mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
Merge branch 'ml-explore:main' into adding-dpo-training
This commit is contained in:
commit
69a8f11f7b
@ -22,6 +22,11 @@ import mlx.core as mx
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx-community/DeepSeek-R1-3bit",
|
||||
help="HF repo or path to local model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
@ -37,9 +42,7 @@ parser.add_argument(
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_repo = "mlx-community/DeepSeek-V3-3bit"
|
||||
|
||||
model, tokenizer = load(model_repo, lazy=True)
|
||||
model, tokenizer = load(args.model, lazy=True)
|
||||
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
@ -400,6 +400,8 @@ class DeepseekV3Model(nn.Module):
|
||||
|
||||
pipeline_rank = self.pipeline_rank
|
||||
pipeline_size = self.pipeline_size
|
||||
# Hack to avoid time-outs during prompt-processing
|
||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
@ -407,18 +409,21 @@ class DeepseekV3Model(nn.Module):
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
# Receive from the previous process in the pipeline
|
||||
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1))
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
|
||||
h = mx.distributed.send(
|
||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||
)
|
||||
|
||||
# Broadcast h while keeping it in the graph
|
||||
h = mx.distributed.all_gather(h)[: h.shape[0]]
|
||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
@ -219,7 +219,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
if prompt_feature and completion_feature:
|
||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||
elif text_feature:
|
||||
return Dataset(train_ds, tokenizer, text_key=text_feature)
|
||||
return Dataset(ds, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Specify either a prompt and completion feature or a text "
|
||||
|
Loading…
Reference in New Issue
Block a user