mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 12:13:21 +08:00
add Muon optimizer
This commit is contained in:
parent
e8ff59451e
commit
aef320bf84
@ -865,19 +865,19 @@ def _zeropower_via_newtonschulz5(
|
|||||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
X = G.astype(mx.bfloat16)
|
X = G.astype(mx.bfloat16)
|
||||||
if G.shape[-2] > G.shape[-1]:
|
if G.shape[-2] > G.shape[-1]:
|
||||||
X = mx.transpose(X, (-2, -1))
|
X = X.T
|
||||||
|
|
||||||
# Frobenius-norm normalisation (≈ spectral-norm ≤ 1)
|
# Frobenius-norm normalisation (≈ spectral-norm ≤ 1)
|
||||||
X = X / (mx.linalg.norm(X, ord="fro", axis=(-2, -1), keepdims=True) + eps)
|
X = X / (mx.linalg.norm(X, ord="fro", axis=(-2, -1), keepdims=True) + eps)
|
||||||
|
|
||||||
# Perform Newton-Schulz iteration
|
# Perform Newton-Schulz iteration
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
A = X @ mx.transpose(X, (-2, -1))
|
A = X @ X.T
|
||||||
B = b * A + c * A @ A
|
B = b * A + c * A @ A
|
||||||
X = a * X + B @ X
|
X = a * X + B @ X
|
||||||
|
|
||||||
if G.shape[-2] > G.shape[-1]:
|
if G.shape[-2] > G.shape[-1]:
|
||||||
X = mx.transpose(X, (-2, -1))
|
X = X.T
|
||||||
|
|
||||||
return X
|
return X
|
||||||
|
|
||||||
@ -918,6 +918,7 @@ class Muon(Optimizer):
|
|||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.nesterov = nesterov
|
self.nesterov = nesterov
|
||||||
self.ns_steps = ns_steps
|
self.ns_steps = ns_steps
|
||||||
|
self.backup_optimizer = AdamW(learning_rate=1e-3, weight_decay=weight_decay)
|
||||||
|
|
||||||
def _flatten_if_conv(self, x: mx.array) -> mx.array:
|
def _flatten_if_conv(self, x: mx.array) -> mx.array:
|
||||||
if x.ndim == 4: # [out, in, kH, kW] → [out, -1]
|
if x.ndim == 4: # [out, in, kH, kW] → [out, -1]
|
||||||
@ -925,11 +926,21 @@ class Muon(Optimizer):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def init_single(self, parameter: mx.array, state: dict):
|
def init_single(self, parameter: mx.array, state: dict):
|
||||||
|
# If the parameter is not a 2D array, use the backup optimizer
|
||||||
|
if parameter.ndim != 2:
|
||||||
|
return self.backup_optimizer.init_single(parameter, state)
|
||||||
state["B"] = mx.zeros_like(parameter)
|
state["B"] = mx.zeros_like(parameter)
|
||||||
|
|
||||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||||
|
# If the parameter is not a 2D array, use the backup optimizer
|
||||||
|
if "B" not in state:
|
||||||
|
return self.backup_optimizer.apply_single(
|
||||||
|
gradient=gradient, parameter=parameter, state=state
|
||||||
|
)
|
||||||
|
# Record the original dtype and shape
|
||||||
original_dtype = gradient.dtype
|
original_dtype = gradient.dtype
|
||||||
original_shape = gradient.shape
|
original_shape = gradient.shape
|
||||||
|
# Cast lr to the original dtype
|
||||||
lr = self.learning_rate.astype(original_dtype)
|
lr = self.learning_rate.astype(original_dtype)
|
||||||
# Compute new buffer and store
|
# Compute new buffer and store
|
||||||
state["B"] = self.momentum * state["B"] + gradient
|
state["B"] = self.momentum * state["B"] + gradient
|
||||||
@ -947,10 +958,11 @@ class Muon(Optimizer):
|
|||||||
# Reshape back to original shape
|
# Reshape back to original shape
|
||||||
gradient = mx.reshape(gradient, original_shape)
|
gradient = mx.reshape(gradient, original_shape)
|
||||||
# scale-invariant step size
|
# scale-invariant step size
|
||||||
scale = mx.sqrt(
|
scale = mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||||
mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1])
|
return (
|
||||||
).astype(original_dtype)
|
parameter * (1 - lr * self.weight_decay)
|
||||||
return parameter * (1 - lr * self.weight_decay) - lr * scale * gradient
|
- lr * scale.astype(original_dtype) * gradient
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_norm(grads, max_norm):
|
def clip_grad_norm(grads, max_norm):
|
||||||
|
Loading…
Reference in New Issue
Block a user