mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:22:46 +08:00
get fp16 working
This commit is contained in:
parent
7fed460146
commit
46d53ce110
@ -25,19 +25,19 @@ parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
default="Hello world",
|
||||
default="Write a quicksort in C++.",
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=128,
|
||||
default=256,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_repo = "mlx-community/DeepSeek-V3-3bit-bf16"
|
||||
model_repo = "mlx-community/DeepSeek-V3-3bit"
|
||||
|
||||
model, tokenizer = load(model_repo, lazy=True)
|
||||
|
||||
|
@ -303,7 +303,7 @@ class MoEGate(nn.Module):
|
||||
scores = scores / denominator
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
return inds, scores.astype(x.dtype)
|
||||
return inds, scores
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
@ -325,7 +325,7 @@ class DeepseekV3MoE(nn.Module):
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
@ -360,6 +360,9 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
# Protect against overflow for fp16
|
||||
if out.dtype == mx.float16:
|
||||
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
|
||||
return out
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user