From 0b04742985565fe4beb3903e68b0ef0d3bc2bcd8 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 12 Jul 2024 11:39:21 -0700 Subject: [PATCH 1/4] Add the distributed linear layers --- python/mlx/nn/layers/distributed.py | 156 ++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 python/mlx/nn/layers/distributed.py diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py new file mode 100644 index 000000000..01686792d --- /dev/null +++ b/python/mlx/nn/layers/distributed.py @@ -0,0 +1,156 @@ +# Copyright © 2024 Apple Inc. + +from functools import lru_cache +from typing import Optional + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +@lru_cache +def sum_gradients(group): + if group.size() == 1: + return lambda x: x + + @mx.custom_function + def f(x): + return x + + @f.vjp + def f(x, dx, _): + return mx.distributed.all_sum(dx, group=group) + + return f + + +class AllToShardedLinear(Module): + """Each member of the group applies part of the affine transformation such + that the result is sharded across the group. + + The gradients are automatically aggregated from each member of the group. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` the the layer will not use a + bias. Default is ``True``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Initialize the parameters + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (output_dims % N) != 0: + raise ValueError( + f"Cannot shard the output of size {output_dims} across {N} devices." + ) + + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N, input_dims), + ) + if bias: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N,), + ) + + def _extra_repr(self) -> str: + N = self.group.size() + return f"input_dims={self.weight.shape[1]}, output_dims={N * self.weight.shape[0]}, bias={'bias' in self}" + + def __call__(self, x: mx.array) -> mx.array: + # Aggregate the gradients coming from each shard + if self.group.size() > 1: + x = sum_gradients(self.group)(x) + + # Compute the affine projection + if "bias" in self: + x = mx.addmm(self["bias"], x, self["weight"].T) + else: + x = x @ self["weight"].T + return x + + +class ShardedToAllLinear(Module): + """Each member of the group applies part of the affine transformation and + then aggregates the results. + + All nodes will have the same exact result after this layer. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` the the layer will not use a + bias. Default is ``True``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Initialize the parameters + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (input_dims % N) != 0: + raise ValueError( + f"The input of size {input_dims} cannot be sharded across {N} devices." + ) + + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims // N), + ) + if bias: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims,), + ) + + def _extra_repr(self) -> str: + N = self.group.size() + return f"input_dims={N * self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" + + def __call__(self, x: mx.array) -> mx.array: + if self.group.size() > 1: + # Perform the local projection and aggregate the results + x = x @ self["weight"].T + x = mx.distributed.all_sum(x, group=group) + + # Add the bias if we have one + if "bias" in self: + x = x + self["bias"] + else: + # Normal linear layer as we are not in a distributed setting. + if "bias" in self: + x = mx.addmm(self["bias"], x, self["weight"].T) + else: + x = x @ self["weight"].T + return x From 060e1c9f92fa9451a9f227a9acb28bd712771880 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 15 Jul 2024 13:03:49 -0700 Subject: [PATCH 2/4] Add quantized distributed layers --- python/mlx/nn/layers/distributed.py | 295 +++++++++++++++++++++++++++- python/mlx/nn/layers/quantized.py | 2 +- 2 files changed, 294 insertions(+), 3 deletions(-) diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py index 01686792d..c29cd81d0 100644 --- a/python/mlx/nn/layers/distributed.py +++ b/python/mlx/nn/layers/distributed.py @@ -71,8 +71,10 @@ class AllToShardedLinear(Module): ) def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape N = self.group.size() - return f"input_dims={self.weight.shape[1]}, output_dims={N * self.weight.shape[0]}, bias={'bias' in self}" + out_dims *= N + return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" def __call__(self, x: mx.array) -> mx.array: # Aggregate the gradients coming from each shard @@ -86,6 +88,25 @@ class AllToShardedLinear(Module): x = x @ self["weight"].T return x + @classmethod + def from_linear( + cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None + ): + group = group or mx.distributed.init() + N = group.size() + r = group.rank() + output_dims, input_dims = linear_layer.weight.shape + step = output_dims // N + + sl = cls(input_dims, output_dims, False, group) + # The multiplication with 1.0 forces a copy, perhaps change to + # something better when available. + sl.weight = linear_layer.weight[r * step : (r + 1) * step] * 1 + if "bias" in linear_layer: + sl.bias = linear_layer.bias[r * step : (r + 1) * step] * 1 + + return sl + class ShardedToAllLinear(Module): """Each member of the group applies part of the affine transformation and @@ -93,6 +114,9 @@ class ShardedToAllLinear(Module): All nodes will have the same exact result after this layer. + :class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to + convert linear layers to sharded :obj:`ShardedToAllLinear` layers. + Args: input_dims (int): The dimensionality of the input features output_dims (int): The dimensionality of the output features @@ -136,7 +160,9 @@ class ShardedToAllLinear(Module): def _extra_repr(self) -> str: N = self.group.size() - return f"input_dims={N * self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" + out_dims, in_dims = self.weight.shape + in_dims *= N + return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" def __call__(self, x: mx.array) -> mx.array: if self.group.size() > 1: @@ -154,3 +180,268 @@ class ShardedToAllLinear(Module): else: x = x @ self["weight"].T return x + + @classmethod + def from_linear( + cls, linear_layer: Module, group: Optional[mx.distributed.Group] = None + ): + group = group or mx.distributed.init() + N = group.size() + r = group.rank() + output_dims, input_dims = linear_layer.weight.shape + step = input_dims // N + + sl = cls(input_dims, output_dims, False, group) + # The multiplication with 1.0 forces a copy, perhaps change to + # something better when available. + sl.weight = linear_layer.weight[:, r * step : (r + 1) * step] * 1 + if "bias" in linear_layer: + sl.bias = linear_layer.bias + + return sl + + +class QuantizedAllToShardedLinear(Module): + """Each member of the group applies part of the affine transformation with + a quantized matrix such that the result is sharded across the group. + + It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. + Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and + will not be included in any gradient computation. + + Args: + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. Default: ``True``. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``64``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Quantization config + self.group_size = group_size + self.bits = bits + + # Initialize the quantized weight + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (output_dims % N) != 0: + raise ValueError( + f"Cannot shard the output of size {output_dims} across {N} devices." + ) + + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N, input_dims), + ) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + + # And bias if needed + if bias: + self.bias = mx.zeros((output_dims // N,)) + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + in_dims *= 32 // self.bits + out_dims *= self.group.size() + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " + f"group_size={self.group_size}, bits={self.bits}" + ) + + def __call__(self, x: mx.array) -> mx.array: + # Aggregate the gradients coming from each shard + if self.group.size() > 1: + x = sum_gradients(self.group)(x) + + x = mx.quantized_matmul( + x, + self["weight"], + scales=self["scales"], + biases=self["biases"], + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + if "bias" in self: + x = x + self["bias"] + return x + + @classmethod + def from_quantized_linear( + cls, + quantized_linear_layer: Module, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + N = group.size() + r = group.rank() + output_dims, input_dims = quantized_linear_layer.weight.shape + input_dims *= 32 // quantized_linear_layer.bits + step = output_dims // N + + sl = cls( + input_dims, + output_dims, + False, + group_size=quantized_linear_layer.group_size, + bits=quantized_linear_layer.bits, + group=group, + ) + sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1 + sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1 + sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1 + if "bias" in quantized_linear_layer: + sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1 + + return sl + + +class QuantizedShardedToAllLinear(Module): + """Each member of the group applies part of the affine transformation using + the quantized matrix and then aggregates the results. + + All nodes will have the same exact result after this layer. + + It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. + Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and + will not be included in any gradient computation. + + Args: + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. Default: ``True``. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``64``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Quantization config + self.group_size = group_size + self.bits = bits + + # Initialize the quantized weight + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (input_dims % N) != 0: + raise ValueError( + f"The input of size {input_dims} cannot be sharded across {N} devices." + ) + + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims // N), + ) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + + # And bias if needed + if bias: + self.bias = mx.zeros((output_dims,)) + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + in_dims *= (32 // self.bits) * self.group.size() + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " + f"group_size={self.group_size}, bits={self.bits}" + ) + + def __call__(self, x: mx.array) -> mx.array: + x = mx.quantized_matmul( + x, + self["weight"], + scales=self["scales"], + biases=self["biases"], + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + if self.group.size() > 1: + x = mx.distributed.sum_all(x, group=group) + if "bias" in self: + x = x + self["bias"] + return x + + @classmethod + def from_quantized_linear( + cls, + quantized_linear_layer: Module, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + N = group.size() + r = group.rank() + output_dims, input_dims = quantized_linear_layer.weight.shape + input_dims *= (32 // quantized_linear_layer.bits) * N + + sl = cls( + input_dims, + output_dims, + False, + group_size=quantized_linear_layer.group_size, + bits=quantized_linear_layer.bits, + group=group, + ) + sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1 + sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1 + sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1 + if "bias" in quantized_linear_layer: + sl.bias = quantized_linear_layer.bias + + return sl diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 48c2ce13a..84aa81364 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -197,7 +197,7 @@ class QuantizedLinear(Module): out_dims, in_dims = self.weight.shape in_dims *= 32 // self.bits return ( - f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}," + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " f"group_size={self.group_size}, bits={self.bits}" ) From a8b3da7946cb49ce32eba262e9ab321e524cbd2e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 15 Jul 2024 13:23:51 -0700 Subject: [PATCH 3/4] Add distributed layers to nn top-level --- python/mlx/nn/layers/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 3cf5e33a8..1dd530765 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import ( ConvTranspose2d, ConvTranspose3d, ) +from mlx.nn.layers.distributed import ( + AllToShardedLinear, + QuantizedAllToShardedLinear, + QuantizedShardedToAllLinear, + ShardedToAllLinear, +) from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear From 16975815e94d87cc56edf977513122bc37b4e039 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 15 Jul 2024 13:49:38 -0700 Subject: [PATCH 4/4] Fixes in distributed layers --- python/mlx/nn/layers/distributed.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py index c29cd81d0..f1cef52e6 100644 --- a/python/mlx/nn/layers/distributed.py +++ b/python/mlx/nn/layers/distributed.py @@ -1,5 +1,6 @@ # Copyright © 2024 Apple Inc. +import math from functools import lru_cache from typing import Optional @@ -168,7 +169,7 @@ class ShardedToAllLinear(Module): if self.group.size() > 1: # Perform the local projection and aggregate the results x = x @ self["weight"].T - x = mx.distributed.all_sum(x, group=group) + x = mx.distributed.all_sum(x, group=self.group) # Add the bias if we have one if "bias" in self: @@ -316,9 +317,9 @@ class QuantizedAllToShardedLinear(Module): bits=quantized_linear_layer.bits, group=group, ) - sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1 - sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1 - sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1 + sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1 + sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1 + sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1 if "bias" in quantized_linear_layer: sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1 @@ -413,7 +414,7 @@ class QuantizedShardedToAllLinear(Module): bits=self.bits, ) if self.group.size() > 1: - x = mx.distributed.sum_all(x, group=group) + x = mx.distributed.all_sum(x, group=self.group) if "bias" in self: x = x + self["bias"] return x @@ -428,6 +429,8 @@ class QuantizedShardedToAllLinear(Module): N = group.size() r = group.rank() output_dims, input_dims = quantized_linear_layer.weight.shape + step = input_dims // N + step_grouped = quantized_linear_layer.scales.shape[1] // N input_dims *= (32 // quantized_linear_layer.bits) * N sl = cls( @@ -438,9 +441,15 @@ class QuantizedShardedToAllLinear(Module): bits=quantized_linear_layer.bits, group=group, ) - sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1 - sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1 - sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1 + sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1 + sl.scales = ( + quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped] + * 1 + ) + sl.biases = ( + quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped] + * 1 + ) if "bias" in quantized_linear_layer: sl.bias = quantized_linear_layer.bias