mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Distributed layers (#1270)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							69e4dd506b
						
					
				
				
					commit
					4eef8102c9
				
			| @@ -761,6 +761,8 @@ def main(): | ||||
|         "--cwd", help="Set the working directory on each node to the provided one" | ||||
|     ) | ||||
|     args, rest = parser.parse_known_args() | ||||
|     if rest[0] == "--": | ||||
|         rest.pop(0) | ||||
|  | ||||
|     if args.print_python: | ||||
|         print(sys.executable) | ||||
|   | ||||
| @@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import ( | ||||
|     ConvTranspose2d, | ||||
|     ConvTranspose3d, | ||||
| ) | ||||
| from mlx.nn.layers.distributed import ( | ||||
|     AllToShardedLinear, | ||||
|     QuantizedAllToShardedLinear, | ||||
|     QuantizedShardedToAllLinear, | ||||
|     ShardedToAllLinear, | ||||
| ) | ||||
| from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d | ||||
| from mlx.nn.layers.embedding import Embedding | ||||
| from mlx.nn.layers.linear import Bilinear, Identity, Linear | ||||
|   | ||||
							
								
								
									
										599
									
								
								python/mlx/nn/layers/distributed.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										599
									
								
								python/mlx/nn/layers/distributed.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,599 @@ | ||||
