get fp16 working

This commit is contained in:
Awni Hannun 2025-01-06 13:23:53 -08:00
parent 7fed460146
commit 46d53ce110
2 changed files with 8 additions and 5 deletions

View File

@ -25,19 +25,19 @@ parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
"-p", "-p",
default="Hello world", default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
parser.add_argument( parser.add_argument(
"--max-tokens", "--max-tokens",
"-m", "-m",
type=int, type=int,
default=128, default=256,
help="Maximum number of tokens to generate", help="Maximum number of tokens to generate",
) )
args = parser.parse_args() args = parser.parse_args()
model_repo = "mlx-community/DeepSeek-V3-3bit-bf16" model_repo = "mlx-community/DeepSeek-V3-3bit"
model, tokenizer = load(model_repo, lazy=True) model, tokenizer = load(model_repo, lazy=True)

View File

@ -303,7 +303,7 @@ class MoEGate(nn.Module):
scores = scores / denominator scores = scores / denominator
scores = scores * self.routed_scaling_factor scores = scores * self.routed_scaling_factor
return inds, scores.astype(x.dtype) return inds, scores
class DeepseekV3MoE(nn.Module): class DeepseekV3MoE(nn.Module):
@ -325,7 +325,7 @@ class DeepseekV3MoE(nn.Module):
def __call__(self, x): def __call__(self, x):
inds, scores = self.gate(x) inds, scores = self.gate(x)
y = self.switch_mlp(x, inds) y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2) y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x) y = y + self.shared_experts(x)
@ -360,6 +360,9 @@ class DeepseekV3DecoderLayer(nn.Module):
h = x + r h = x + r
r = self.mlp(self.post_attention_layernorm(h)) r = self.mlp(self.post_attention_layernorm(h))
out = h + r out = h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out return out