mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 20:41:15 +08:00
add Muon optimizer
This commit is contained in:
parent
e8ff59451e
commit
aef320bf84
@ -865,19 +865,19 @@ def _zeropower_via_newtonschulz5(
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.astype(mx.bfloat16)
|
||||
if G.shape[-2] > G.shape[-1]:
|
||||
X = mx.transpose(X, (-2, -1))
|
||||
X = X.T
|
||||
|
||||
# Frobenius-norm normalisation (≈ spectral-norm ≤ 1)
|
||||
X = X / (mx.linalg.norm(X, ord="fro", axis=(-2, -1), keepdims=True) + eps)
|
||||
|
||||
# Perform Newton-Schulz iteration
|
||||
for _ in range(steps):
|
||||
A = X @ mx.transpose(X, (-2, -1))
|
||||
A = X @ X.T
|
||||
B = b * A + c * A @ A
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.shape[-2] > G.shape[-1]:
|
||||
X = mx.transpose(X, (-2, -1))
|
||||
X = X.T
|
||||
|
||||
return X
|
||||
|
||||
@ -918,6 +918,7 @@ class Muon(Optimizer):
|
||||
self.weight_decay = weight_decay
|
||||
self.nesterov = nesterov
|
||||
self.ns_steps = ns_steps
|
||||
self.backup_optimizer = AdamW(learning_rate=1e-3, weight_decay=weight_decay)
|
||||
|
||||
def _flatten_if_conv(self, x: mx.array) -> mx.array:
|
||||
if x.ndim == 4: # [out, in, kH, kW] → [out, -1]
|
||||
@ -925,11 +926,21 @@ class Muon(Optimizer):
|
||||
return x
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
# If the parameter is not a 2D array, use the backup optimizer
|
||||
if parameter.ndim != 2:
|
||||
return self.backup_optimizer.init_single(parameter, state)
|
||||
state["B"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
# If the parameter is not a 2D array, use the backup optimizer
|
||||
if "B" not in state:
|
||||
return self.backup_optimizer.apply_single(
|
||||
gradient=gradient, parameter=parameter, state=state
|
||||
)
|
||||
# Record the original dtype and shape
|
||||
original_dtype = gradient.dtype
|
||||
original_shape = gradient.shape
|
||||
# Cast lr to the original dtype
|
||||
lr = self.learning_rate.astype(original_dtype)
|
||||
# Compute new buffer and store
|
||||
state["B"] = self.momentum * state["B"] + gradient
|
||||
@ -947,10 +958,11 @@ class Muon(Optimizer):
|
||||
# Reshape back to original shape
|
||||
gradient = mx.reshape(gradient, original_shape)
|
||||
# scale-invariant step size
|
||||
scale = mx.sqrt(
|
||||
mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1])
|
||||
).astype(original_dtype)
|
||||
return parameter * (1 - lr * self.weight_decay) - lr * scale * gradient
|
||||
scale = mx.maximum(1.0, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||
return (
|
||||
parameter * (1 - lr * self.weight_decay)
|
||||
- lr * scale.astype(original_dtype) * gradient
|
||||
)
|
||||
|
||||
|
||||
def clip_grad_norm(grads, max_norm):
|
||||
|
Loading…
Reference in New Issue
Block a user