mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'ml-explore:main' into adding-Muon-optimizer
This commit is contained in:
@@ -373,7 +373,7 @@ def smooth_l1_loss(
|
||||
f"targets shape {targets.shape}."
|
||||
)
|
||||
|
||||
diff = predictions - targets
|
||||
diff = mx.abs(predictions - targets)
|
||||
loss = mx.where(
|
||||
diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
|
||||
)
|
||||
|
||||
@@ -50,19 +50,19 @@ class Optimizer:
|
||||
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
||||
"""
|
||||
|
||||
# Iniatilize the optimizer state to match the parameter state
|
||||
# Initialize the optimizer state to match the parameter state
|
||||
def update_state(params, state):
|
||||
if isinstance(params, (list, tuple)):
|
||||
state = list(state)
|
||||
for i in range(len(state)):
|
||||
state[i] = update_state(params[i], state[i])
|
||||
if len(state) != len(params):
|
||||
state.extend(tree_map(lambda x: {}, params[len(state) :]))
|
||||
state.extend(tree_map(lambda _: {}, params[len(state) :]))
|
||||
return type(params)(state)
|
||||
elif isinstance(params, dict):
|
||||
for k, v in params.items():
|
||||
if k not in state:
|
||||
state[k] = tree_map(lambda x: {}, v)
|
||||
state[k] = tree_map(lambda _: {}, v)
|
||||
else:
|
||||
state[k] = update_state(v, state[k])
|
||||
return state
|
||||
@@ -79,6 +79,7 @@ class Optimizer:
|
||||
|
||||
Args:
|
||||
parameter (mx.array): A single parameter that will be optimized.
|
||||
state (dict): The optimizer's state.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -148,10 +149,10 @@ class Optimizer:
|
||||
"""
|
||||
if isinstance(param, Callable):
|
||||
self._schedulers[name] = param
|
||||
param = param(self.step)
|
||||
parameter = param(self.step)
|
||||
else:
|
||||
param = mx.array(param)
|
||||
self.state[name] = param
|
||||
parameter = mx.array(param)
|
||||
self.state[name] = parameter
|
||||
|
||||
|
||||
class MultiOptimizer(Optimizer):
|
||||
|
||||
@@ -80,12 +80,12 @@ def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable:
|
||||
array(0.0999961, dtype=float32)
|
||||
"""
|
||||
|
||||
def scheduler(step):
|
||||
def schedule(step):
|
||||
s = mx.minimum(step, decay_steps)
|
||||
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
|
||||
return end + decay * (init - end)
|
||||
|
||||
return scheduler
|
||||
return schedule
|
||||
|
||||
|
||||
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
|
||||
@@ -99,9 +99,9 @@ def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable
|
||||
that indicates when to transition between schedules.
|
||||
|
||||
Example:
|
||||
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
|
||||
>>> linear = optim.linear_schedule(0, 1e-1, steps=10)
|
||||
>>> cosine = optim.cosine_decay(1e-1, 200)
|
||||
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
|
||||
>>> lr_schedule = optim.join_schedules([linear, cosine], [10])
|
||||
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0, dtype=float32)
|
||||
@@ -139,8 +139,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
||||
|
||||
Example:
|
||||
|
||||
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
|
||||
>>> optimizer = optim.Adam(learning_rate=warmup)
|
||||
>>> lr_schedule = optim.linear_schedule(0, 1e-1, 100)
|
||||
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0, dtype=float32)
|
||||
>>> for _ in range(101): optimizer.update({}, {})
|
||||
@@ -151,8 +151,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
||||
if steps < 1:
|
||||
raise ValueError(f"steps must be greater than 0, but got {steps}.")
|
||||
|
||||
def step_fn(step):
|
||||
def schedule(step):
|
||||
step = mx.minimum(step, steps)
|
||||
return step * ((end - init) / steps) + init
|
||||
|
||||
return step_fn
|
||||
return schedule
|
||||
|
||||
Reference in New Issue
Block a user