mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
784e0716fe
...
7f39e9c299
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f39e9c299 | ||
|
|
baad6e392b | ||
|
|
b2273733ea |
@@ -201,7 +201,7 @@ jobs:
|
||||
|
||||
cuda_build_and_test:
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
image: linux-cuda-12:2023.11.1
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
@@ -210,7 +210,7 @@ jobs:
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
python -m venv env
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install -e ".[dev]"
|
||||
|
||||
@@ -19,3 +19,4 @@ Common Optimizers
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
Muon
|
||||
|
||||
@@ -334,7 +334,9 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
@@ -426,7 +428,9 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
@@ -36,7 +36,8 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
|
||||
auto grid_dim = cg::this_grid().dim_threads();
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
constexpr float eps = 1e-7;
|
||||
constexpr int simd_size = WARP_SIZE;
|
||||
constexpr float n_bins = (1 << bits) - 1;
|
||||
@@ -48,7 +49,7 @@ affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
||||
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
|
||||
size_t offset = tidx + grid_dim.x * size_t(tidy);
|
||||
size_t offset = tidx + grid_dim_x * size_t(tidy);
|
||||
size_t in_index = offset * values_per_reduce;
|
||||
if (in_index >= size) {
|
||||
return;
|
||||
@@ -153,12 +154,13 @@ __global__ void affine_dequantize(
|
||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||
|
||||
auto grid_dim = cg::this_grid().dim_threads();
|
||||
auto grid_dim_x =
|
||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
||||
|
||||
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
||||
|
||||
size_t offset = tidx + grid_dim.x * size_t(tidy);
|
||||
size_t offset = tidx + grid_dim_x * size_t(tidy);
|
||||
size_t oindex = offset * pack_factor;
|
||||
|
||||
if (oindex >= size) {
|
||||
@@ -349,7 +351,8 @@ void fast::AffineQuantize::eval_gpu(
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if (dequantize_) {
|
||||
auto kernel = cu::affine_dequantize<DataType, group_size(), bits()>;
|
||||
auto kernel =
|
||||
cu::affine_dequantize<DataType, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
@@ -362,7 +365,8 @@ void fast::AffineQuantize::eval_gpu(
|
||||
out.data<DataType>(),
|
||||
out.size());
|
||||
} else {
|
||||
auto kernel = cu::affine_quantize<DataType, group_size(), bits()>;
|
||||
auto kernel =
|
||||
cu::affine_quantize<DataType, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
|
||||
@@ -849,28 +849,28 @@ class Adafactor(Optimizer):
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
r"""The Muon optimizer - MomentUm Orthogonalized by Newton-schulz.
|
||||
r"""The Muon optimizer.
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, a Newton-Schulz iteration is used, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
For more details, see: https://kellerjordan.github.io/posts/muon/
|
||||
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
||||
original implementation: `Muon: An optimizer for hidden layers in neural
|
||||
networks <https://kellerjordan.github.io/posts/muon/>`_
|
||||
|
||||
Note:
|
||||
- This optimizer may not be optimal for the embedding layer, the final fully connected layer,
|
||||
or any 0D/1D parameters; those should be optimized by a standard method (e.g., AdamW).
|
||||
- For 4D convolutional filters, it works by flattening their last dimensions.
|
||||
- Muon may be sub-optimal for the embedding layer, the final fully
|
||||
connected layer, or any 0D/1D parameters. Those should be optimized
|
||||
by a different method (e.g., :class:`AdamW`).
|
||||
- For 4D convolutional filters, it works by flattening their last
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
learning_rate (float or callable): The learning rate used by the internal SGD.
|
||||
learning_rate (float or callable): The learning rate.
|
||||
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0.01``
|
||||
nesterov (bool, optional): Enables Nesterov momentum. Recommended for better performance.
|
||||
Default: ``True``
|
||||
ns_steps (int, optional): Number of Newton-Schulz iteration steps for orthogonalization.
|
||||
Default: ``5``
|
||||
weight_decay (float, optional): The weight decay (L2 penalty).
|
||||
Default: ``0.01``
|
||||
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
||||
better performance. Default: ``True``
|
||||
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
||||
orthogonalization. Default: ``5``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -882,7 +882,7 @@ class Muon(Optimizer):
|
||||
ns_steps: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
@@ -894,55 +894,46 @@ class Muon(Optimizer):
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.astype(mx.bfloat16)
|
||||
transpose_needed = G.shape[-2] > G.shape[-1]
|
||||
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
|
||||
X = X / norm
|
||||
|
||||
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
return X
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Muon parameter update"""
|
||||
|
||||
|
||||
# Apply weight decay
|
||||
if self.weight_decay != 0:
|
||||
gradient = gradient + self.weight_decay * parameter
|
||||
|
||||
|
||||
# Update momentum buffer
|
||||
v = self.momentum * state["v"]
|
||||
v = v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
|
||||
|
||||
# Get effective gradient
|
||||
if self.nesterov:
|
||||
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
|
||||
else:
|
||||
effective_grad = v
|
||||
|
||||
|
||||
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
|
||||
if effective_grad.ndim < 2:
|
||||
orthogonalized_grad = effective_grad
|
||||
@@ -951,22 +942,33 @@ class Muon(Optimizer):
|
||||
# Save original shape for 4D conv filters
|
||||
original_shape = effective_grad.shape
|
||||
reshape_needed = effective_grad.ndim > 2
|
||||
|
||||
|
||||
if reshape_needed:
|
||||
effective_grad = mx.reshape(effective_grad, (effective_grad.shape[0], -1))
|
||||
|
||||
effective_grad = mx.reshape(
|
||||
effective_grad, (effective_grad.shape[0], -1)
|
||||
)
|
||||
|
||||
# Apply Newton-Schulz orthogonalization
|
||||
orthogonalized_grad = self._zeropower_via_newtonschulz5(effective_grad, steps=self.ns_steps)
|
||||
|
||||
orthogonalized_grad = self._zeropower_via_newtonschulz5(
|
||||
effective_grad, steps=self.ns_steps
|
||||
)
|
||||
|
||||
# Reshape back if needed
|
||||
if reshape_needed:
|
||||
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
||||
|
||||
|
||||
# Calculate scaling factor
|
||||
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||
scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
||||
|
||||
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
|
||||
scale_factor = (
|
||||
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
||||
)
|
||||
|
||||
return (
|
||||
parameter
|
||||
- self.learning_rate.astype(gradient.dtype)
|
||||
* orthogonalized_grad
|
||||
* scale_factor
|
||||
)
|
||||
|
||||
|
||||
def clip_grad_norm(grads, max_norm):
|
||||
|
||||
@@ -307,7 +307,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
|
||||
# Test update
|
||||
updated_params = optim.apply_gradients(grads, params)
|
||||
|
||||
|
||||
# Check that shapes are preserved
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
@@ -316,7 +316,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Check that parameters actually changed
|
||||
self.assertFalse(
|
||||
tree_equal(
|
||||
@@ -325,11 +325,11 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Test with different configurations
|
||||
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||
optim_no_nesterov.apply_gradients(grads, params)
|
||||
|
||||
|
||||
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||
optim_no_momentum.apply_gradients(grads, params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user