mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Use defaultdict (#307)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
This commit is contained in:
		| @@ -1,5 +1,7 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| from collections import defaultdict | ||||
|  | ||||
|  | ||||
| def tree_map(fn, tree, *rest, is_leaf=None): | ||||
|     """Applies ``fn`` to the leaves of the python tree ``tree`` and | ||||
| @@ -128,12 +130,10 @@ def tree_unflatten(tree): | ||||
|         is_list = False | ||||
|  | ||||
|     # collect children | ||||
|     children = {} | ||||
|     children = defaultdict(list) | ||||
|     for key, value in tree: | ||||
|         current_idx, *next_idx = key.split(".", maxsplit=1) | ||||
|         next_idx = "" if not next_idx else next_idx[0] | ||||
|         if current_idx not in children: | ||||
|             children[current_idx] = [] | ||||
|         children[current_idx].append((next_idx, value)) | ||||
|  | ||||
|     # recursively map them to the original container | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Chunyang Wen
					Chunyang Wen