mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Fix typo; Fix lint warning when reuse the same name (#1968)
* Fix typo; Fix lint warning when reuse the same name * Add missing period
This commit is contained in:
parent
c6ea2ba329
commit
45ad06aac8
@ -50,19 +50,19 @@ class Optimizer:
|
|||||||
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Iniatilize the optimizer state to match the parameter state
|
# Initialize the optimizer state to match the parameter state
|
||||||
def update_state(params, state):
|
def update_state(params, state):
|
||||||
if isinstance(params, (list, tuple)):
|
if isinstance(params, (list, tuple)):
|
||||||
state = list(state)
|
state = list(state)
|
||||||
for i in range(len(state)):
|
for i in range(len(state)):
|
||||||
state[i] = update_state(params[i], state[i])
|
state[i] = update_state(params[i], state[i])
|
||||||
if len(state) != len(params):
|
if len(state) != len(params):
|
||||||
state.extend(tree_map(lambda x: {}, params[len(state) :]))
|
state.extend(tree_map(lambda _: {}, params[len(state) :]))
|
||||||
return type(params)(state)
|
return type(params)(state)
|
||||||
elif isinstance(params, dict):
|
elif isinstance(params, dict):
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
if k not in state:
|
if k not in state:
|
||||||
state[k] = tree_map(lambda x: {}, v)
|
state[k] = tree_map(lambda _: {}, v)
|
||||||
else:
|
else:
|
||||||
state[k] = update_state(v, state[k])
|
state[k] = update_state(v, state[k])
|
||||||
return state
|
return state
|
||||||
@ -79,6 +79,7 @@ class Optimizer:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
parameter (mx.array): A single parameter that will be optimized.
|
parameter (mx.array): A single parameter that will be optimized.
|
||||||
|
state (dict): The optimizer's state.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -148,10 +149,10 @@ class Optimizer:
|
|||||||
"""
|
"""
|
||||||
if isinstance(param, Callable):
|
if isinstance(param, Callable):
|
||||||
self._schedulers[name] = param
|
self._schedulers[name] = param
|
||||||
param = param(self.step)
|
parameter = param(self.step)
|
||||||
else:
|
else:
|
||||||
param = mx.array(param)
|
parameter = mx.array(param)
|
||||||
self.state[name] = param
|
self.state[name] = parameter
|
||||||
|
|
||||||
|
|
||||||
class MultiOptimizer(Optimizer):
|
class MultiOptimizer(Optimizer):
|
||||||
|
Loading…
Reference in New Issue
Block a user