mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Distributed layers (#1270)
This commit is contained in:
parent
69e4dd506b
commit
4eef8102c9
@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (in.flags().row_contiguous ||
|
constexpr size_t extra_bytes = 16384;
|
||||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||||
|
(in.flags().row_contiguous ||
|
||||||
|
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy(in, out, CopyType::General, stream());
|
copy(in, out, CopyType::General, stream());
|
||||||
|
@ -251,8 +251,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (in.flags().row_contiguous ||
|
constexpr size_t extra_bytes = 16384;
|
||||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||||
|
(in.flags().row_contiguous ||
|
||||||
|
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy_gpu(in, out, CopyType::General);
|
copy_gpu(in, out, CopyType::General);
|
||||||
|
@ -993,6 +993,9 @@ array concatenate(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[concatenate] No arrays provided for concatenation");
|
"[concatenate] No arrays provided for concatenation");
|
||||||
}
|
}
|
||||||
|
if (arrays.size() == 1) {
|
||||||
|
return arrays[0];
|
||||||
|
}
|
||||||
|
|
||||||
auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] ");
|
auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] ");
|
||||||
|
|
||||||
|
@ -761,6 +761,8 @@ def main():
|
|||||||
"--cwd", help="Set the working directory on each node to the provided one"
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
)
|
)
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
|
if rest[0] == "--":
|
||||||
|
rest.pop(0)
|
||||||
|
|
||||||
if args.print_python:
|
if args.print_python:
|
||||||
print(sys.executable)
|
print(sys.executable)
|
||||||
|
@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import (
|
|||||||
ConvTranspose2d,
|
ConvTranspose2d,
|
||||||
ConvTranspose3d,
|
ConvTranspose3d,
|
||||||
)
|
)
|
||||||
|
from mlx.nn.layers.distributed import (
|
||||||
|
AllToShardedLinear,
|
||||||
|
QuantizedAllToShardedLinear,
|
||||||
|
QuantizedShardedToAllLinear,
|
||||||
|
ShardedToAllLinear,
|
||||||
|
)
|
||||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||||
from mlx.nn.layers.embedding import Embedding
|
from mlx.nn.layers.embedding import Embedding
|
||||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||||
|
599
python/mlx/nn/layers/distributed.py
Normal file
599
python/mlx/nn/layers/distributed.py
Normal file
@ -0,0 +1,599 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.nn.layers.base import Module
|
||||||
|
from mlx.nn.layers.linear import Linear
|
||||||
|
from mlx.nn.layers.quantized import QuantizedLinear
|
||||||
|
from mlx.utils import tree_map_with_path
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def sum_gradients(group):
|
||||||
|
if group.size() == 1:
|
||||||
|
return lambda x: x
|
||||||
|
|
||||||
|
@mx.custom_function
|
||||||
|
def f(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@f.vjp
|
||||||
|
def f(x, dx, _):
|
||||||
|
return mx.distributed.all_sum(dx, group=group)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def _split(weight, segments, axis):
|
||||||
|
"""Equivalent to mx.split but allows for fractional segments."""
|
||||||
|
if isinstance(segments, int) or isinstance(segments[0], int):
|
||||||
|
return mx.split(weight, segments, axis=axis)
|
||||||
|
|
||||||
|
N = weight.shape[axis]
|
||||||
|
indices = [int(s * N) for s in segments]
|
||||||
|
return mx.split(weight, indices, axis=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def _shard(
|
||||||
|
parameters: dict,
|
||||||
|
sharding_predicate: Callable,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
"""Returns a new parameter tree with the weights sharded according to the
|
||||||
|
sharding_predicate.
|
||||||
|
|
||||||
|
The sharding predicate should return the sharding axis and optionally also
|
||||||
|
the segments that comprise the weight.
|
||||||
|
"""
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
N = group.size()
|
||||||
|
r = group.rank()
|
||||||
|
|
||||||
|
def _shard_fn(path, weight):
|
||||||
|
if not isinstance(weight, mx.array):
|
||||||
|
return weight
|
||||||
|
|
||||||
|
s = sharding_predicate(path, weight)
|
||||||
|
if s is None:
|
||||||
|
return weight
|
||||||
|
|
||||||
|
axis = None
|
||||||
|
segments = 1
|
||||||
|
if isinstance(s, int):
|
||||||
|
axis = s
|
||||||
|
elif isinstance(s, tuple):
|
||||||
|
axis, segments = s
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The sharding function should return int or tuple[int, list]"
|
||||||
|
)
|
||||||
|
|
||||||
|
return mx.contiguous(
|
||||||
|
mx.concatenate(
|
||||||
|
[_split(part, N, axis)[r] for part in _split(weight, segments, axis)],
|
||||||
|
axis=axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tree_map_with_path(_shard_fn, parameters)
|
||||||
|
|
||||||
|
|
||||||
|
def _all_to_sharded(segments):
|
||||||
|
"""Simple predicate to shard fully connected layers such that a common
|
||||||
|
representation becomes a sharded representation."""
|
||||||
|
|
||||||
|
def _shard_fn(path, weight):
|
||||||
|
return max(weight.ndim - 2, 0), segments
|
||||||
|
|
||||||
|
return _shard_fn
|
||||||
|
|
||||||
|
|
||||||
|
def _sharded_to_all(segments):
|
||||||
|
"""Simple predicate to shard fully connected layers such that a sharded
|
||||||
|
representation becomes a common representation."""
|
||||||
|
|
||||||
|
def _shard_fn(path, weight):
|
||||||
|
if path.endswith("bias"):
|
||||||
|
return None
|
||||||
|
return -1, segments
|
||||||
|
|
||||||
|
return _shard_fn
|
||||||
|
|
||||||
|
|
||||||
|
def _check_sharding(sharding):
|
||||||
|
if sharding not in ("all-to-sharded", "sharded-to-all"):
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
f"Sharding type {sharding=} not supported, "
|
||||||
|
"choose one of 'all-to-sharded' or 'sharded-to-all'"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def shard_inplace(
|
||||||
|
module: Module,
|
||||||
|
sharding: Union[str, Callable],
|
||||||
|
*,
|
||||||
|
segments: Union[int, list] = 1,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
"""Shard a module in-place by updating its parameter dictionary with the
|
||||||
|
sharded parameter dictionary.
|
||||||
|
|
||||||
|
The ``sharding`` argument can be any callable that given the path and the
|
||||||
|
weight returns the sharding axis and optionally also the segments that
|
||||||
|
comprise the unsharded weight. For instance if the weight is a fused QKV
|
||||||
|
matrix the segments should be 3.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The module doesn't change so in order for distributed communication to
|
||||||
|
happen the module needs to natively support it and for it to be enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (mlx.nn.Module): The parameters of this module will be sharded
|
||||||
|
in-place.
|
||||||
|
sharding (str or callable): One of "all-to-sharded" and
|
||||||
|
"sharded-to-all" or a callable that returns the sharding axis and
|
||||||
|
segments.
|
||||||
|
segments (int or list): The segments to use if ``sharding`` is a
|
||||||
|
string. Default: ``1``.
|
||||||
|
group (mlx.core.distributed.Group): The distributed group to shard
|
||||||
|
across. If not set, the global group will be used. Default: ``None``.
|
||||||
|
"""
|
||||||
|
if isinstance(sharding, str):
|
||||||
|
_check_sharding(sharding)
|
||||||
|
sharding = (
|
||||||
|
_all_to_sharded(segments)
|
||||||
|
if sharding == "all-to-sharded"
|
||||||
|
else _sharded_to_all(segments)
|
||||||
|
)
|
||||||
|
module.update(_shard(module.parameters(), sharding, group))
|
||||||
|
|
||||||
|
|
||||||
|
def shard_linear(
|
||||||
|
module: Module,
|
||||||
|
sharding: str,
|
||||||
|
*,
|
||||||
|
segments: Union[int, list] = 1,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
"""Create a new linear layer that has its parameters sharded and also
|
||||||
|
performs distributed communication either in the forward or backward
|
||||||
|
pass.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Contrary to ``shard_inplace``, the original layer is not changed but a
|
||||||
|
new layer is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (mlx.nn.Module): The linear layer to be sharded.
|
||||||
|
sharding (str): One of "all-to-sharded" and
|
||||||
|
"sharded-to-all" that defines the type of sharding to perform.
|
||||||
|
segments (int or list): The segments to use. Default: ``1``.
|
||||||
|
group (mlx.core.distributed.Group): The distributed group to shard
|
||||||
|
across. If not set, the global group will be used. Default: ``None``.
|
||||||
|
"""
|
||||||
|
_check_sharding(sharding)
|
||||||
|
fns = {
|
||||||
|
("all-to-sharded", True): AllToShardedLinear.from_linear,
|
||||||
|
("all-to-sharded", False): QuantizedAllToShardedLinear.from_quantized_linear,
|
||||||
|
("sharded-to-all", True): ShardedToAllLinear.from_linear,
|
||||||
|
("sharded-to-all", False): QuantizedShardedToAllLinear.from_quantized_linear,
|
||||||
|
}
|
||||||
|
return fns[sharding, isinstance(module, Linear)](
|
||||||
|
module, segments=segments, group=group
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AllToShardedLinear(Module):
|
||||||
|
"""Each member of the group applies part of the affine transformation such
|
||||||
|
that the result is sharded across the group.
|
||||||
|
|
||||||
|
The gradients are automatically aggregated from each member of the group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dims (int): The dimensionality of the input features
|
||||||
|
output_dims (int): The dimensionality of the output features
|
||||||
|
bias (bool, optional): If set to ``False`` the the layer will not use a
|
||||||
|
bias. Default is ``True``.
|
||||||
|
group (mx.distributed.Group, optional): The sharding will happen across
|
||||||
|
this group. If not set then the global group is used. Default is
|
||||||
|
``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
output_dims: int,
|
||||||
|
bias: bool = True,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Initialize the parameters
|
||||||
|
scale = math.sqrt(1.0 / input_dims)
|
||||||
|
self.group = group or mx.distributed.init()
|
||||||
|
N = self.group.size()
|
||||||
|
|
||||||
|
if (output_dims % N) != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot shard the output of size {output_dims} across {N} devices."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(output_dims // N, input_dims),
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(output_dims // N,),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extra_repr(self) -> str:
|
||||||
|
out_dims, in_dims = self.weight.shape
|
||||||
|
N = self.group.size()
|
||||||
|
out_dims *= N
|
||||||
|
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
# Aggregate the gradients coming from each shard
|
||||||
|
x = sum_gradients(self.group)(x)
|
||||||
|
|
||||||
|
# Compute the affine projection
|
||||||
|
if "bias" in self:
|
||||||
|
x = mx.addmm(self["bias"], x, self["weight"].T)
|
||||||
|
else:
|
||||||
|
x = x @ self["weight"].T
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_linear(
|
||||||
|
cls,
|
||||||
|
linear_layer: Module,
|
||||||
|
*,
|
||||||
|
segments: Union[int, list] = 1,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
|
|
||||||
|
sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group)
|
||||||
|
sl.update(_shard(linear_layer.parameters(), _all_to_sharded(segments), group))
|
||||||
|
|
||||||
|
return sl
|
||||||
|
|
||||||
|
|
||||||
|
class ShardedToAllLinear(Module):
|
||||||
|
"""Each member of the group applies part of the affine transformation and
|
||||||
|
then aggregates the results.
|
||||||
|
|
||||||
|
All nodes will have the same exact result after this layer.
|
||||||
|
|
||||||
|
:class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to
|
||||||
|
convert linear layers to sharded :obj:`ShardedToAllLinear` layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dims (int): The dimensionality of the input features
|
||||||
|
output_dims (int): The dimensionality of the output features
|
||||||
|
bias (bool, optional): If set to ``False`` the the layer will not use a
|
||||||
|
bias. Default is ``True``.
|
||||||
|
group (mx.distributed.Group, optional): The sharding will happen across
|
||||||
|
this group. If not set then the global group is used. Default is
|
||||||
|
``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
output_dims: int,
|
||||||
|
bias: bool = True,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Initialize the parameters
|
||||||
|
scale = math.sqrt(1.0 / input_dims)
|
||||||
|
self.group = group or mx.distributed.init()
|
||||||
|
N = self.group.size()
|
||||||
|
|
||||||
|
if (input_dims % N) != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input of size {input_dims} cannot be sharded across {N} devices."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(output_dims, input_dims // N),
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(output_dims,),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extra_repr(self) -> str:
|
||||||
|
N = self.group.size()
|
||||||
|
out_dims, in_dims = self.weight.shape
|
||||||
|
in_dims *= N
|
||||||
|
return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}"
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
x = x @ self["weight"].T
|
||||||
|
|
||||||
|
x = mx.distributed.all_sum(x, group=self.group)
|
||||||
|
|
||||||
|
if "bias" in self:
|
||||||
|
x = x + self["bias"]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_linear(
|
||||||
|
cls,
|
||||||
|
linear_layer: Module,
|
||||||
|
*,
|
||||||
|
segments: Union[int, list] = 1,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
|
|
||||||
|
sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group)
|
||||||
|
sl.update(_shard(linear_layer.parameters(), _sharded_to_all(segments), group))
|
||||||
|
|
||||||
|
return sl
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedAllToShardedLinear(Module):
|
||||||
|
"""Each member of the group applies part of the affine transformation with
|
||||||
|
a quantized matrix such that the result is sharded across the group.
|
||||||
|
|
||||||
|
It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`.
|
||||||
|
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
|
||||||
|
will not be included in any gradient computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dims (int): The dimensionality of the input features.
|
||||||
|
output_dims (int): The dimensionality of the output features.
|
||||||
|
bias (bool, optional): If set to ``False`` then the layer will not use
|
||||||
|
a bias. Default: ``True``.
|
||||||
|
group_size (int, optional): The group size to use for the quantized
|
||||||
|
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||||
|
bits (int, optional): The bit width to use for the quantized weight.
|
||||||
|
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||||
|
group (mx.distributed.Group, optional): The sharding will happen across
|
||||||
|
this group. If not set then the global group is used. Default is
|
||||||
|
``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
output_dims: int,
|
||||||
|
bias: bool = True,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Quantization config
|
||||||
|
self.group_size = group_size
|
||||||
|
self.bits = bits
|
||||||
|
|
||||||
|
# Initialize the quantized weight
|
||||||
|
scale = math.sqrt(1.0 / input_dims)
|
||||||
|
self.group = group or mx.distributed.init()
|
||||||
|
N = self.group.size()
|
||||||
|
|
||||||
|
if (output_dims % N) != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot shard the output of size {output_dims} across {N} devices."
|
||||||
|
)
|
||||||
|
|
||||||
|
weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(output_dims // N, input_dims),
|
||||||
|
)
|
||||||
|
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||||
|
|
||||||
|
# And bias if needed
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((output_dims // N,))
|
||||||
|
|
||||||
|
# Freeze this model's parameters
|
||||||
|
self.freeze()
|
||||||
|
|
||||||
|
def unfreeze(self, *args, **kwargs):
|
||||||
|
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||||
|
our parameters will remain frozen."""
|
||||||
|
super().unfreeze(*args, **kwargs)
|
||||||
|
self.freeze(recurse=False)
|
||||||
|
|
||||||
|
def _extra_repr(self) -> str:
|
||||||
|
out_dims, in_dims = self.weight.shape
|
||||||
|
in_dims *= 32 // self.bits
|
||||||
|
out_dims *= self.group.size()
|
||||||
|
return (
|
||||||
|
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
|
||||||
|
f"group_size={self.group_size}, bits={self.bits}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
# Aggregate the gradients coming from each shard
|
||||||
|
x = sum_gradients(self.group)(x)
|
||||||
|
|
||||||
|
x = mx.quantized_matmul(
|
||||||
|
x,
|
||||||
|
self["weight"],
|
||||||
|
scales=self["scales"],
|
||||||
|
biases=self["biases"],
|
||||||
|
transpose=True,
|
||||||
|
group_size=self.group_size,
|
||||||
|
bits=self.bits,
|
||||||
|
)
|
||||||
|
if "bias" in self:
|
||||||
|
x = x + self["bias"]
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_quantized_linear(
|
||||||
|
cls,
|
||||||
|
quantized_linear_layer: Module,
|
||||||
|
*,
|
||||||
|
segments: Union[int, list] = 1,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
output_dims, input_dims = quantized_linear_layer.weight.shape
|
||||||
|
input_dims *= 32 // quantized_linear_layer.bits
|
||||||
|
|
||||||
|
sl = cls(
|
||||||
|
input_dims,
|
||||||
|
output_dims,
|
||||||
|
hasattr(quantized_linear_layer, "bias"),
|
||||||
|
group_size=quantized_linear_layer.group_size,
|
||||||
|
bits=quantized_linear_layer.bits,
|
||||||
|
group=group,
|
||||||
|
)
|
||||||
|
sl.update(
|
||||||
|
_shard(
|
||||||
|
quantized_linear_layer.parameters(),
|
||||||
|
_all_to_sharded(segments),
|
||||||
|
group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return sl
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedShardedToAllLinear(Module):
|
||||||
|
"""Each member of the group applies part of the affine transformation using
|
||||||
|
the quantized matrix and then aggregates the results.
|
||||||
|
|
||||||
|
All nodes will have the same exact result after this layer.
|
||||||
|
|
||||||
|
It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`.
|
||||||
|
Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and
|
||||||
|
will not be included in any gradient computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dims (int): The dimensionality of the input features.
|
||||||
|
output_dims (int): The dimensionality of the output features.
|
||||||
|
bias (bool, optional): If set to ``False`` then the layer will not use
|
||||||
|
a bias. Default: ``True``.
|
||||||
|
group_size (int, optional): The group size to use for the quantized
|
||||||
|
weight. See :func:`~mlx.core.quantize`. Default: ``64``.
|
||||||
|
bits (int, optional): The bit width to use for the quantized weight.
|
||||||
|
See :func:`~mlx.core.quantize`. Default: ``4``.
|
||||||
|
group (mx.distributed.Group, optional): The sharding will happen across
|
||||||
|
this group. If not set then the global group is used. Default is
|
||||||
|
``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
output_dims: int,
|
||||||
|
bias: bool = True,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Quantization config
|
||||||
|
self.group_size = group_size
|
||||||
|
self.bits = bits
|
||||||
|
|
||||||
|
# Initialize the quantized weight
|
||||||
|
scale = math.sqrt(1.0 / input_dims)
|
||||||
|
self.group = group or mx.distributed.init()
|
||||||
|
N = self.group.size()
|
||||||
|
|
||||||
|
if (input_dims % N) != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input of size {input_dims} cannot be sharded across {N} devices."
|
||||||
|
)
|
||||||
|
|
||||||
|
weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(output_dims, input_dims // N),
|
||||||
|
)
|
||||||
|
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||||
|
|
||||||
|
# And bias if needed
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((output_dims,))
|
||||||
|
|
||||||
|
# Freeze this model's parameters
|
||||||
|
self.freeze()
|
||||||
|
|
||||||
|
def unfreeze(self, *args, **kwargs):
|
||||||
|
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||||
|
our parameters will remain frozen."""
|
||||||
|
super().unfreeze(*args, **kwargs)
|
||||||
|
self.freeze(recurse=False)
|
||||||
|
|
||||||
|
def _extra_repr(self) -> str:
|
||||||
|
out_dims, in_dims = self.weight.shape
|
||||||
|
in_dims *= (32 // self.bits) * self.group.size()
|
||||||
|
return (
|
||||||
|
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, "
|
||||||
|
f"group_size={self.group_size}, bits={self.bits}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
x = mx.quantized_matmul(
|
||||||
|
x,
|
||||||
|
self["weight"],
|
||||||
|
scales=self["scales"],
|
||||||
|
biases=self["biases"],
|
||||||
|
transpose=True,
|
||||||
|
group_size=self.group_size,
|
||||||
|
bits=self.bits,
|
||||||
|
)
|
||||||
|
x = mx.distributed.all_sum(x, group=self.group)
|
||||||
|
if "bias" in self:
|
||||||
|
x = x + self["bias"]
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_quantized_linear(
|
||||||
|
cls,
|
||||||
|
quantized_linear_layer: Module,
|
||||||
|
*,
|
||||||
|
segments: Union[int, list] = 1,
|
||||||
|
group: Optional[mx.distributed.Group] = None,
|
||||||
|
):
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
output_dims, input_dims = quantized_linear_layer.weight.shape
|
||||||
|
input_dims *= 32 // quantized_linear_layer.bits
|
||||||
|
|
||||||
|
sl = cls(
|
||||||
|
input_dims,
|
||||||
|
output_dims,
|
||||||
|
hasattr(quantized_linear_layer, "bias"),
|
||||||
|
group_size=quantized_linear_layer.group_size,
|
||||||
|
bits=quantized_linear_layer.bits,
|
||||||
|
group=group,
|
||||||
|
)
|
||||||
|
sl.update(
|
||||||
|
_shard(
|
||||||
|
quantized_linear_layer.parameters(),
|
||||||
|
_sharded_to_all(segments),
|
||||||
|
group,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return sl
|
@ -5124,4 +5124,23 @@ void init_ops(nb::module_& m) {
|
|||||||
[0, 1, 0],
|
[0, 1, 0],
|
||||||
[0, 1, 0]], dtype=float32)
|
[0, 1, 0]], dtype=float32)
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"contiguous",
|
||||||
|
&mx::contiguous,
|
||||||
|
nb::arg(),
|
||||||
|
"allow_col_major"_a = false,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def contiguous(a: array, /, allow_col_major: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Force an array to be row contiguous. Copy if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): The input to make contiguous
|
||||||
|
allow_col_major (bool): Consider column major as contiguous and don't copy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The row or col contiguous output.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
250
python/tests/mlx_distributed_tests.py
Normal file
250
python/tests/mlx_distributed_tests.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx_tests
|
||||||
|
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||||
|
from mlx.nn.utils import average_gradients
|
||||||
|
|
||||||
|
|
||||||
|
class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||||
|
def test_average_gradients(self):
|
||||||
|
original_all_sum = mx.distributed.all_sum
|
||||||
|
n_calls = 0
|
||||||
|
xtype = None
|
||||||
|
|
||||||
|
def new_all_sum(x, **kwargs):
|
||||||
|
nonlocal n_calls
|
||||||
|
nonlocal xtype
|
||||||
|
|
||||||
|
n_calls += 1
|
||||||
|
if xtype is not None:
|
||||||
|
self.assertEqual(xtype, x.dtype)
|
||||||
|
|
||||||
|
return original_all_sum(x, **kwargs)
|
||||||
|
|
||||||
|
mx.distributed.all_sum = new_all_sum
|
||||||
|
|
||||||
|
try:
|
||||||
|
grads = [mx.ones(10) for i in range(10)]
|
||||||
|
new_grads = average_gradients(grads)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 1)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
new_grads = average_gradients(grads, all_reduce_size=0)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 10)
|
||||||
|
|
||||||
|
n_calls = 0
|
||||||
|
xtype = mx.float16
|
||||||
|
new_grads = average_gradients(
|
||||||
|
grads, all_reduce_size=2 * 50, communication_type=mx.float16
|
||||||
|
)
|
||||||
|
mx.eval(new_grads)
|
||||||
|
self.assertEqual(len(new_grads), 10)
|
||||||
|
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||||
|
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||||
|
self.assertEqual(n_calls, 2)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
mx.distributed.all_sum = original_all_sum
|
||||||
|
|
||||||
|
def test_donation(self):
|
||||||
|
x = mx.random.normal((1024,))
|
||||||
|
mx.eval(x)
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
|
mx.reset_peak_memory()
|
||||||
|
scale = mx.array(2.0)
|
||||||
|
y = mx.distributed.all_sum(x)
|
||||||
|
mx.eval(y)
|
||||||
|
mx.synchronize()
|
||||||
|
all_sum_only = mx.get_peak_memory()
|
||||||
|
y = mx.distributed.all_sum(x) * scale
|
||||||
|
mx.eval(y)
|
||||||
|
mx.synchronize()
|
||||||
|
all_sum_with_binary = mx.get_peak_memory()
|
||||||
|
|
||||||
|
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||||
|
|
||||||
|
def test_shard_linear(self):
|
||||||
|
# Seed the prng to have the same inputs and weights generated everywhere
|
||||||
|
mx.random.seed(0xF0F0F0F0)
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
world = mx.distributed.init()
|
||||||
|
part = (
|
||||||
|
slice(None),
|
||||||
|
slice(
|
||||||
|
world.rank() * 1024 // world.size(),
|
||||||
|
(world.rank() + 1) * 1024 // world.size(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
x = mx.random.normal((4, 1024))
|
||||||
|
|
||||||
|
# Create and shard some linear layers
|
||||||
|
lin = nn.Linear(1024, 1024, bias=True)
|
||||||
|
slin1 = shard_linear(lin, "all-to-sharded")
|
||||||
|
slin2 = shard_linear(lin, "sharded-to-all")
|
||||||
|
y = lin(x)
|
||||||
|
y1 = slin1(x)
|
||||||
|
y2 = slin2(x[part])
|
||||||
|
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(y[part], y1))
|
||||||
|
|
||||||
|
# And their quant versions
|
||||||
|
qlin = lin.to_quantized()
|
||||||
|
slin1 = shard_linear(qlin, "all-to-sharded")
|
||||||
|
slin2 = shard_linear(qlin, "sharded-to-all")
|
||||||
|
y = qlin(x)
|
||||||
|
y1 = slin1(x)
|
||||||
|
y2 = slin2(x[part])
|
||||||
|
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
||||||
|
self.assertTrue(mx.allclose(y[part], y1))
|
||||||
|
|
||||||
|
# Check the backward works as expected
|
||||||
|
def dummy_loss(model, x, y):
|
||||||
|
return (model(x) * y).sum()
|
||||||
|
|
||||||
|
mod = nn.Sequential(
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
nn.Linear(128, 128),
|
||||||
|
)
|
||||||
|
smod = nn.Sequential(
|
||||||
|
shard_linear(mod.layers[0], "all-to-sharded"),
|
||||||
|
shard_linear(mod.layers[1], "sharded-to-all"),
|
||||||
|
shard_linear(mod.layers[2], "all-to-sharded"),
|
||||||
|
shard_linear(mod.layers[3], "sharded-to-all"),
|
||||||
|
)
|
||||||
|
|
||||||
|
grad1 = nn.value_and_grad(mod, dummy_loss)
|
||||||
|
grad2 = nn.value_and_grad(smod, dummy_loss)
|
||||||
|
|
||||||
|
x = mx.random.normal((4, 128))
|
||||||
|
y = mx.random.normal((4, 128))
|
||||||
|
|
||||||
|
l1, g1 = grad1(mod, x, y)
|
||||||
|
l2, g2 = grad2(smod, x, y)
|
||||||
|
mx.eval(l1, g1, l2, g2)
|
||||||
|
|
||||||
|
part = slice(
|
||||||
|
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(l1, l2))
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][0]["weight"][part],
|
||||||
|
g2["layers"][0]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][2]["weight"][part],
|
||||||
|
g2["layers"][2]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][1]["weight"][:, part],
|
||||||
|
g2["layers"][1]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][3]["weight"][:, part],
|
||||||
|
g2["layers"][3]["weight"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][0]["bias"][part],
|
||||||
|
g2["layers"][0]["bias"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][2]["bias"][part],
|
||||||
|
g2["layers"][2]["bias"],
|
||||||
|
atol=1e-6,
|
||||||
|
rtol=1e-4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_shard_predicate(self):
|
||||||
|
mx.random.seed(0xF0F0F0F0)
|
||||||
|
|
||||||
|
class MyConv(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.aggregate = kwargs.pop("aggregate", False)
|
||||||
|
self.conv = nn.Conv2d(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.aggregate:
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sharding(path, weight):
|
||||||
|
parts = path.split(".")
|
||||||
|
even = int(parts[1]) % 2 == 0
|
||||||
|
if even:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return -1 if parts[-1] != "bias" else None
|
||||||
|
|
||||||
|
mod = nn.Sequential(
|
||||||
|
MyConv(3, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 3, kernel_size=3),
|
||||||
|
)
|
||||||
|
smod = nn.Sequential(
|
||||||
|
MyConv(3, 128, kernel_size=3),
|
||||||
|
MyConv(128, 128, kernel_size=3, aggregate=True),
|
||||||
|
MyConv(128, 128, kernel_size=3),
|
||||||
|
MyConv(128, 3, kernel_size=3, aggregate=True),
|
||||||
|
)
|
||||||
|
smod.update(mod.parameters())
|
||||||
|
shard_inplace(smod, sharding)
|
||||||
|
|
||||||
|
x = mx.random.normal((4, 16, 16, 3))
|
||||||
|
y1 = mod(x)
|
||||||
|
y2 = smod(x)
|
||||||
|
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
@ -3,11 +3,14 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_distributed_tests
|
||||||
from mlx.nn.utils import average_gradients
|
|
||||||
|
|
||||||
|
|
||||||
class TestDistributed(mlx_tests.MLXTestCase):
|
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
world = mx.distributed.init(strict=True, backend="mpi")
|
||||||
|
|
||||||
def test_groups(self):
|
def test_groups(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
self.assertEqual(world.size(), 8)
|
self.assertEqual(world.size(), 8)
|
||||||
@ -121,77 +124,6 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||||
mx.eval(y, x)
|
mx.eval(y, x)
|
||||||
|
|
||||||
def test_average_gradients(self):
|
|
||||||
original_all_sum = mx.distributed.all_sum
|
|
||||||
n_calls = 0
|
|
||||||
xtype = None
|
|
||||||
|
|
||||||
def new_all_sum(x, **kwargs):
|
|
||||||
nonlocal n_calls
|
|
||||||
nonlocal xtype
|
|
||||||
|
|
||||||
n_calls += 1
|
|
||||||
if xtype is not None:
|
|
||||||
self.assertEqual(xtype, x.dtype)
|
|
||||||
|
|
||||||
return original_all_sum(x, **kwargs)
|
|
||||||
|
|
||||||
mx.distributed.all_sum = new_all_sum
|
|
||||||
|
|
||||||
try:
|
|
||||||
grads = [mx.ones(10) for i in range(10)]
|
|
||||||
new_grads = average_gradients(grads)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 1)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 2)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
new_grads = average_gradients(grads, all_reduce_size=0)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 10)
|
|
||||||
|
|
||||||
n_calls = 0
|
|
||||||
xtype = mx.float16
|
|
||||||
new_grads = average_gradients(
|
|
||||||
grads, all_reduce_size=2 * 50, communication_type=mx.float16
|
|
||||||
)
|
|
||||||
mx.eval(new_grads)
|
|
||||||
self.assertEqual(len(new_grads), 10)
|
|
||||||
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
|
||||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
|
||||||
self.assertEqual(n_calls, 2)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
mx.distributed.all_sum = original_all_sum
|
|
||||||
|
|
||||||
def test_donation(self):
|
|
||||||
x = mx.random.normal((1024,))
|
|
||||||
mx.eval(x)
|
|
||||||
mx.synchronize()
|
|
||||||
|
|
||||||
mx.reset_peak_memory()
|
|
||||||
scale = mx.array(2.0)
|
|
||||||
y = mx.distributed.all_sum(x)
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_only = mx.get_peak_memory()
|
|
||||||
y = mx.distributed.all_sum(x) * scale
|
|
||||||
mx.eval(y)
|
|
||||||
mx.synchronize()
|
|
||||||
all_sum_with_binary = mx.get_peak_memory()
|
|
||||||
|
|
||||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -3,10 +3,10 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_distributed_tests
|
||||||
|
|
||||||
|
|
||||||
class TestRingDistributed(mlx_tests.MLXTestCase):
|
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
world = mx.distributed.init(strict=True, backend="ring")
|
world = mx.distributed.init(strict=True, backend="ring")
|
||||||
|
Loading…
Reference in New Issue
Block a user