mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-26 04:21:17 +08:00
Fixes in distributed layers
This commit is contained in:
parent
a8b3da7946
commit
16975815e9
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -168,7 +169,7 @@ class ShardedToAllLinear(Module):
|
|||||||
if self.group.size() > 1:
|
if self.group.size() > 1:
|
||||||
# Perform the local projection and aggregate the results
|
# Perform the local projection and aggregate the results
|
||||||
x = x @ self["weight"].T
|
x = x @ self["weight"].T
|
||||||
x = mx.distributed.all_sum(x, group=group)
|
x = mx.distributed.all_sum(x, group=self.group)
|
||||||
|
|
||||||
# Add the bias if we have one
|
# Add the bias if we have one
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
@ -316,9 +317,9 @@ class QuantizedAllToShardedLinear(Module):
|
|||||||
bits=quantized_linear_layer.bits,
|
bits=quantized_linear_layer.bits,
|
||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
|
sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1
|
||||||
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
|
sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1
|
||||||
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
|
sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1
|
||||||
if "bias" in quantized_linear_layer:
|
if "bias" in quantized_linear_layer:
|
||||||
sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1
|
sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1
|
||||||
|
|
||||||
@ -413,7 +414,7 @@ class QuantizedShardedToAllLinear(Module):
|
|||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
)
|
)
|
||||||
if self.group.size() > 1:
|
if self.group.size() > 1:
|
||||||
x = mx.distributed.sum_all(x, group=group)
|
x = mx.distributed.all_sum(x, group=self.group)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
x = x + self["bias"]
|
x = x + self["bias"]
|
||||||
return x
|
return x
|
||||||
@ -428,6 +429,8 @@ class QuantizedShardedToAllLinear(Module):
|
|||||||
N = group.size()
|
N = group.size()
|
||||||
r = group.rank()
|
r = group.rank()
|
||||||
output_dims, input_dims = quantized_linear_layer.weight.shape
|
output_dims, input_dims = quantized_linear_layer.weight.shape
|
||||||
|
step = input_dims // N
|
||||||
|
step_grouped = quantized_linear_layer.scales.shape[1] // N
|
||||||
input_dims *= (32 // quantized_linear_layer.bits) * N
|
input_dims *= (32 // quantized_linear_layer.bits) * N
|
||||||
|
|
||||||
sl = cls(
|
sl = cls(
|
||||||
@ -438,9 +441,15 @@ class QuantizedShardedToAllLinear(Module):
|
|||||||
bits=quantized_linear_layer.bits,
|
bits=quantized_linear_layer.bits,
|
||||||
group=group,
|
group=group,
|
||||||
)
|
)
|
||||||
sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1
|
sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1
|
||||||
sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1
|
sl.scales = (
|
||||||
sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1
|
quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped]
|
||||||
|
* 1
|
||||||
|
)
|
||||||
|
sl.biases = (
|
||||||
|
quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped]
|
||||||
|
* 1
|
||||||
|
)
|
||||||
if "bias" in quantized_linear_layer:
|
if "bias" in quantized_linear_layer:
|
||||||
sl.bias = quantized_linear_layer.bias
|
sl.bias = quantized_linear_layer.bias
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user