Compare commits

...

3 Commits

Author SHA1 Message Date
Awni Hannun
7f39e9c299 nits 2025-07-17 06:26:43 -07:00
Gökdeniz Gülmez
baad6e392b Merge branch 'ml-explore:main' into adding-Muon-optimizer 2025-07-17 13:07:54 +02:00
Awni Hannun
b2273733ea Test with CUDA 12.2 (#2375)
* Test with CUDA 12.0

* try older image

* fix cpu sort
2025-07-16 13:00:37 -07:00
6 changed files with 68 additions and 57 deletions

View File

@@ -201,7 +201,7 @@ jobs:
cuda_build_and_test:
machine:
image: linux-cuda-12:default
image: linux-cuda-12:2023.11.1
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
@@ -210,7 +210,7 @@ jobs:
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
python3 -m venv env
source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"

View File

@@ -19,3 +19,4 @@ Common Optimizers
Adamax
Lion
MultiOptimizer
Muon

View File

@@ -334,7 +334,9 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
@@ -426,7 +428,9 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());

View File

@@ -36,7 +36,8 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim = cg::this_grid().dim_threads();
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr float eps = 1e-7;
constexpr int simd_size = WARP_SIZE;
constexpr float n_bins = (1 << bits) - 1;
@@ -48,7 +49,7 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
size_t offset = tidx + grid_dim.x * size_t(tidy);
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t in_index = offset * values_per_reduce;
if (in_index >= size) {
return;
@@ -153,12 +154,13 @@ __global__ void affine_dequantize(
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim = cg::this_grid().dim_threads();
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
size_t offset = tidx + grid_dim.x * size_t(tidy);
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
@@ -349,7 +351,8 @@ void fast::AffineQuantize::eval_gpu(
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) {
auto kernel = cu::affine_dequantize<DataType, group_size(), bits()>;
auto kernel =
cu::affine_dequantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
@@ -362,7 +365,8 @@ void fast::AffineQuantize::eval_gpu(
out.data<DataType>(),
out.size());
} else {
auto kernel = cu::affine_quantize<DataType, group_size(), bits()>;
auto kernel =
cu::affine_quantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(

View File

@@ -849,28 +849,28 @@ class Adafactor(Optimizer):
class Muon(Optimizer):
r"""The Muon optimizer - MomentUm Orthogonalized by Newton-schulz.
r"""The Muon optimizer.
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, a Newton-Schulz iteration is used, which has
the advantage that it can be stably run in bfloat16 on the GPU.
For more details, see: https://kellerjordan.github.io/posts/muon/
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
original implementation: `Muon: An optimizer for hidden layers in neural
networks <https://kellerjordan.github.io/posts/muon/>`_
Note:
- This optimizer may not be optimal for the embedding layer, the final fully connected layer,
or any 0D/1D parameters; those should be optimized by a standard method (e.g., AdamW).
- For 4D convolutional filters, it works by flattening their last dimensions.
- Muon may be sub-optimal for the embedding layer, the final fully
connected layer, or any 0D/1D parameters. Those should be optimized
by a different method (e.g., :class:`AdamW`).
- For 4D convolutional filters, it works by flattening their last
dimensions.
Args:
learning_rate (float or callable): The learning rate used by the internal SGD.
learning_rate (float or callable): The learning rate.
momentum (float, optional): The momentum strength. Default: ``0.95``
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for better performance.
Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for orthogonalization.
Default: ``5``
weight_decay (float, optional): The weight decay (L2 penalty).
Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
better performance. Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
orthogonalization. Default: ``5``
"""
def __init__(
@@ -882,7 +882,7 @@ class Muon(Optimizer):
ns_steps: int = 5,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
@@ -894,55 +894,46 @@ class Muon(Optimizer):
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.astype(mx.bfloat16)
transpose_needed = G.shape[-2] > G.shape[-1]
if transpose_needed:
X = X.T
# Ensure spectral norm is at most 1
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
X = X / norm
# Perform the NS iterations
for _ in range(steps):
A = X @ X.T
B = b * A + c * (A @ A)
X = a * X + B @ X
if transpose_needed:
X = X.T
return X
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
# Apply weight decay
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
# Update momentum buffer
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
# Get effective gradient
if self.nesterov:
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
else:
effective_grad = v
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
if effective_grad.ndim < 2:
orthogonalized_grad = effective_grad
@@ -951,22 +942,33 @@ class Muon(Optimizer):
# Save original shape for 4D conv filters
original_shape = effective_grad.shape
reshape_needed = effective_grad.ndim > 2
if reshape_needed:
effective_grad = mx.reshape(effective_grad, (effective_grad.shape[0], -1))
effective_grad = mx.reshape(
effective_grad, (effective_grad.shape[0], -1)
)
# Apply Newton-Schulz orthogonalization
orthogonalized_grad = self._zeropower_via_newtonschulz5(effective_grad, steps=self.ns_steps)
orthogonalized_grad = self._zeropower_via_newtonschulz5(
effective_grad, steps=self.ns_steps
)
# Reshape back if needed
if reshape_needed:
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
# Calculate scaling factor
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
scale_factor = (
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
)
return (
parameter
- self.learning_rate.astype(gradient.dtype)
* orthogonalized_grad
* scale_factor
)
def clip_grad_norm(grads, max_norm):

View File

@@ -307,7 +307,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
# Test update
updated_params = optim.apply_gradients(grads, params)
# Check that shapes are preserved
self.assertTrue(
tree_equal(
@@ -316,7 +316,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
updated_params,
)
)
# Check that parameters actually changed
self.assertFalse(
tree_equal(
@@ -325,11 +325,11 @@ class TestOptimizers(mlx_tests.MLXTestCase):
updated_params,
)
)
# Test with different configurations
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
optim_no_nesterov.apply_gradients(grads, params)
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
optim_no_momentum.apply_gradients(grads, params)