Fix typo; Fix lint warning when reuse the same name (#1968)

* Fix typo; Fix lint warning when reuse the same name

* Add missing period
This commit is contained in:
Chunyang Wen 2025-03-18 22:12:24 +08:00 committed by GitHub
parent c6ea2ba329
commit 45ad06aac8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -50,19 +50,19 @@ class Optimizer:
dict_keys(['step', 'learning_rate', 'weight', 'bias']) dict_keys(['step', 'learning_rate', 'weight', 'bias'])
""" """
# Iniatilize the optimizer state to match the parameter state # Initialize the optimizer state to match the parameter state
def update_state(params, state): def update_state(params, state):
if isinstance(params, (list, tuple)): if isinstance(params, (list, tuple)):
state = list(state) state = list(state)
for i in range(len(state)): for i in range(len(state)):
state[i] = update_state(params[i], state[i]) state[i] = update_state(params[i], state[i])
if len(state) != len(params): if len(state) != len(params):
state.extend(tree_map(lambda x: {}, params[len(state) :])) state.extend(tree_map(lambda _: {}, params[len(state) :]))
return type(params)(state) return type(params)(state)
elif isinstance(params, dict): elif isinstance(params, dict):
for k, v in params.items(): for k, v in params.items():
if k not in state: if k not in state:
state[k] = tree_map(lambda x: {}, v) state[k] = tree_map(lambda _: {}, v)
else: else:
state[k] = update_state(v, state[k]) state[k] = update_state(v, state[k])
return state return state
@ -79,6 +79,7 @@ class Optimizer:
Args: Args:
parameter (mx.array): A single parameter that will be optimized. parameter (mx.array): A single parameter that will be optimized.
state (dict): The optimizer's state.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -148,10 +149,10 @@ class Optimizer:
""" """
if isinstance(param, Callable): if isinstance(param, Callable):
self._schedulers[name] = param self._schedulers[name] = param
param = param(self.step) parameter = param(self.step)
else: else:
param = mx.array(param) parameter = mx.array(param)
self.state[name] = param self.state[name] = parameter
class MultiOptimizer(Optimizer): class MultiOptimizer(Optimizer):