From 46d53ce110c8cbed767ed63ff3faa62c8ce0c487 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 6 Jan 2025 13:23:53 -0800 Subject: [PATCH] get fp16 working --- llms/mlx_lm/examples/pipeline_generate.py | 6 +++--- llms/mlx_lm/models/deepseek_v3.py | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py index 7f6f34db..b98e757b 100644 --- a/llms/mlx_lm/examples/pipeline_generate.py +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -25,19 +25,19 @@ parser = argparse.ArgumentParser(description="LLM pipelined inference example") parser.add_argument( "--prompt", "-p", - default="Hello world", + default="Write a quicksort in C++.", help="Message to be processed by the model ('-' reads from stdin)", ) parser.add_argument( "--max-tokens", "-m", type=int, - default=128, + default=256, help="Maximum number of tokens to generate", ) args = parser.parse_args() -model_repo = "mlx-community/DeepSeek-V3-3bit-bf16" +model_repo = "mlx-community/DeepSeek-V3-3bit" model, tokenizer = load(model_repo, lazy=True) diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index ee27a60e..f95949f9 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -303,7 +303,7 @@ class MoEGate(nn.Module): scores = scores / denominator scores = scores * self.routed_scaling_factor - return inds, scores.astype(x.dtype) + return inds, scores class DeepseekV3MoE(nn.Module): @@ -325,7 +325,7 @@ class DeepseekV3MoE(nn.Module): def __call__(self, x): inds, scores = self.gate(x) y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) + y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype) if self.config.n_shared_experts is not None: y = y + self.shared_experts(x) @@ -360,6 +360,9 @@ class DeepseekV3DecoderLayer(nn.Module): h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r + # Protect against overflow for fp16 + if out.dtype == mx.float16: + out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000) return out