mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-29 20:58:11 +08:00
Start distributed inference for llama models
This commit is contained in:

committed by
Awni Hannun

parent
e2e5478da5
commit
d77840207c
@@ -200,6 +200,36 @@ class Model(nn.Module):
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||
group = group or mx.distributed.init()
|
||||
|
||||
def all_to_sharded(l):
|
||||
if isinstance(l, nn.QuantizedLinear):
|
||||
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
|
||||
else:
|
||||
return nn.AllToShardedLinear.from_linear(l, group)
|
||||
|
||||
def sharded_to_all(l):
|
||||
if isinstance(l, nn.QuantizedLinear):
|
||||
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
|
||||
else:
|
||||
return nn.ShardedToAllLinear.from_linear(l, group)
|
||||
|
||||
N = group.size()
|
||||
for layer in self.model.layers:
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= N
|
||||
layer.self_attn.n_kv_heads //= N
|
||||
|
||||
# Shard the MLP
|
||||
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
Reference in New Issue
Block a user