Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-03-21 08:50:43 +01:00
committed by GitHub
84 changed files with 901 additions and 484 deletions

View File

@@ -373,7 +373,7 @@ def smooth_l1_loss(
f"targets shape {targets.shape}."
)
diff = predictions - targets
diff = mx.abs(predictions - targets)
loss = mx.where(
diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
)

View File

@@ -50,19 +50,19 @@ class Optimizer:
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
"""
# Iniatilize the optimizer state to match the parameter state
# Initialize the optimizer state to match the parameter state
def update_state(params, state):
if isinstance(params, (list, tuple)):
state = list(state)
for i in range(len(state)):
state[i] = update_state(params[i], state[i])
if len(state) != len(params):
state.extend(tree_map(lambda x: {}, params[len(state) :]))
state.extend(tree_map(lambda _: {}, params[len(state) :]))
return type(params)(state)
elif isinstance(params, dict):
for k, v in params.items():
if k not in state:
state[k] = tree_map(lambda x: {}, v)
state[k] = tree_map(lambda _: {}, v)
else:
state[k] = update_state(v, state[k])
return state
@@ -79,6 +79,7 @@ class Optimizer:
Args:
parameter (mx.array): A single parameter that will be optimized.
state (dict): The optimizer's state.
"""
raise NotImplementedError()
@@ -148,10 +149,10 @@ class Optimizer:
"""
if isinstance(param, Callable):
self._schedulers[name] = param
param = param(self.step)
parameter = param(self.step)
else:
param = mx.array(param)
self.state[name] = param
parameter = mx.array(param)
self.state[name] = parameter
class MultiOptimizer(Optimizer):

View File

@@ -80,12 +80,12 @@ def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable:
array(0.0999961, dtype=float32)
"""
def scheduler(step):
def schedule(step):
s = mx.minimum(step, decay_steps)
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
return end + decay * (init - end)
return scheduler
return schedule
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
@@ -99,9 +99,9 @@ def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable
that indicates when to transition between schedules.
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
>>> linear = optim.linear_schedule(0, 1e-1, steps=10)
>>> cosine = optim.cosine_decay(1e-1, 200)
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
>>> lr_schedule = optim.join_schedules([linear, cosine], [10])
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
@@ -139,8 +139,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
>>> optimizer = optim.Adam(learning_rate=warmup)
>>> lr_schedule = optim.linear_schedule(0, 1e-1, 100)
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(101): optimizer.update({}, {})
@@ -151,8 +151,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
if steps < 1:
raise ValueError(f"steps must be greater than 0, but got {steps}.")
def step_fn(step):
def schedule(step):
step = mx.minimum(step, steps)
return step * ((end - init) / steps) + init
return step_fn
return schedule