From df1406735b1e140e44b29b302783fbc837da9269 Mon Sep 17 00:00:00 2001 From: Victor Nogueira Date: Tue, 21 Jan 2025 23:12:43 +0100 Subject: [PATCH 1/2] Fix dataset variable name, in `datasets.py` (#1212) --- llms/mlx_lm/tuner/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 1b09c7e2..377e7cae 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -170,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): if prompt_feature and completion_feature: return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) elif text_feature: - return Dataset(train_ds, tokenizer, text_key=text_feature) + return Dataset(ds, tokenizer, text_key=text_feature) else: raise ValueError( "Specify either a prompt and completion feature or a text " From 9a3ddc3e656e2b8d43eba3576f0e72fd3a0b2681 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 21 Jan 2025 19:40:29 -0800 Subject: [PATCH 2/2] some fixes for pipeline parallel deep seek r1 (#1216) --- llms/mlx_lm/examples/pipeline_generate.py | 9 ++++++--- llms/mlx_lm/models/deepseek_v3.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py index b98e757b..2970b986 100644 --- a/llms/mlx_lm/examples/pipeline_generate.py +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -22,6 +22,11 @@ import mlx.core as mx from mlx_lm import load, stream_generate parser = argparse.ArgumentParser(description="LLM pipelined inference example") +parser.add_argument( + "--model", + default="mlx-community/DeepSeek-R1-3bit", + help="HF repo or path to local model.", +) parser.add_argument( "--prompt", "-p", @@ -37,9 +42,7 @@ parser.add_argument( ) args = parser.parse_args() -model_repo = "mlx-community/DeepSeek-V3-3bit" - -model, tokenizer = load(model_repo, lazy=True) +model, tokenizer = load(args.model, lazy=True) messages = [{"role": "user", "content": args.prompt}] prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index f95949f9..46ee6ab3 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -400,6 +400,8 @@ class DeepseekV3Model(nn.Module): pipeline_rank = self.pipeline_rank pipeline_size = self.pipeline_size + # Hack to avoid time-outs during prompt-processing + dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu if mask is None: mask = create_attention_mask(h, cache) @@ -407,18 +409,21 @@ class DeepseekV3Model(nn.Module): cache = [None] * len(self.layers) # Receive from the previous process in the pipeline + if pipeline_rank < pipeline_size - 1: - h = mx.distributed.recv_like(h, (pipeline_rank + 1)) + h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) for layer, c in zip(self.layers, cache): h = layer(h, mask, c) # Send to the next process in the pipeline if pipeline_rank != 0: - h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size) + h = mx.distributed.send( + h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream + ) # Broadcast h while keeping it in the graph - h = mx.distributed.all_gather(h)[: h.shape[0]] + h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]] return self.norm(h)