From d77840207cb144867a5ffcd4bced65e1b615849e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 15 Jul 2024 13:24:50 -0700 Subject: [PATCH] Start distributed inference for llama models --- llms/mlx_lm/generate.py | 2 +- llms/mlx_lm/models/llama.py | 30 ++++++++++++++++++++++++++++++ llms/mlx_lm/utils.py | 5 +++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0d286c75..aa7a4a2f 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -239,8 +239,8 @@ def main(): tokenizer, prompt, max_tokens=args.max_tokens, - verbose=args.verbose, sampler=sampler, + verbose=args.verbose and mx.distributed.init().rank() == 0, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7b452ea4..343dc091 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -200,6 +200,36 @@ class Model(nn.Module): k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } + def shard(self, group: Optional[mx.distributed.Group] = None): + group = group or mx.distributed.init() + + def all_to_sharded(l): + if isinstance(l, nn.QuantizedLinear): + return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group) + else: + return nn.AllToShardedLinear.from_linear(l, group) + + def sharded_to_all(l): + if isinstance(l, nn.QuantizedLinear): + return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group) + else: + return nn.ShardedToAllLinear.from_linear(l, group) + + N = group.size() + for layer in self.model.layers: + # Shard the self attention + layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj) + layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj) + layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj) + layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj) + layer.self_attn.n_heads //= N + layer.self_attn.n_kv_heads //= N + + # Shard the MLP + layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj) + layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj) + layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj) + @property def layers(self): return self.model.layers diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b2e89a13..557c4316 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -699,6 +699,11 @@ def load_model( model.load_weights(list(weights.items()), strict=strict) + if mx.distributed.init().size() > 1: + if not hasattr(model, "shard"): + raise RuntimeError("Model doesn't support distributed inference.") + model.shard() + if not lazy: mx.eval(model.parameters())