mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fix typo; Fix lint warning when reuse the same name (#1968)
* Fix typo; Fix lint warning when reuse the same name * Add missing period
This commit is contained in:
		| @@ -50,19 +50,19 @@ class Optimizer: | |||||||
|             dict_keys(['step', 'learning_rate', 'weight', 'bias']) |             dict_keys(['step', 'learning_rate', 'weight', 'bias']) | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         # Iniatilize the optimizer state to match the parameter state |         # Initialize the optimizer state to match the parameter state | ||||||
|         def update_state(params, state): |         def update_state(params, state): | ||||||
|             if isinstance(params, (list, tuple)): |             if isinstance(params, (list, tuple)): | ||||||
|                 state = list(state) |                 state = list(state) | ||||||
|                 for i in range(len(state)): |                 for i in range(len(state)): | ||||||
|                     state[i] = update_state(params[i], state[i]) |                     state[i] = update_state(params[i], state[i]) | ||||||
|                 if len(state) != len(params): |                 if len(state) != len(params): | ||||||
|                     state.extend(tree_map(lambda x: {}, params[len(state) :])) |                     state.extend(tree_map(lambda _: {}, params[len(state) :])) | ||||||
|                 return type(params)(state) |                 return type(params)(state) | ||||||
|             elif isinstance(params, dict): |             elif isinstance(params, dict): | ||||||
|                 for k, v in params.items(): |                 for k, v in params.items(): | ||||||
|                     if k not in state: |                     if k not in state: | ||||||
|                         state[k] = tree_map(lambda x: {}, v) |                         state[k] = tree_map(lambda _: {}, v) | ||||||
|                     else: |                     else: | ||||||
|                         state[k] = update_state(v, state[k]) |                         state[k] = update_state(v, state[k]) | ||||||
|                 return state |                 return state | ||||||
| @@ -79,6 +79,7 @@ class Optimizer: | |||||||
|  |  | ||||||
|         Args: |         Args: | ||||||
|             parameter (mx.array): A single parameter that will be optimized. |             parameter (mx.array): A single parameter that will be optimized. | ||||||
|  |             state (dict): The optimizer's state. | ||||||
|         """ |         """ | ||||||
|         raise NotImplementedError() |         raise NotImplementedError() | ||||||
|  |  | ||||||
| @@ -148,10 +149,10 @@ class Optimizer: | |||||||
|         """ |         """ | ||||||
|         if isinstance(param, Callable): |         if isinstance(param, Callable): | ||||||
|             self._schedulers[name] = param |             self._schedulers[name] = param | ||||||
|             param = param(self.step) |             parameter = param(self.step) | ||||||
|         else: |         else: | ||||||
|             param = mx.array(param) |             parameter = mx.array(param) | ||||||
|         self.state[name] = param |         self.state[name] = parameter | ||||||
|  |  | ||||||
|  |  | ||||||
| class MultiOptimizer(Optimizer): | class MultiOptimizer(Optimizer): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Chunyang Wen
					Chunyang Wen