mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Fixes in distributed layers
This commit is contained in:
parent
a8b3da7946
commit
16975815e9
@ -1,5 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
@ -168,7 +169,7 @@ class ShardedToAllLinear(Module):
|
||||
if self.group.size() > 1:
|
||||
# Perform the local projection and aggregate the results
|
||||
x = x @ self["weight"].T
|
||||
x = mx.distributed.all_sum(x, group=group)
|
||||
x = mx.distributed.all_sum(x, group=self.group)
|
||||
|
||||
# Add the bias if we have one
|
||||
if "bias" in self:
|
||||
@ -316,9 +317,9 @@ class QuantizedAllToShardedLinear(Module):
|
||||
bits=quantized_linear_layer.bits,
|
||||
group=group,
|
||||
)
|
||||
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
|
||||
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
|
||||
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
|
||||
sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1
|
||||
sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1
|
||||
sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1
|
||||
if "bias" in quantized_linear_layer:
|
||||
sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1
|
||||
|
||||
@ -413,7 +414,7 @@ class QuantizedShardedToAllLinear(Module):
|
||||
bits=self.bits,
|
||||
)
|
||||
if self.group.size() > 1:
|
||||
x = mx.distributed.sum_all(x, group=group)
|
||||
x = mx.distributed.all_sum(x, group=self.group)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
@ -428,6 +429,8 @@ class QuantizedShardedToAllLinear(Module):
|
||||
N = group.size()
|
||||
r = group.rank()
|
||||
output_dims, input_dims = quantized_linear_layer.weight.shape
|
||||
step = input_dims // N
|
||||
step_grouped = quantized_linear_layer.scales.shape[1] // N
|
||||
input_dims *= (32 // quantized_linear_layer.bits) * N
|
||||
|
||||
sl = cls(
|
||||
@ -438,9 +441,15 @@ class QuantizedShardedToAllLinear(Module):
|
||||
bits=quantized_linear_layer.bits,
|
||||
group=group,
|
||||
)
|
||||
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
|
||||
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
|
||||
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
|
||||
sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1
|
||||
sl.scales = (
|
||||
quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped]
|
||||
* 1
|
||||
)
|
||||
sl.biases = (
|
||||
quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped]
|
||||
* 1
|
||||
)
|
||||
if "bias" in quantized_linear_layer:
|
||||
sl.bias = quantized_linear_layer.bias
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user