mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
deepseek v3 model with pipeline parallelism (#1191)
* deepseekv3 * use upload_large_file instead of deprecated multi comit * add pipeline generation and example * comment * get fp16 working * use mlx==0.22
This commit is contained in:
parent
40b88eff48
commit
5cae0a60e6
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.20.4"
|
||||
__version__ = "0.21.0"
|
||||
|
75
llms/mlx_lm/examples/pipeline_generate.py
Normal file
75
llms/mlx_lm/examples/pipeline_generate.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Run with:
|
||||
|
||||
```
|
||||
/path/to/mpirun \
|
||||
-np 2 \
|
||||
--hostfile /path/to/hosts.txt \
|
||||
python /path/to/pipeline_generate.py --prompt "hello world"
|
||||
```
|
||||
|
||||
Make sure you can run MLX over MPI on two hosts. For more information see the
|
||||
documentation:
|
||||
|
||||
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
"-p",
|
||||
default="Write a quicksort in C++.",
|
||||
help="Message to be processed by the model ('-' reads from stdin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_repo = "mlx-community/DeepSeek-V3-3bit"
|
||||
|
||||
model, tokenizer = load(model_repo, lazy=True)
|
||||
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
group = mx.distributed.init()
|
||||
rank = group.rank()
|
||||
model.model.pipeline(group)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Synchronize processes before generation to avoid timeout if downloading
|
||||
# model for the first time.
|
||||
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
||||
|
||||
|
||||
def rprint(*args, **kwargs):
|
||||
if rank == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
|
||||
rprint(response.text, end="", flush=True)
|
||||
|
||||
rprint()
|
||||
rprint("=" * 10)
|
||||
rprint(
|
||||
f"Prompt: {response.prompt_tokens} tokens, "
|
||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
rprint(
|
||||
f"Generation: {response.generation_tokens} tokens, "
|
||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
460
llms/mlx_lm/models/deepseek_v3.py
Normal file
460
llms/mlx_lm/models/deepseek_v3.py
Normal file
@ -0,0 +1,460 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "deepseek_v3"
|
||||
vocab_size: int = 102400
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
moe_intermediate_size: int = 1407
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_routed_experts: Optional[int] = None
|
||||
routed_scaling_factor: float = 1.0
|
||||
kv_lora_rank: int = 512
|
||||
q_lora_rank: int = 1536
|
||||
qk_rope_head_dim: int = 64
|
||||
v_head_dim: int = 128
|
||||
qk_nope_head_dim: int = 128
|
||||
topk_method: str = "noaux_tc"
|
||||
scoring_func: str = "sigmoid"
|
||||
norm_topk_prob: bool = True
|
||||
n_group: Optional[int] = None
|
||||
topk_group: Optional[int] = None
|
||||
num_experts_per_tok: Optional[int] = None
|
||||
moe_layer_freq: int = 1
|
||||
first_k_dense_replace: int = 0
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Dict = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
def yarn_find_correction_dim(
|
||||
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
|
||||
def yarn_find_correction_range(
|
||||
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
||||
if min_val == max_val:
|
||||
max_val += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
|
||||
return mx.clip(linear_func, 0, 1)
|
||||
|
||||
|
||||
class DeepseekV3YarnRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
scaling_factor=1.0,
|
||||
original_max_position_embeddings=4096,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
mscale=1,
|
||||
mscale_all_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||
freq_inter = scaling_factor * base ** (
|
||||
mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
)
|
||||
low, high = yarn_find_correction_range(
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
dim,
|
||||
base,
|
||||
original_max_position_embeddings,
|
||||
)
|
||||
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
self._freqs = (freq_inter * freq_extra) / (
|
||||
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
|
||||
)
|
||||
|
||||
def __call__(self, x, offset=0):
|
||||
if self.mscale != 1.0:
|
||||
x = self.mscale * x
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
x.shape[-1],
|
||||
traditional=True,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
|
||||
self.scale = self.q_head_dim**-0.5
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(
|
||||
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
|
||||
)
|
||||
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(
|
||||
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads
|
||||
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rope = DeepseekV3YarnRotaryEmbedding(
|
||||
dim=self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**rope_kwargs,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(x)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
|
||||
|
||||
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
|
||||
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
|
||||
compressed_kv = self.kv_a_proj_with_mqa(x)
|
||||
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
|
||||
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
)
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
|
||||
scores = mx.sigmoid(gates.astype(mx.float32))
|
||||
|
||||
assert self.topk_method == "noaux_tc", "Unsupported topk method."
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores + self.e_score_correction_bias
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
|
||||
scores = scores / denominator
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV3MLP(
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV3Attention(config)
|
||||
self.mlp = (
|
||||
DeepseekV3MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV3MLP(config)
|
||||
)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
# Protect against overflow for fp16
|
||||
if out.dtype == mx.float16:
|
||||
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
|
||||
return out
|
||||
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
DeepseekV3DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_size = 1
|
||||
|
||||
def pipeline(self, group):
|
||||
# Split layers in reverse so rank=0 gets the last layers and
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = (
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.layers = self.layers[start : start + layers_per_rank]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
|
||||
pipeline_rank = self.pipeline_rank
|
||||
pipeline_size = self.pipeline_size
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
# Receive from the previous process in the pipeline
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1))
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
|
||||
|
||||
# Broadcast h while keeping it in the graph
|
||||
h = mx.distributed.all_gather(h)[: h.shape[0]]
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV3Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
):
|
||||
out = self.model(inputs, cache, mask)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
# Remove multi-token prediction layer
|
||||
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.19.2
|
||||
mlx>=0.22.0
|
||||
numpy
|
||||
transformers[sentencepiece]>=4.39.3
|
||||
protobuf
|
||||
|
@ -561,7 +561,7 @@ def load(
|
||||
Defaults to an empty dictionary.
|
||||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||
to the model. Default: ``None``.
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
lazy (bool): If ``False`` eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
Returns:
|
||||
@ -655,7 +655,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
|
||||
model, tokenizer = load("{upload_repo}")
|
||||
|
||||
prompt="hello"
|
||||
prompt = "hello"
|
||||
|
||||
if tokenizer.chat_template is not None:
|
||||
messages = [{{"role": "user", "content": prompt}}]
|
||||
|
@ -682,6 +682,43 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek_v3(self):
|
||||
from mlx_lm.models import deepseek_v3
|
||||
|
||||
args = deepseek_v3.ModelArgs(
|
||||
model_type="deepseek_v3",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
n_routed_experts=4,
|
||||
n_group=2,
|
||||
topk_group=1,
|
||||
num_experts_per_tok=2,
|
||||
n_shared_experts=1,
|
||||
kv_lora_rank=4,
|
||||
q_lora_rank=4,
|
||||
qk_rope_head_dim=32,
|
||||
v_head_dim=16,
|
||||
qk_nope_head_dim=32,
|
||||
rope_scaling={
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 40,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn",
|
||||
},
|
||||
)
|
||||
model = deepseek_v3.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma2(self):
|
||||
from mlx_lm.models import gemma2
|
||||
|
||||
|
@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
|
||||
self.config = args
|
||||
self.custom_attribute = "This is a custom model"
|
||||
|
||||
def load_weights(self, weights):
|
||||
def load_weights(self, weights, **kwargs):
|
||||
self.qwenWeights = weights
|
||||
|
||||
class CustomQwenConfig:
|
||||
|
Loading…
Reference in New Issue
Block a user