Fix typo; Fix lint warning when reuse the same name (#1968)

* Fix typo; Fix lint warning when reuse the same name

* Add missing period
This commit is contained in:
Chunyang Wen 2025-03-18 22:12:24 +08:00 committed by GitHub
parent c6ea2ba329
commit 45ad06aac8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -50,19 +50,19 @@ class Optimizer:
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
"""
# Iniatilize the optimizer state to match the parameter state
# Initialize the optimizer state to match the parameter state
def update_state(params, state):
if isinstance(params, (list, tuple)):
state = list(state)
for i in range(len(state)):
state[i] = update_state(params[i], state[i])
if len(state) != len(params):
state.extend(tree_map(lambda x: {}, params[len(state) :]))
state.extend(tree_map(lambda _: {}, params[len(state) :]))
return type(params)(state)
elif isinstance(params, dict):
for k, v in params.items():
if k not in state:
state[k] = tree_map(lambda x: {}, v)
state[k] = tree_map(lambda _: {}, v)
else:
state[k] = update_state(v, state[k])
return state
@ -79,6 +79,7 @@ class Optimizer:
Args:
parameter (mx.array): A single parameter that will be optimized.
state (dict): The optimizer's state.
"""
raise NotImplementedError()
@ -148,10 +149,10 @@ class Optimizer:
"""
if isinstance(param, Callable):
self._schedulers[name] = param
param = param(self.step)
parameter = param(self.step)
else:
param = mx.array(param)
self.state[name] = param
parameter = mx.array(param)
self.state[name] = parameter
class MultiOptimizer(Optimizer):