mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Fix typo; Fix lint warning when reuse the same name (#1968)
* Fix typo; Fix lint warning when reuse the same name * Add missing period
This commit is contained in:
parent
c6ea2ba329
commit
45ad06aac8
@ -50,19 +50,19 @@ class Optimizer:
|
||||
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
||||
"""
|
||||
|
||||
# Iniatilize the optimizer state to match the parameter state
|
||||
# Initialize the optimizer state to match the parameter state
|
||||
def update_state(params, state):
|
||||
if isinstance(params, (list, tuple)):
|
||||
state = list(state)
|
||||
for i in range(len(state)):
|
||||
state[i] = update_state(params[i], state[i])
|
||||
if len(state) != len(params):
|
||||
state.extend(tree_map(lambda x: {}, params[len(state) :]))
|
||||
state.extend(tree_map(lambda _: {}, params[len(state) :]))
|
||||
return type(params)(state)
|
||||
elif isinstance(params, dict):
|
||||
for k, v in params.items():
|
||||
if k not in state:
|
||||
state[k] = tree_map(lambda x: {}, v)
|
||||
state[k] = tree_map(lambda _: {}, v)
|
||||
else:
|
||||
state[k] = update_state(v, state[k])
|
||||
return state
|
||||
@ -79,6 +79,7 @@ class Optimizer:
|
||||
|
||||
Args:
|
||||
parameter (mx.array): A single parameter that will be optimized.
|
||||
state (dict): The optimizer's state.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -148,10 +149,10 @@ class Optimizer:
|
||||
"""
|
||||
if isinstance(param, Callable):
|
||||
self._schedulers[name] = param
|
||||
param = param(self.step)
|
||||
parameter = param(self.step)
|
||||
else:
|
||||
param = mx.array(param)
|
||||
self.state[name] = param
|
||||
parameter = mx.array(param)
|
||||
self.state[name] = parameter
|
||||
|
||||
|
||||
class MultiOptimizer(Optimizer):
|
||||
|
Loading…
Reference in New Issue
Block a user