mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:12:24 +08:00
get fp16 working
This commit is contained in:
parent
7fed460146
commit
46d53ce110
@ -25,19 +25,19 @@ parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
"-p",
|
"-p",
|
||||||
default="Hello world",
|
default="Write a quicksort in C++.",
|
||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
"-m",
|
"-m",
|
||||||
type=int,
|
type=int,
|
||||||
default=128,
|
default=256,
|
||||||
help="Maximum number of tokens to generate",
|
help="Maximum number of tokens to generate",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_repo = "mlx-community/DeepSeek-V3-3bit-bf16"
|
model_repo = "mlx-community/DeepSeek-V3-3bit"
|
||||||
|
|
||||||
model, tokenizer = load(model_repo, lazy=True)
|
model, tokenizer = load(model_repo, lazy=True)
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ class MoEGate(nn.Module):
|
|||||||
scores = scores / denominator
|
scores = scores / denominator
|
||||||
scores = scores * self.routed_scaling_factor
|
scores = scores * self.routed_scaling_factor
|
||||||
|
|
||||||
return inds, scores.astype(x.dtype)
|
return inds, scores
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3MoE(nn.Module):
|
class DeepseekV3MoE(nn.Module):
|
||||||
@ -325,7 +325,7 @@ class DeepseekV3MoE(nn.Module):
|
|||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
inds, scores = self.gate(x)
|
inds, scores = self.gate(x)
|
||||||
y = self.switch_mlp(x, inds)
|
y = self.switch_mlp(x, inds)
|
||||||
y = (y * scores[..., None]).sum(axis=-2)
|
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
y = y + self.shared_experts(x)
|
y = y + self.shared_experts(x)
|
||||||
|
|
||||||
@ -360,6 +360,9 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
|
# Protect against overflow for fp16
|
||||||
|
if out.dtype == mx.float16:
|
||||||
|
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user