Start distributed inference for llama models

This commit is contained in:
Angelos Katharopoulos 2024-07-15 13:24:50 -07:00 committed by Awni Hannun
parent e2e5478da5
commit d77840207c
3 changed files with 36 additions and 1 deletions

View File

@ -239,8 +239,8 @@ def main():
tokenizer, tokenizer,
prompt, prompt,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
verbose=args.verbose,
sampler=sampler, sampler=sampler,
verbose=args.verbose and mx.distributed.init().rank() == 0,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,

View File

@ -200,6 +200,36 @@ class Model(nn.Module):
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
} }
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
def all_to_sharded(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
else:
return nn.AllToShardedLinear.from_linear(l, group)
def sharded_to_all(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
else:
return nn.ShardedToAllLinear.from_linear(l, group)
N = group.size()
for layer in self.model.layers:
# Shard the self attention
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
layer.self_attn.n_heads //= N
layer.self_attn.n_kv_heads //= N
# Shard the MLP
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers

View File

@ -699,6 +699,11 @@ def load_model(
model.load_weights(list(weights.items()), strict=strict) model.load_weights(list(weights.items()), strict=strict)
if mx.distributed.init().size() > 1:
if not hasattr(model, "shard"):
raise RuntimeError("Model doesn't support distributed inference.")
model.shard()
if not lazy: if not lazy:
mx.eval(model.parameters()) mx.eval(model.parameters())