get fp16 working

This commit is contained in:
Awni Hannun 2025-01-06 13:23:53 -08:00
parent 7fed460146
commit 46d53ce110
2 changed files with 8 additions and 5 deletions

View File

@ -25,19 +25,19 @@ parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--prompt",
"-p",
default="Hello world",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=128,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
model_repo = "mlx-community/DeepSeek-V3-3bit-bf16"
model_repo = "mlx-community/DeepSeek-V3-3bit"
model, tokenizer = load(model_repo, lazy=True)

View File

@ -303,7 +303,7 @@ class MoEGate(nn.Module):
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores.astype(x.dtype)
return inds, scores
class DeepseekV3MoE(nn.Module):
@ -325,7 +325,7 @@ class DeepseekV3MoE(nn.Module):
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
@ -360,6 +360,9 @@ class DeepseekV3DecoderLayer(nn.Module):
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out