Fixes in distributed layers

This commit is contained in:
Angelos Katharopoulos 2024-07-15 13:49:38 -07:00
parent a8b3da7946
commit 16975815e9

View File

@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.
import math
from functools import lru_cache
from typing import Optional
@ -168,7 +169,7 @@ class ShardedToAllLinear(Module):
if self.group.size() > 1:
# Perform the local projection and aggregate the results
x = x @ self["weight"].T
x = mx.distributed.all_sum(x, group=group)
x = mx.distributed.all_sum(x, group=self.group)
# Add the bias if we have one
if "bias" in self:
@ -316,9 +317,9 @@ class QuantizedAllToShardedLinear(Module):
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1
@ -413,7 +414,7 @@ class QuantizedShardedToAllLinear(Module):
bits=self.bits,
)
if self.group.size() > 1:
x = mx.distributed.sum_all(x, group=group)
x = mx.distributed.all_sum(x, group=self.group)
if "bias" in self:
x = x + self["bias"]
return x
@ -428,6 +429,8 @@ class QuantizedShardedToAllLinear(Module):
N = group.size()
r = group.rank()
output_dims, input_dims = quantized_linear_layer.weight.shape
step = input_dims // N
step_grouped = quantized_linear_layer.scales.shape[1] // N
input_dims *= (32 // quantized_linear_layer.bits) * N
sl = cls(
@ -438,9 +441,15 @@ class QuantizedShardedToAllLinear(Module):
bits=quantized_linear_layer.bits,
group=group,
)
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1
sl.scales = (
quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped]
* 1
)
sl.biases = (
quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped]
* 1
)
if "bias" in quantized_linear_layer:
sl.bias = quantized_linear_layer.bias