mlx.optimizers.Muon

Contents

mlx.optimizers.Muon#

class Muon(learning_rate: float | Callable[[array], array], momentum: float = 0.95, weight_decay: float = 0.01, nesterov: bool = True, ns_steps: int = 5)#

The Muon optimizer.

Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the original implementation: Muon: An optimizer for hidden layers in neural networks

Note

  • Muon may be sub-optimal for the embedding layer, the final fully connected layer, or any 0D/1D parameters. Those should be optimized by a different method (e.g., AdamW).

  • For 4D convolutional filters, it works by flattening their last dimensions.

Parameters:
  • learning_rate (float or callable) – The learning rate.

  • momentum (float, optional) – The momentum strength. Default: 0.95

  • weight_decay (float, optional) – The weight decay (L2 penalty). Default: 0.01

  • nesterov (bool, optional) – Enables Nesterov momentum. Recommended for better performance. Default: True

  • ns_steps (int, optional) – Number of Newton-Schulz iteration steps for orthogonalization. Default: 5

Methods

__init__(learning_rate[, momentum, ...])

apply_single(gradient, parameter, state)

Performs the Muon parameter update

init_single(parameter, state)

Initialize optimizer state