mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Minor refactor for tree_map and tree_unflatten (#311)
* Minor refact for tree_map and tree_unflatten * Remove the if statement --------- Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
This commit is contained in:
		| @@ -39,13 +39,9 @@ def tree_map(fn, tree, *rest, is_leaf=None): | ||||
|     """ | ||||
|     if is_leaf is not None and is_leaf(tree): | ||||
|         return fn(tree, *rest) | ||||
|     elif isinstance(tree, list): | ||||
|         return [ | ||||
|             tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) | ||||
|             for i, child in enumerate(tree) | ||||
|         ] | ||||
|     elif isinstance(tree, tuple): | ||||
|         return tuple( | ||||
|     elif isinstance(tree, (list, tuple)): | ||||
|         TreeType = type(tree) | ||||
|         return TreeType( | ||||
|             tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) | ||||
|             for i, child in enumerate(tree) | ||||
|         ) | ||||
| @@ -141,8 +137,8 @@ def tree_unflatten(tree): | ||||
|         keys = sorted((int(idx), idx) for idx in children.keys()) | ||||
|         l = [] | ||||
|         for i, k in keys: | ||||
|             while i > len(l): | ||||
|                 l.append({}) | ||||
|             # if i <= len(l), no {} will be appended. | ||||
|             l.extend([{} for _ in range(i - len(l))]) | ||||
|             l.append(tree_unflatten(children[k])) | ||||
|         return l | ||||
|     else: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Chunyang Wen
					Chunyang Wen