| # Copyright © 2024 Apple Inc. | ||||
|  | ||||
| import math | ||||
| from functools import lru_cache | ||||
| from typing import Callable, Optional, Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
| from mlx.nn.layers.linear import Linear | ||||
| from mlx.nn.layers.quantized import QuantizedLinear | ||||
| from mlx.utils import tree_map_with_path | ||||
|  | ||||
|  | ||||
| @lru_cache | ||||
| def sum_gradients(group): | ||||
|     if group.size() == 1: | ||||
|         return lambda x: x | ||||
|  | ||||
|     @mx.custom_function | ||||
|     def f(x): | ||||
|         return x | ||||
|  | ||||
|     @f.vjp | ||||
|     def f(x, dx, _): | ||||
|         return mx.distributed.all_sum(dx, group=group) | ||||
|  | ||||
|     return f | ||||
|  | ||||
|  | ||||
| def _split(weight, segments, axis): | ||||
|     """Equivalent to mx.split but allows for fractional segments.""" | ||||
|     if isinstance(segments, int) or isinstance(segments[0], int): | ||||
|         return mx.split(weight, segments, axis=axis) | ||||
|  | ||||
|     N = weight.shape[axis] | ||||
|     indices = [int(s * N) for s in segments] | ||||
|     return mx.split(weight, indices, axis=axis) | ||||
|  | ||||
|  | ||||
| def _shard( | ||||
|     parameters: dict, | ||||
|     sharding_predicate: Callable, | ||||
|     group: Optional[mx.distributed.Group] = None, | ||||
| ): | ||||
|     """Returns a new parameter tree with the weights sharded according to the | ||||
|     sharding_predicate. | ||||
|  | ||||
|     The sharding predicate should return the sharding axis and optionally also | ||||
|     the segments that comprise the weight. | ||||
|     """ | ||||
|     group = group or mx.distributed.init() | ||||
|     N = group.size() | ||||
|     r = group.rank() | ||||
|  | ||||
|     def _shard_fn(path, weight): | ||||
|         if not isinstance(weight, mx.array): | ||||
|             return weight | ||||
|  | ||||
|         s = sharding_predicate(path, weight) | ||||
|         if s is None: | ||||
|             return weight | ||||
|  | ||||
|         axis = None | ||||
|         segments = 1 | ||||
|         if isinstance(s, int): | ||||
|             axis = s | ||||
|         elif isinstance(s, tuple): | ||||
|             axis, segments = s | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "The sharding function should return int or tuple[int, list]" | ||||
|             ) | ||||
|  | ||||
|         return mx.contiguous( | ||||
|             mx.concatenate( | ||||
|                 [_split(part, N, axis)[r] for part in _split(weight, segments, axis)], | ||||
|                 axis=axis, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     return tree_map_with_path(_shard_fn, parameters) | ||||
|  | ||||
|  | ||||
| def _all_to_sharded(segments): | ||||
|     """Simple predicate to shard fully connected layers such that a common | ||||
|     representation becomes a sharded representation.""" | ||||
|  | ||||
|     def _shard_fn(path, weight): | ||||
|         return max(weight.ndim - 2, 0), segments | ||||
|  | ||||
|     return _shard_fn | ||||
|  | ||||
|  | ||||
| def _sharded_to_all(segments): | ||||
|     """Simple predicate to shard fully connected layers such that a sharded | ||||
|     representation becomes a common representation.""" | ||||
|  | ||||
|     def _shard_fn(path, weight): | ||||
|         if path.endswith("bias"): | ||||
|             return None | ||||
|         return -1, segments | ||||
|  | ||||
|     return _shard_fn | ||||
|  | ||||
|  | ||||
| def _check_sharding(sharding): | ||||
|     if sharding not in ("all-to-sharded", "sharded-to-all"): | ||||
|         raise ValueError( | ||||
|             ( | ||||
|                 f"Sharding type {sharding=} not supported, " | ||||
|                 "choose one of 'all-to-sharded' or 'sharded-to-all'" | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def shard_inplace( | ||||
|     module: Module, | ||||
|     sharding: Union[str, Callable], | ||||
|     *, | ||||
|     segments: Union[int, list] = 1, | ||||
|     group: Optional[mx.distributed.Group] = None, | ||||
| ): | ||||
|     """Shard a module in-place by updating its parameter dictionary with the | ||||
|     sharded parameter dictionary. | ||||
|  | ||||
|     The ``sharding`` argument can be any callable that given the path and the | ||||
|     weight returns the sharding axis and optionally also the segments that | ||||
|     comprise the unsharded weight. For instance if the weight is a fused QKV | ||||
|     matrix the segments should be 3. | ||||
|  | ||||
|     .. note:: | ||||
|         The module doesn't change so in order for distributed communication to | ||||
|         happen the module needs to natively support it and for it to be enabled. | ||||
|  | ||||
|     Args: | ||||
|         module (mlx.nn.Module): The parameters of this module will be sharded | ||||
|             in-place. | ||||
|         sharding (str or callable): One of "all-to-sharded" and | ||||
|             "sharded-to-all" or a callable that returns the sharding axis and | ||||
|             segments. | ||||
|         segments (int or list): The segments to use if ``sharding`` is a | ||||
|             string. Default: ``1``. | ||||
|         group (mlx.core.distributed.Group): The distributed group to shard | ||||
|             across. If not set, the global group will be used. Default: ``None``. | ||||
|     """ | ||||
|     if isinstance(sharding, str): | ||||
|         _check_sharding(sharding) | ||||
|         sharding = ( | ||||
|             _all_to_sharded(segments) | ||||
|             if sharding == "all-to-sharded" | ||||
|             else _sharded_to_all(segments) | ||||
|         ) | ||||
|     module.update(_shard(module.parameters(), sharding, group)) | ||||
|  | ||||
|  | ||||
| def shard_linear( | ||||
|     module: Module, | ||||
|     sharding: str, | ||||
|     *, | ||||
|     segments: Union[int, list] = 1, | ||||
|     group: Optional[mx.distributed.Group] = None, | ||||
| ): | ||||
|     """Create a new linear layer that has its parameters sharded and also | ||||
|     performs distributed communication either in the forward or backward | ||||
|     pass. | ||||
|  | ||||
|     .. note:: | ||||
|         Contrary to ``shard_inplace``, the original layer is not changed but a | ||||
|         new layer is returned. | ||||
|  | ||||
|     Args: | ||||
|         module (mlx.nn.Module): The linear layer to be sharded. | ||||
|         sharding (str): One of "all-to-sharded" and | ||||
|             "sharded-to-all" that defines the type of sharding to perform. | ||||
|         segments (int or list): The segments to use. Default: ``1``. | ||||
|         group (mlx.core.distributed.Group): The distributed group to shard | ||||
|             across. If not set, the global group will be used. Default: ``None``. | ||||
|     """ | ||||
|     _check_sharding(sharding) | ||||
|     fns = { | ||||
|         ("all-to-sharded", True): AllToShardedLinear.from_linear, | ||||
|         ("all-to-sharded", False): QuantizedAllToShardedLinear.from_quantized_linear, | ||||
|         ("sharded-to-all", True): ShardedToAllLinear.from_linear, | ||||
|         ("sharded-to-all", False): QuantizedShardedToAllLinear.from_quantized_linear, | ||||
|     } | ||||
|     return fns[sharding, isinstance(module, Linear)]( | ||||
|         module, segments=segments, group=group | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class AllToShardedLinear(Module): | ||||
|     """Each member of the group applies part of the affine transformation such | ||||
|     that the result is sharded across the group. | ||||
|  | ||||
|     The gradients are automatically aggregated from each member of the group. | ||||
|  | ||||
|     Args: | ||||
|         input_dims (int): The dimensionality of the input features | ||||
|         output_dims (int): The dimensionality of the output features | ||||
|         bias (bool, optional): If set to ``False`` the the layer will not use a | ||||
|             bias. Default is ``True``. | ||||
|         group (mx.distributed.Group, optional): The sharding will happen across | ||||
|             this group. If not set then the global group is used. Default is | ||||
|             ``None``. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_dims: int, | ||||
|         output_dims: int, | ||||
|         bias: bool = True, | ||||
|         group: Optional[mx.distributed.Group] = None, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         # Initialize the parameters | ||||
|         scale = math.sqrt(1.0 / input_dims) | ||||
|         self.group = group or mx.distributed.init() | ||||
|         N = self.group.size() | ||||
|  | ||||
|         if (output_dims % N) != 0: | ||||
|             raise ValueError( | ||||
|                 f"Cannot shard the output of size {output_dims} across {N} devices." | ||||
|             ) | ||||
|  | ||||
|         self.weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(output_dims // N, input_dims), | ||||
|         ) | ||||
|         if bias: | ||||
|             self.bias = mx.random.uniform( | ||||
|                 low=-scale, | ||||
|                 high=scale, | ||||
|                 shape=(output_dims // N,), | ||||
|             ) | ||||
|  | ||||
|     def _extra_repr(self) -> str: | ||||
|         out_dims, in_dims = self.weight.shape | ||||
|         N = self.group.size() | ||||
|         out_dims *= N | ||||
|         return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" | ||||
|  | ||||
|     def __call__(self, x: mx.array) -> mx.array: | ||||
|         # Aggregate the gradients coming from each shard | ||||
|         x = sum_gradients(self.group)(x) | ||||
|  | ||||
|         # Compute the affine projection | ||||
|         if "bias" in self: | ||||
|             x = mx.addmm(self["bias"], x, self["weight"].T) | ||||
|         else: | ||||
|             x = x @ self["weight"].T | ||||
|         return x | ||||
|  | ||||
|     @classmethod | ||||
|     def from_linear( | ||||
|         cls, | ||||
|         linear_layer: Module, | ||||
|         *, | ||||
|         segments: Union[int, list] = 1, | ||||
|         group: Optional[mx.distributed.Group] = None, | ||||
|     ): | ||||
|         group = group or mx.distributed.init() | ||||
|         output_dims, input_dims = linear_layer.weight.shape | ||||
|  | ||||
|         sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) | ||||
|         sl.update(_shard(linear_layer.parameters(), _all_to_sharded(segments), group)) | ||||
|  | ||||
|         return sl | ||||
|  | ||||
|  | ||||
| class ShardedToAllLinear(Module): | ||||
|     """Each member of the group applies part of the affine transformation and | ||||
|     then aggregates the results. | ||||
|  | ||||
|     All nodes will have the same exact result after this layer. | ||||
|  | ||||
|     :class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to | ||||
|     convert linear layers to sharded :obj:`ShardedToAllLinear` layers. | ||||
|  | ||||
|     Args: | ||||
|         input_dims (int): The dimensionality of the input features | ||||
|         output_dims (int): The dimensionality of the output features | ||||
|         bias (bool, optional): If set to ``False`` the the layer will not use a | ||||
|             bias. Default is ``True``. | ||||
|         group (mx.distributed.Group, optional): The sharding will happen across | ||||
|             this group. If not set then the global group is used. Default is | ||||
|             ``None``. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_dims: int, | ||||
|         output_dims: int, | ||||
|         bias: bool = True, | ||||
|         group: Optional[mx.distributed.Group] = None, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         # Initialize the parameters | ||||
|         scale = math.sqrt(1.0 / input_dims) | ||||
|         self.group = group or mx.distributed.init() | ||||
|         N = self.group.size() | ||||
|  | ||||
|         if (input_dims % N) != 0: | ||||
|             raise ValueError( | ||||
|                 f"The input of size {input_dims} cannot be sharded across {N} devices." | ||||
|             ) | ||||
|  | ||||
|         self.weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(output_dims, input_dims // N), | ||||
|         ) | ||||
|         if bias: | ||||
|             self.bias = mx.random.uniform( | ||||
|                 low=-scale, | ||||
|                 high=scale, | ||||
|                 shape=(output_dims,), | ||||
|             ) | ||||
|  | ||||
|     def _extra_repr(self) -> str: | ||||
|         N = self.group.size() | ||||
|         out_dims, in_dims = self.weight.shape | ||||
|         in_dims *= N | ||||
|         return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" | ||||
|  | ||||
|     def __call__(self, x: mx.array) -> mx.array: | ||||
|         x = x @ self["weight"].T | ||||
|  | ||||
|         x = mx.distributed.all_sum(x, group=self.group) | ||||
|  | ||||
|         if "bias" in self: | ||||
|             x = x + self["bias"] | ||||
|  | ||||
|         return x | ||||
|  | ||||
|     @classmethod | ||||
|     def from_linear( | ||||
|         cls, | ||||
|         linear_layer: Module, | ||||
|         *, | ||||
|         segments: Union[int, list] = 1, | ||||
|         group: Optional[mx.distributed.Group] = None, | ||||
|     ): | ||||
|         group = group or mx.distributed.init() | ||||
|         output_dims, input_dims = linear_layer.weight.shape | ||||
|  | ||||
|         sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) | ||||
|         sl.update(_shard(linear_layer.parameters(), _sharded_to_all(segments), group)) | ||||
|  | ||||
|         return sl | ||||
|  | ||||
|  | ||||
| class QuantizedAllToShardedLinear(Module): | ||||
|     """Each member of the group applies part of the affine transformation with | ||||
|     a quantized matrix such that the result is sharded across the group. | ||||
|  | ||||
|     It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. | ||||
|     Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and | ||||
|     will not be included in any gradient computation. | ||||
|  | ||||
|     Args: | ||||
|         input_dims (int): The dimensionality of the input features. | ||||
|         output_dims (int): The dimensionality of the output features. | ||||
|         bias (bool, optional): If set to ``False`` then the layer will not use | ||||
|             a bias. Default: ``True``. | ||||
|         group_size (int, optional): The group size to use for the quantized | ||||
|             weight. See :func:`~mlx.core.quantize`. Default: ``64``. | ||||
|         bits (int, optional): The bit width to use for the quantized weight. | ||||
|             See :func:`~mlx.core.quantize`. Default: ``4``. | ||||
|         group (mx.distributed.Group, optional): The sharding will happen across | ||||
|             this group. If not set then the global group is used. Default is | ||||
|             ``None``. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_dims: int, | ||||
|         output_dims: int, | ||||
|         bias: bool = True, | ||||
|         group_size: int = 64, | ||||
|         bits: int = 4, | ||||
|         group: Optional[mx.distributed.Group] = None, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         # Quantization config | ||||
|         self.group_size = group_size | ||||
|         self.bits = bits | ||||
|  | ||||
|         # Initialize the quantized weight | ||||
|         scale = math.sqrt(1.0 / input_dims) | ||||
|         self.group = group or mx.distributed.init() | ||||
|         N = self.group.size() | ||||
|  | ||||
|         if (output_dims % N) != 0: | ||||
|             raise ValueError( | ||||
|                 f"Cannot shard the output of size {output_dims} across {N} devices." | ||||
|             ) | ||||
|  | ||||
|         weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(output_dims // N, input_dims), | ||||
|         ) | ||||
|         self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) | ||||
|  | ||||
|         # And bias if needed | ||||
|         if bias: | ||||
|             self.bias = mx.zeros((output_dims // N,)) | ||||
|  | ||||
|         # Freeze this model's parameters | ||||
|         self.freeze() | ||||
|  | ||||
|     def unfreeze(self, *args, **kwargs): | ||||
|         """Wrap unfreeze so that we unfreeze any layers we might contain but | ||||
|         our parameters will remain frozen.""" | ||||
|         super().unfreeze(*args, **kwargs) | ||||
|         self.freeze(recurse=False) | ||||
|  | ||||
|     def _extra_repr(self) -> str: | ||||
|         out_dims, in_dims = self.weight.shape | ||||
|         in_dims *= 32 // self.bits | ||||
|         out_dims *= self.group.size() | ||||
|         return ( | ||||
|             f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " | ||||
|             f"group_size={self.group_size}, bits={self.bits}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x: mx.array) -> mx.array: | ||||
|         # Aggregate the gradients coming from each shard | ||||
|         x = sum_gradients(self.group)(x) | ||||
|  | ||||
|         x = mx.quantized_matmul( | ||||
|             x, | ||||
|             self["weight"], | ||||
|             scales=self["scales"], | ||||
|             biases=self["biases"], | ||||
|             transpose=True, | ||||
|             group_size=self.group_size, | ||||
|             bits=self.bits, | ||||
|         ) | ||||
|         if "bias" in self: | ||||
|             x = x + self["bias"] | ||||
|         return x | ||||
|  | ||||
|     @classmethod | ||||
|     def from_quantized_linear( | ||||
|         cls, | ||||
|         quantized_linear_layer: Module, | ||||
|         *, | ||||
|         segments: Union[int, list] = 1, | ||||
|         group: Optional[mx.distributed.Group] = None, | ||||
|     ): | ||||
|         group = group or mx.distributed.init() | ||||
|         output_dims, input_dims = quantized_linear_layer.weight.shape | ||||
|         input_dims *= 32 // quantized_linear_layer.bits | ||||
|  | ||||
|         sl = cls( | ||||
|             input_dims, | ||||
|             output_dims, | ||||
|             hasattr(quantized_linear_layer, "bias"), | ||||
|             group_size=quantized_linear_layer.group_size, | ||||
|             bits=quantized_linear_layer.bits, | ||||
|             group=group, | ||||
|         ) | ||||
|         sl.update( | ||||
|             _shard( | ||||
|                 quantized_linear_layer.parameters(), | ||||
|                 _all_to_sharded(segments), | ||||
|                 group, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         return sl | ||||
|  | ||||
|  | ||||
| class QuantizedShardedToAllLinear(Module): | ||||
|     """Each member of the group applies part of the affine transformation using | ||||
|     the quantized matrix and then aggregates the results. | ||||
|  | ||||
|     All nodes will have the same exact result after this layer. | ||||
|  | ||||
|     It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. | ||||
|     Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and | ||||
|     will not be included in any gradient computation. | ||||
|  | ||||
|     Args: | ||||
|         input_dims (int): The dimensionality of the input features. | ||||
|         output_dims (int): The dimensionality of the output features. | ||||
|         bias (bool, optional): If set to ``False`` then the layer will not use | ||||
|             a bias. Default: ``True``. | ||||
|         group_size (int, optional): The group size to use for the quantized | ||||
|             weight. See :func:`~mlx.core.quantize`. Default: ``64``. | ||||
|         bits (int, optional): The bit width to use for the quantized weight. | ||||
|             See :func:`~mlx.core.quantize`. Default: ``4``. | ||||
|         group (mx.distributed.Group, optional): The sharding will happen across | ||||
|             this group. If not set then the global group is used. Default is | ||||
|             ``None``. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         input_dims: int, | ||||
|         output_dims: int, | ||||
|         bias: bool = True, | ||||
|         group_size: int = 64, | ||||
|         bits: int = 4, | ||||
|         group: Optional[mx.distributed.Group] = None, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         # Quantization config | ||||
|         self.group_size = group_size | ||||
|         self.bits = bits | ||||
|  | ||||
|         # Initialize the quantized weight | ||||
|         scale = math.sqrt(1.0 / input_dims) | ||||
|         self.group = group or mx.distributed.init() | ||||
|         N = self.group.size() | ||||
|  | ||||
|         if (input_dims % N) != 0: | ||||
|             raise ValueError( | ||||
|                 f"The input of size {input_dims} cannot be sharded across {N} devices." | ||||
|             ) | ||||
|  | ||||
|         weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(output_dims, input_dims // N), | ||||
|         ) | ||||
|         self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) | ||||
|  | ||||
|         # And bias if needed | ||||
|         if bias: | ||||
|             self.bias = mx.zeros((output_dims,)) | ||||
|  | ||||
|         # Freeze this model's parameters | ||||
|         self.freeze() | ||||
|  | ||||
|     def unfreeze(self, *args, **kwargs): | ||||
|         """Wrap unfreeze so that we unfreeze any layers we might contain but | ||||
|         our parameters will remain frozen.""" | ||||
|         super().unfreeze(*args, **kwargs) | ||||
|         self.freeze(recurse=False) | ||||
|  | ||||
|     def _extra_repr(self) -> str: | ||||
|         out_dims, in_dims = self.weight.shape | ||||
|         in_dims *= (32 // self.bits) * self.group.size() | ||||
|         return ( | ||||
|             f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " | ||||
|             f"group_size={self.group_size}, bits={self.bits}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x: mx.array) -> mx.array: | ||||
|         x = mx.quantized_matmul( | ||||
|             x, | ||||
|             self["weight"], | ||||
|             scales=self["scales"], | ||||
|             biases=self["biases"], | ||||
|             transpose=True, | ||||
|             group_size=self.group_size, | ||||
|             bits=self.bits, | ||||
|         ) | ||||
|         x = mx.distributed.all_sum(x, group=self.group) | ||||
|         if "bias" in self: | ||||
|             x = x + self["bias"] | ||||
|         return x | ||||
|  | ||||
|     @classmethod | ||||
|     def from_quantized_linear( | ||||
|         cls, | ||||
|         quantized_linear_layer: Module, | ||||
|         *, | ||||
|         segments: Union[int, list] = 1, | ||||
|         group: Optional[mx.distributed.Group] = None, | ||||
|     ): | ||||
|         group = group or mx.distributed.init() | ||||
|         output_dims, input_dims = quantized_linear_layer.weight.shape | ||||
|         input_dims *= 32 // quantized_linear_layer.bits | ||||
|  | ||||
|         sl = cls( | ||||
|             input_dims, | ||||
|             output_dims, | ||||
|             hasattr(quantized_linear_layer, "bias"), | ||||
|             group_size=quantized_linear_layer.group_size, | ||||
|             bits=quantized_linear_layer.bits, | ||||
|             group=group, | ||||
|         ) | ||||
|         sl.update( | ||||
|             _shard( | ||||
|                 quantized_linear_layer.parameters(), | ||||
|                 _sharded_to_all(segments), | ||||
|                 group, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         return sl | ||||
| @@ -5124,4 +5124,23 @@ void init_ops(nb::module_& m) { | ||||
|                  [0, 1, 0], | ||||
|                  [0, 1, 0]], dtype=float32) | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "contiguous", | ||||
|       &mx::contiguous, | ||||
|       nb::arg(), | ||||
|       "allow_col_major"_a = false, | ||||
|       nb::kw_only(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def contiguous(a: array, /, allow_col_major: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|       Force an array to be row contiguous. Copy if necessary. | ||||
|  | ||||
|       Args: | ||||
|         a (array): The input to make contiguous | ||||
|         allow_col_major (bool): Consider column major as contiguous and don't copy | ||||
|  | ||||
|       Returns: | ||||
|         array: The row or col contiguous output. | ||||
|     )pbdoc"); | ||||
| } | ||||
|   | ||||
							
								
								
									
										250
									
								
								python/tests/mlx_distributed_tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										250
									
								
								python/tests/mlx_distributed_tests.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,250 @@ | ||||
| # Copyright © 2025 Apple Inc. | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx_tests | ||||
| from mlx.nn.layers.distributed import shard_inplace, shard_linear | ||||
| from mlx.nn.utils import average_gradients | ||||
|  | ||||
|  | ||||
| class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase): | ||||
|     def test_average_gradients(self): | ||||
|         original_all_sum = mx.distributed.all_sum | ||||
|         n_calls = 0 | ||||
|         xtype = None | ||||
|  | ||||
|         def new_all_sum(x, **kwargs): | ||||
|             nonlocal n_calls | ||||
|             nonlocal xtype | ||||
|  | ||||
|             n_calls += 1 | ||||
|             if xtype is not None: | ||||
|                 self.assertEqual(xtype, x.dtype) | ||||
|  | ||||
|             return original_all_sum(x, **kwargs) | ||||
|  | ||||
|         mx.distributed.all_sum = new_all_sum | ||||
|  | ||||
|         try: | ||||
|             grads = [mx.ones(10) for i in range(10)] | ||||
|             new_grads = average_gradients(grads) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 1) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=4 * 50) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=0) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 10) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             xtype = mx.float16 | ||||
|             new_grads = average_gradients( | ||||
|                 grads, all_reduce_size=2 * 50, communication_type=mx.float16 | ||||
|             ) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|         finally: | ||||
|             mx.distributed.all_sum = original_all_sum | ||||
|  | ||||
|     def test_donation(self): | ||||
|         x = mx.random.normal((1024,)) | ||||
|         mx.eval(x) | ||||
|         mx.synchronize() | ||||
|  | ||||
|         mx.reset_peak_memory() | ||||
|         scale = mx.array(2.0) | ||||
|         y = mx.distributed.all_sum(x) | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_only = mx.get_peak_memory() | ||||
|         y = mx.distributed.all_sum(x) * scale | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_with_binary = mx.get_peak_memory() | ||||
|  | ||||
|         self.assertEqual(all_sum_only, all_sum_with_binary) | ||||
|  | ||||
|     def test_shard_linear(self): | ||||
|         # Seed the prng to have the same inputs and weights generated everywhere | ||||
|         mx.random.seed(0xF0F0F0F0) | ||||
|  | ||||
|         # Prepare inputs | ||||
|         world = mx.distributed.init() | ||||
|         part = ( | ||||
|             slice(None), | ||||
|             slice( | ||||
|                 world.rank() * 1024 // world.size(), | ||||
|                 (world.rank() + 1) * 1024 // world.size(), | ||||
|             ), | ||||
|         ) | ||||
|         x = mx.random.normal((4, 1024)) | ||||
|  | ||||
|         # Create and shard some linear layers | ||||
|         lin = nn.Linear(1024, 1024, bias=True) | ||||
|         slin1 = shard_linear(lin, "all-to-sharded") | ||||
|         slin2 = shard_linear(lin, "sharded-to-all") | ||||
|         y = lin(x) | ||||
|         y1 = slin1(x) | ||||
|         y2 = slin2(x[part]) | ||||
|         self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) | ||||
|         self.assertTrue(mx.allclose(y[part], y1)) | ||||
|  | ||||
|         # And their quant versions | ||||
|         qlin = lin.to_quantized() | ||||
|         slin1 = shard_linear(qlin, "all-to-sharded") | ||||
|         slin2 = shard_linear(qlin, "sharded-to-all") | ||||
|         y = qlin(x) | ||||
|         y1 = slin1(x) | ||||
|         y2 = slin2(x[part]) | ||||
|         self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) | ||||
|         self.assertTrue(mx.allclose(y[part], y1)) | ||||
|  | ||||
|         # Check the backward works as expected | ||||
|         def dummy_loss(model, x, y): | ||||
|             return (model(x) * y).sum() | ||||
|  | ||||
|         mod = nn.Sequential( | ||||
|             nn.Linear(128, 128), | ||||
|             nn.Linear(128, 128), | ||||
|             nn.Linear(128, 128), | ||||
|             nn.Linear(128, 128), | ||||
|         ) | ||||
|         smod = nn.Sequential( | ||||
|             shard_linear(mod.layers[0], "all-to-sharded"), | ||||
|             shard_linear(mod.layers[1], "sharded-to-all"), | ||||
|             shard_linear(mod.layers[2], "all-to-sharded"), | ||||
|             shard_linear(mod.layers[3], "sharded-to-all"), | ||||
|         ) | ||||
|  | ||||
|         grad1 = nn.value_and_grad(mod, dummy_loss) | ||||
|         grad2 = nn.value_and_grad(smod, dummy_loss) | ||||
|  | ||||
|         x = mx.random.normal((4, 128)) | ||||
|         y = mx.random.normal((4, 128)) | ||||
|  | ||||
|         l1, g1 = grad1(mod, x, y) | ||||
|         l2, g2 = grad2(smod, x, y) | ||||
|         mx.eval(l1, g1, l2, g2) | ||||
|  | ||||
|         part = slice( | ||||
|             world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size() | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(l1, l2)) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][0]["weight"][part], | ||||
|                 g2["layers"][0]["weight"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][2]["weight"][part], | ||||
|                 g2["layers"][2]["weight"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][1]["weight"][:, part], | ||||
|                 g2["layers"][1]["weight"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][3]["weight"][:, part], | ||||
|                 g2["layers"][3]["weight"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][0]["bias"][part], | ||||
|                 g2["layers"][0]["bias"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][2]["bias"][part], | ||||
|                 g2["layers"][2]["bias"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4 | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4 | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_shard_predicate(self): | ||||
|         mx.random.seed(0xF0F0F0F0) | ||||
|  | ||||
|         class MyConv(nn.Module): | ||||
|             def __init__(self, *args, **kwargs): | ||||
|                 super().__init__() | ||||
|                 self.aggregate = kwargs.pop("aggregate", False) | ||||
|                 self.conv = nn.Conv2d(*args, **kwargs) | ||||
|  | ||||
|             def __call__(self, x): | ||||
|                 x = self.conv(x) | ||||
|                 if self.aggregate: | ||||
|                     x = mx.distributed.all_sum(x) | ||||
|                 return x | ||||
|  | ||||
|         def sharding(path, weight): | ||||
|             parts = path.split(".") | ||||
|             even = int(parts[1]) % 2 == 0 | ||||
|             if even: | ||||
|                 return 0 | ||||
|             else: | ||||
|                 return -1 if parts[-1] != "bias" else None | ||||
|  | ||||
|         mod = nn.Sequential( | ||||
|             MyConv(3, 128, kernel_size=3), | ||||
|             MyConv(128, 128, kernel_size=3), | ||||
|             MyConv(128, 128, kernel_size=3), | ||||
|             MyConv(128, 3, kernel_size=3), | ||||
|         ) | ||||
|         smod = nn.Sequential( | ||||
|             MyConv(3, 128, kernel_size=3), | ||||
|             MyConv(128, 128, kernel_size=3, aggregate=True), | ||||
|             MyConv(128, 128, kernel_size=3), | ||||
|             MyConv(128, 3, kernel_size=3, aggregate=True), | ||||
|         ) | ||||
|         smod.update(mod.parameters()) | ||||
|         shard_inplace(smod, sharding) | ||||
|  | ||||
|         x = mx.random.normal((4, 16, 16, 3)) | ||||
|         y1 = mod(x) | ||||
|         y2 = smod(x) | ||||
|         self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) | ||||
| @@ -3,11 +3,14 @@ | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
| from mlx.nn.utils import average_gradients | ||||
| import mlx_distributed_tests | ||||
|  | ||||
|  | ||||
| class TestDistributed(mlx_tests.MLXTestCase): | ||||
| class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         world = mx.distributed.init(strict=True, backend="mpi") | ||||
|  | ||||
|     def test_groups(self): | ||||
|         world = mx.distributed.init() | ||||
|         self.assertEqual(world.size(), 8) | ||||
| @@ -121,77 +124,6 @@ class TestDistributed(mlx_tests.MLXTestCase): | ||||
|             x = mx.distributed.recv_like(x, neighbor, group=pairs) | ||||
|         mx.eval(y, x) | ||||
|  | ||||
|     def test_average_gradients(self): | ||||
|         original_all_sum = mx.distributed.all_sum | ||||
|         n_calls = 0 | ||||
|         xtype = None | ||||
|  | ||||
|         def new_all_sum(x, **kwargs): | ||||
|             nonlocal n_calls | ||||
|             nonlocal xtype | ||||
|  | ||||
|             n_calls += 1 | ||||
|             if xtype is not None: | ||||
|                 self.assertEqual(xtype, x.dtype) | ||||
|  | ||||
|             return original_all_sum(x, **kwargs) | ||||
|  | ||||
|         mx.distributed.all_sum = new_all_sum | ||||
|  | ||||
|         try: | ||||
|             grads = [mx.ones(10) for i in range(10)] | ||||
|             new_grads = average_gradients(grads) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 1) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=4 * 50) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=0) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 10) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             xtype = mx.float16 | ||||
|             new_grads = average_gradients( | ||||
|                 grads, all_reduce_size=2 * 50, communication_type=mx.float16 | ||||
|             ) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|         finally: | ||||
|             mx.distributed.all_sum = original_all_sum | ||||
|  | ||||
|     def test_donation(self): | ||||
|         x = mx.random.normal((1024,)) | ||||
|         mx.eval(x) | ||||
|         mx.synchronize() | ||||
|  | ||||
|         mx.reset_peak_memory() | ||||
|         scale = mx.array(2.0) | ||||
|         y = mx.distributed.all_sum(x) | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_only = mx.get_peak_memory() | ||||
|         y = mx.distributed.all_sum(x) * scale | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_with_binary = mx.get_peak_memory() | ||||
|  | ||||
|         self.assertEqual(all_sum_only, all_sum_with_binary) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -3,10 +3,10 @@ | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
| import mlx_distributed_tests | ||||
|  | ||||
|  | ||||
| class TestRingDistributed(mlx_tests.MLXTestCase): | ||||
| class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         world = mx.distributed.init(strict=True, backend="ring") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user