mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Start distributed inference for llama models
This commit is contained in:
parent
e2e5478da5
commit
d77840207c
@ -239,8 +239,8 @@ def main():
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
verbose=args.verbose,
|
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
verbose=args.verbose and mx.distributed.init().rank() == 0,
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
|
@ -200,6 +200,36 @@ class Model(nn.Module):
|
|||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||||
|
group = group or mx.distributed.init()
|
||||||
|
|
||||||
|
def all_to_sharded(l):
|
||||||
|
if isinstance(l, nn.QuantizedLinear):
|
||||||
|
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
|
||||||
|
else:
|
||||||
|
return nn.AllToShardedLinear.from_linear(l, group)
|
||||||
|
|
||||||
|
def sharded_to_all(l):
|
||||||
|
if isinstance(l, nn.QuantizedLinear):
|
||||||
|
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
|
||||||
|
else:
|
||||||
|
return nn.ShardedToAllLinear.from_linear(l, group)
|
||||||
|
|
||||||
|
N = group.size()
|
||||||
|
for layer in self.model.layers:
|
||||||
|
# Shard the self attention
|
||||||
|
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
|
||||||
|
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
|
||||||
|
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
|
||||||
|
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
|
||||||
|
layer.self_attn.n_heads //= N
|
||||||
|
layer.self_attn.n_kv_heads //= N
|
||||||
|
|
||||||
|
# Shard the MLP
|
||||||
|
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
|
||||||
|
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
|
||||||
|
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
@ -699,6 +699,11 @@ def load_model(
|
|||||||
|
|
||||||
model.load_weights(list(weights.items()), strict=strict)
|
model.load_weights(list(weights.items()), strict=strict)
|
||||||
|
|
||||||
|
if mx.distributed.init().size() > 1:
|
||||||
|
if not hasattr(model, "shard"):
|
||||||
|
raise RuntimeError("Model doesn't support distributed inference.")
|
||||||
|
model.shard()
|
||||||
|
|
||||||
if not lazy:
|
if not lazy:
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user