diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index 1dfae8524..2a612a2d9 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { void Contiguous::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous)) { + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { copy(in, out, CopyType::General, stream()); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 67576f03f..6946ffb9e 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -251,8 +251,10 @@ void Concatenate::eval_gpu(const std::vector& inputs, array& out) { void Contiguous::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous)) { + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { copy_gpu(in, out, CopyType::General); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4e147487d..5a64a7852 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -993,6 +993,9 @@ array concatenate( throw std::invalid_argument( "[concatenate] No arrays provided for concatenation"); } + if (arrays.size() == 1) { + return arrays[0]; + } auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] "); diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 5d6bc4383..3391ba620 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -761,6 +761,8 @@ def main(): "--cwd", help="Set the working directory on each node to the provided one" ) args, rest = parser.parse_known_args() + if rest[0] == "--": + rest.pop(0) if args.print_python: print(sys.executable) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index c1d89fed9..26f77917f 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import ( ConvTranspose2d, ConvTranspose3d, ) +from mlx.nn.layers.distributed import ( + AllToShardedLinear, + QuantizedAllToShardedLinear, + QuantizedShardedToAllLinear, + ShardedToAllLinear, +) from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py new file mode 100644 index 000000000..92acde8f6 --- /dev/null +++ b/python/mlx/nn/layers/distributed.py @@ -0,0 +1,599 @@ +# Copyright © 2024 Apple Inc. + +import math +from functools import lru_cache +from typing import Callable, Optional, Union + +import mlx.core as mx +from mlx.nn.layers.base import Module +from mlx.nn.layers.linear import Linear +from mlx.nn.layers.quantized import QuantizedLinear +from mlx.utils import tree_map_with_path + + +@lru_cache +def sum_gradients(group): + if group.size() == 1: + return lambda x: x + + @mx.custom_function + def f(x): + return x + + @f.vjp + def f(x, dx, _): + return mx.distributed.all_sum(dx, group=group) + + return f + + +def _split(weight, segments, axis): + """Equivalent to mx.split but allows for fractional segments.""" + if isinstance(segments, int) or isinstance(segments[0], int): + return mx.split(weight, segments, axis=axis) + + N = weight.shape[axis] + indices = [int(s * N) for s in segments] + return mx.split(weight, indices, axis=axis) + + +def _shard( + parameters: dict, + sharding_predicate: Callable, + group: Optional[mx.distributed.Group] = None, +): + """Returns a new parameter tree with the weights sharded according to the + sharding_predicate. + + The sharding predicate should return the sharding axis and optionally also + the segments that comprise the weight. + """ + group = group or mx.distributed.init() + N = group.size() + r = group.rank() + + def _shard_fn(path, weight): + if not isinstance(weight, mx.array): + return weight + + s = sharding_predicate(path, weight) + if s is None: + return weight + + axis = None + segments = 1 + if isinstance(s, int): + axis = s + elif isinstance(s, tuple): + axis, segments = s + else: + raise ValueError( + "The sharding function should return int or tuple[int, list]" + ) + + return mx.contiguous( + mx.concatenate( + [_split(part, N, axis)[r] for part in _split(weight, segments, axis)], + axis=axis, + ) + ) + + return tree_map_with_path(_shard_fn, parameters) + + +def _all_to_sharded(segments): + """Simple predicate to shard fully connected layers such that a common + representation becomes a sharded representation.""" + + def _shard_fn(path, weight): + return max(weight.ndim - 2, 0), segments + + return _shard_fn + + +def _sharded_to_all(segments): + """Simple predicate to shard fully connected layers such that a sharded + representation becomes a common representation.""" + + def _shard_fn(path, weight): + if path.endswith("bias"): + return None + return -1, segments + + return _shard_fn + + +def _check_sharding(sharding): + if sharding not in ("all-to-sharded", "sharded-to-all"): + raise ValueError( + ( + f"Sharding type {sharding=} not supported, " + "choose one of 'all-to-sharded' or 'sharded-to-all'" + ) + ) + + +def shard_inplace( + module: Module, + sharding: Union[str, Callable], + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): + """Shard a module in-place by updating its parameter dictionary with the + sharded parameter dictionary. + + The ``sharding`` argument can be any callable that given the path and the + weight returns the sharding axis and optionally also the segments that + comprise the unsharded weight. For instance if the weight is a fused QKV + matrix the segments should be 3. + + .. note:: + The module doesn't change so in order for distributed communication to + happen the module needs to natively support it and for it to be enabled. + + Args: + module (mlx.nn.Module): The parameters of this module will be sharded + in-place. + sharding (str or callable): One of "all-to-sharded" and + "sharded-to-all" or a callable that returns the sharding axis and + segments. + segments (int or list): The segments to use if ``sharding`` is a + string. Default: ``1``. + group (mlx.core.distributed.Group): The distributed group to shard + across. If not set, the global group will be used. Default: ``None``. + """ + if isinstance(sharding, str): + _check_sharding(sharding) + sharding = ( + _all_to_sharded(segments) + if sharding == "all-to-sharded" + else _sharded_to_all(segments) + ) + module.update(_shard(module.parameters(), sharding, group)) + + +def shard_linear( + module: Module, + sharding: str, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): + """Create a new linear layer that has its parameters sharded and also + performs distributed communication either in the forward or backward + pass. + + .. note:: + Contrary to ``shard_inplace``, the original layer is not changed but a + new layer is returned. + + Args: + module (mlx.nn.Module): The linear layer to be sharded. + sharding (str): One of "all-to-sharded" and + "sharded-to-all" that defines the type of sharding to perform. + segments (int or list): The segments to use. Default: ``1``. + group (mlx.core.distributed.Group): The distributed group to shard + across. If not set, the global group will be used. Default: ``None``. + """ + _check_sharding(sharding) + fns = { + ("all-to-sharded", True): AllToShardedLinear.from_linear, + ("all-to-sharded", False): QuantizedAllToShardedLinear.from_quantized_linear, + ("sharded-to-all", True): ShardedToAllLinear.from_linear, + ("sharded-to-all", False): QuantizedShardedToAllLinear.from_quantized_linear, + } + return fns[sharding, isinstance(module, Linear)]( + module, segments=segments, group=group + ) + + +class AllToShardedLinear(Module): + """Each member of the group applies part of the affine transformation such + that the result is sharded across the group. + + The gradients are automatically aggregated from each member of the group. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` the the layer will not use a + bias. Default is ``True``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Initialize the parameters + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (output_dims % N) != 0: + raise ValueError( + f"Cannot shard the output of size {output_dims} across {N} devices." + ) + + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N, input_dims), + ) + if bias: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N,), + ) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + N = self.group.size() + out_dims *= N + return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" + + def __call__(self, x: mx.array) -> mx.array: + # Aggregate the gradients coming from each shard + x = sum_gradients(self.group)(x) + + # Compute the affine projection + if "bias" in self: + x = mx.addmm(self["bias"], x, self["weight"].T) + else: + x = x @ self["weight"].T + return x + + @classmethod + def from_linear( + cls, + linear_layer: Module, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = linear_layer.weight.shape + + sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) + sl.update(_shard(linear_layer.parameters(), _all_to_sharded(segments), group)) + + return sl + + +class ShardedToAllLinear(Module): + """Each member of the group applies part of the affine transformation and + then aggregates the results. + + All nodes will have the same exact result after this layer. + + :class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to + convert linear layers to sharded :obj:`ShardedToAllLinear` layers. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` the the layer will not use a + bias. Default is ``True``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Initialize the parameters + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (input_dims % N) != 0: + raise ValueError( + f"The input of size {input_dims} cannot be sharded across {N} devices." + ) + + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims // N), + ) + if bias: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims,), + ) + + def _extra_repr(self) -> str: + N = self.group.size() + out_dims, in_dims = self.weight.shape + in_dims *= N + return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" + + def __call__(self, x: mx.array) -> mx.array: + x = x @ self["weight"].T + + x = mx.distributed.all_sum(x, group=self.group) + + if "bias" in self: + x = x + self["bias"] + + return x + + @classmethod + def from_linear( + cls, + linear_layer: Module, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = linear_layer.weight.shape + + sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) + sl.update(_shard(linear_layer.parameters(), _sharded_to_all(segments), group)) + + return sl + + +class QuantizedAllToShardedLinear(Module): + """Each member of the group applies part of the affine transformation with + a quantized matrix such that the result is sharded across the group. + + It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. + Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and + will not be included in any gradient computation. + + Args: + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. Default: ``True``. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``64``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Quantization config + self.group_size = group_size + self.bits = bits + + # Initialize the quantized weight + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (output_dims % N) != 0: + raise ValueError( + f"Cannot shard the output of size {output_dims} across {N} devices." + ) + + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N, input_dims), + ) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + + # And bias if needed + if bias: + self.bias = mx.zeros((output_dims // N,)) + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + in_dims *= 32 // self.bits + out_dims *= self.group.size() + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " + f"group_size={self.group_size}, bits={self.bits}" + ) + + def __call__(self, x: mx.array) -> mx.array: + # Aggregate the gradients coming from each shard + x = sum_gradients(self.group)(x) + + x = mx.quantized_matmul( + x, + self["weight"], + scales=self["scales"], + biases=self["biases"], + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + if "bias" in self: + x = x + self["bias"] + return x + + @classmethod + def from_quantized_linear( + cls, + quantized_linear_layer: Module, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = quantized_linear_layer.weight.shape + input_dims *= 32 // quantized_linear_layer.bits + + sl = cls( + input_dims, + output_dims, + hasattr(quantized_linear_layer, "bias"), + group_size=quantized_linear_layer.group_size, + bits=quantized_linear_layer.bits, + group=group, + ) + sl.update( + _shard( + quantized_linear_layer.parameters(), + _all_to_sharded(segments), + group, + ) + ) + + return sl + + +class QuantizedShardedToAllLinear(Module): + """Each member of the group applies part of the affine transformation using + the quantized matrix and then aggregates the results. + + All nodes will have the same exact result after this layer. + + It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. + Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and + will not be included in any gradient computation. + + Args: + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. Default: ``True``. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``64``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Quantization config + self.group_size = group_size + self.bits = bits + + # Initialize the quantized weight + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (input_dims % N) != 0: + raise ValueError( + f"The input of size {input_dims} cannot be sharded across {N} devices." + ) + + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims // N), + ) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + + # And bias if needed + if bias: + self.bias = mx.zeros((output_dims,)) + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + in_dims *= (32 // self.bits) * self.group.size() + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " + f"group_size={self.group_size}, bits={self.bits}" + ) + + def __call__(self, x: mx.array) -> mx.array: + x = mx.quantized_matmul( + x, + self["weight"], + scales=self["scales"], + biases=self["biases"], + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + x = mx.distributed.all_sum(x, group=self.group) + if "bias" in self: + x = x + self["bias"] + return x + + @classmethod + def from_quantized_linear( + cls, + quantized_linear_layer: Module, + *, + segments: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = quantized_linear_layer.weight.shape + input_dims *= 32 // quantized_linear_layer.bits + + sl = cls( + input_dims, + output_dims, + hasattr(quantized_linear_layer, "bias"), + group_size=quantized_linear_layer.group_size, + bits=quantized_linear_layer.bits, + group=group, + ) + sl.update( + _shard( + quantized_linear_layer.parameters(), + _sharded_to_all(segments), + group, + ) + ) + + return sl diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1577cae18..6de580d1b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5124,4 +5124,23 @@ void init_ops(nb::module_& m) { [0, 1, 0], [0, 1, 0]], dtype=float32) )pbdoc"); + m.def( + "contiguous", + &mx::contiguous, + nb::arg(), + "allow_col_major"_a = false, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def contiguous(a: array, /, allow_col_major: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Force an array to be row contiguous. Copy if necessary. + + Args: + a (array): The input to make contiguous + allow_col_major (bool): Consider column major as contiguous and don't copy + + Returns: + array: The row or col contiguous output. + )pbdoc"); } diff --git a/python/tests/mlx_distributed_tests.py b/python/tests/mlx_distributed_tests.py new file mode 100644 index 000000000..5feb51bc9 --- /dev/null +++ b/python/tests/mlx_distributed_tests.py @@ -0,0 +1,250 @@ +# Copyright © 2025 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx.nn as nn +import mlx_tests +from mlx.nn.layers.distributed import shard_inplace, shard_linear +from mlx.nn.utils import average_gradients + + +class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase): + def test_average_gradients(self): + original_all_sum = mx.distributed.all_sum + n_calls = 0 + xtype = None + + def new_all_sum(x, **kwargs): + nonlocal n_calls + nonlocal xtype + + n_calls += 1 + if xtype is not None: + self.assertEqual(xtype, x.dtype) + + return original_all_sum(x, **kwargs) + + mx.distributed.all_sum = new_all_sum + + try: + grads = [mx.ones(10) for i in range(10)] + new_grads = average_gradients(grads) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 1) + + n_calls = 0 + new_grads = average_gradients(grads, all_reduce_size=4 * 50) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 2) + + n_calls = 0 + new_grads = average_gradients(grads, all_reduce_size=0) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 10) + + n_calls = 0 + xtype = mx.float16 + new_grads = average_gradients( + grads, all_reduce_size=2 * 50, communication_type=mx.float16 + ) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 2) + + finally: + mx.distributed.all_sum = original_all_sum + + def test_donation(self): + x = mx.random.normal((1024,)) + mx.eval(x) + mx.synchronize() + + mx.reset_peak_memory() + scale = mx.array(2.0) + y = mx.distributed.all_sum(x) + mx.eval(y) + mx.synchronize() + all_sum_only = mx.get_peak_memory() + y = mx.distributed.all_sum(x) * scale + mx.eval(y) + mx.synchronize() + all_sum_with_binary = mx.get_peak_memory() + + self.assertEqual(all_sum_only, all_sum_with_binary) + + def test_shard_linear(self): + # Seed the prng to have the same inputs and weights generated everywhere + mx.random.seed(0xF0F0F0F0) + + # Prepare inputs + world = mx.distributed.init() + part = ( + slice(None), + slice( + world.rank() * 1024 // world.size(), + (world.rank() + 1) * 1024 // world.size(), + ), + ) + x = mx.random.normal((4, 1024)) + + # Create and shard some linear layers + lin = nn.Linear(1024, 1024, bias=True) + slin1 = shard_linear(lin, "all-to-sharded") + slin2 = shard_linear(lin, "sharded-to-all") + y = lin(x) + y1 = slin1(x) + y2 = slin2(x[part]) + self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) + self.assertTrue(mx.allclose(y[part], y1)) + + # And their quant versions + qlin = lin.to_quantized() + slin1 = shard_linear(qlin, "all-to-sharded") + slin2 = shard_linear(qlin, "sharded-to-all") + y = qlin(x) + y1 = slin1(x) + y2 = slin2(x[part]) + self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) + self.assertTrue(mx.allclose(y[part], y1)) + + # Check the backward works as expected + def dummy_loss(model, x, y): + return (model(x) * y).sum() + + mod = nn.Sequential( + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + ) + smod = nn.Sequential( + shard_linear(mod.layers[0], "all-to-sharded"), + shard_linear(mod.layers[1], "sharded-to-all"), + shard_linear(mod.layers[2], "all-to-sharded"), + shard_linear(mod.layers[3], "sharded-to-all"), + ) + + grad1 = nn.value_and_grad(mod, dummy_loss) + grad2 = nn.value_and_grad(smod, dummy_loss) + + x = mx.random.normal((4, 128)) + y = mx.random.normal((4, 128)) + + l1, g1 = grad1(mod, x, y) + l2, g2 = grad2(smod, x, y) + mx.eval(l1, g1, l2, g2) + + part = slice( + world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size() + ) + self.assertTrue(mx.allclose(l1, l2)) + self.assertTrue( + mx.allclose( + g1["layers"][0]["weight"][part], + g2["layers"][0]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][2]["weight"][part], + g2["layers"][2]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][1]["weight"][:, part], + g2["layers"][1]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][3]["weight"][:, part], + g2["layers"][3]["weight"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][0]["bias"][part], + g2["layers"][0]["bias"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][2]["bias"][part], + g2["layers"][2]["bias"], + atol=1e-6, + rtol=1e-4, + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4 + ) + ) + self.assertTrue( + mx.allclose( + g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4 + ) + ) + + def test_shard_predicate(self): + mx.random.seed(0xF0F0F0F0) + + class MyConv(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.aggregate = kwargs.pop("aggregate", False) + self.conv = nn.Conv2d(*args, **kwargs) + + def __call__(self, x): + x = self.conv(x) + if self.aggregate: + x = mx.distributed.all_sum(x) + return x + + def sharding(path, weight): + parts = path.split(".") + even = int(parts[1]) % 2 == 0 + if even: + return 0 + else: + return -1 if parts[-1] != "bias" else None + + mod = nn.Sequential( + MyConv(3, 128, kernel_size=3), + MyConv(128, 128, kernel_size=3), + MyConv(128, 128, kernel_size=3), + MyConv(128, 3, kernel_size=3), + ) + smod = nn.Sequential( + MyConv(3, 128, kernel_size=3), + MyConv(128, 128, kernel_size=3, aggregate=True), + MyConv(128, 128, kernel_size=3), + MyConv(128, 3, kernel_size=3, aggregate=True), + ) + smod.update(mod.parameters()) + shard_inplace(smod, sharding) + + x = mx.random.normal((4, 16, 16, 3)) + y1 = mod(x) + y2 = smod(x) + self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index a6467568e..ebc8ad728 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -3,11 +3,14 @@ import unittest import mlx.core as mx -import mlx_tests -from mlx.nn.utils import average_gradients +import mlx_distributed_tests -class TestDistributed(mlx_tests.MLXTestCase): +class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): + @classmethod + def setUpClass(cls): + world = mx.distributed.init(strict=True, backend="mpi") + def test_groups(self): world = mx.distributed.init() self.assertEqual(world.size(), 8) @@ -121,77 +124,6 @@ class TestDistributed(mlx_tests.MLXTestCase): x = mx.distributed.recv_like(x, neighbor, group=pairs) mx.eval(y, x) - def test_average_gradients(self): - original_all_sum = mx.distributed.all_sum - n_calls = 0 - xtype = None - - def new_all_sum(x, **kwargs): - nonlocal n_calls - nonlocal xtype - - n_calls += 1 - if xtype is not None: - self.assertEqual(xtype, x.dtype) - - return original_all_sum(x, **kwargs) - - mx.distributed.all_sum = new_all_sum - - try: - grads = [mx.ones(10) for i in range(10)] - new_grads = average_gradients(grads) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 1) - - n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=4 * 50) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 2) - - n_calls = 0 - new_grads = average_gradients(grads, all_reduce_size=0) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 10) - - n_calls = 0 - xtype = mx.float16 - new_grads = average_gradients( - grads, all_reduce_size=2 * 50, communication_type=mx.float16 - ) - mx.eval(new_grads) - self.assertEqual(len(new_grads), 10) - self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) - self.assertTrue(all(mx.all(g == 1) for g in new_grads)) - self.assertEqual(n_calls, 2) - - finally: - mx.distributed.all_sum = original_all_sum - - def test_donation(self): - x = mx.random.normal((1024,)) - mx.eval(x) - mx.synchronize() - - mx.reset_peak_memory() - scale = mx.array(2.0) - y = mx.distributed.all_sum(x) - mx.eval(y) - mx.synchronize() - all_sum_only = mx.get_peak_memory() - y = mx.distributed.all_sum(x) * scale - mx.eval(y) - mx.synchronize() - all_sum_with_binary = mx.get_peak_memory() - - self.assertEqual(all_sum_only, all_sum_with_binary) - if __name__ == "__main__": unittest.main() diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index 93039095e..169889559 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -3,10 +3,10 @@ import unittest import mlx.core as mx -import mlx_tests +import mlx_distributed_tests -class TestRingDistributed(mlx_tests.MLXTestCase): +class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): @classmethod def setUpClass(cls): world = mx.distributed.init(strict=True, backend="ring")