mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-05 16:34:34 +08:00
Made llama and mistral files mypy compatible (#1359)
* Made mypy compatible * reformatted * Added more fixes * Added fixes to speculative-decoding * Fixes * fix circle * revert some stuff --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import glob
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Any, Dict, Generator, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -72,7 +72,8 @@ python generate.py --model {repo_id} --prompt "My name is"
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
shard: Dict[str, mx.array] = {}
|
||||
shard_size = 0
|
||||
for k, v in weights.items():
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
@@ -83,7 +84,7 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
return shards
|
||||
|
||||
|
||||
def save_model(save_dir: str, weights, tokenizer, config):
|
||||
def save_model(save_dir: Union[str, Path], weights, tokenizer, config):
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -96,7 +97,10 @@ def save_model(save_dir: str, weights, tokenizer, config):
|
||||
)
|
||||
|
||||
total_size = sum(v.nbytes for v in weights.values())
|
||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||
index_data: Dict[str, Any] = {
|
||||
"metadata": {"total_size": total_size},
|
||||
"weight_map": {},
|
||||
}
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||
|
Reference in New Issue
Block a user