mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
@@ -32,6 +32,10 @@ page](https://huggingface.co/deepseek-ai) to see a list of available models.
|
||||
By default, the conversion script will save the converted `weights.npz`,
|
||||
tokenizer, and `config.json` in the `mlx_model` directory.
|
||||
|
||||
> [!TIP] Alternatively, you can also download a few converted checkpoints from
|
||||
> the [MLX Community](https://huggingface.co/mlx-community) organization on
|
||||
> Hugging Face and skip the conversion step.
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights, you can interact with the Deepseek coder
|
||||
|
@@ -14,11 +14,13 @@ import torch
|
||||
from llama import Llama, ModelArgs, sanitize_config
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
|
||||
# bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
|
||||
a = a.to(torch.float32) if dtype == 'bfloat16' else a.to(getattr(torch, dtype))
|
||||
a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype))
|
||||
return mx.array(a.numpy(), getattr(mx, dtype))
|
||||
|
||||
|
||||
def llama(model_path, *, dtype: str):
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
@@ -48,7 +50,7 @@ def llama(model_path, *, dtype: str):
|
||||
state = torch.load(wf, map_location=torch.device("cpu"))
|
||||
for k, v in state.items():
|
||||
v = torch_to_mx(v, dtype=dtype)
|
||||
state[k] = None # free memory
|
||||
state[k] = None # free memory
|
||||
if shard_key(k) in SHARD_WEIGHTS:
|
||||
weights[k].append(v)
|
||||
else:
|
||||
@@ -204,7 +206,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="dtype for loading the torch model and input for quantization or saving the converted model. "
|
||||
"The original weights are stored in bfloat16.",
|
||||
"The original weights are stored in bfloat16.",
|
||||
type=str,
|
||||
default="float16",
|
||||
)
|
||||
|
Reference in New Issue
Block a user