mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Made llama and mistral files mypy compatible (#1359)
* Made mypy compatible * reformatted * Added more fixes * Added fixes to speculative-decoding * Fixes * fix circle * revert some stuff --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
c52cc748f8
commit
4c9f9f9be7
@ -36,7 +36,5 @@ workflows:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- mlx_lm_build_and_test:
|
||||
requires: [ hold ]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
|
@ -40,7 +40,7 @@ def generate(
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
prompt_tps = len(prompt) / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
@ -19,10 +19,10 @@ class ModelArgs:
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
context_length: int
|
||||
num_key_value_heads: int = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
model_type: Optional[str] = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@ -54,7 +54,7 @@ class Attention(nn.Module):
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
|
||||
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
|
||||
@ -66,7 +66,7 @@ class Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
1 / float(args.rope_scaling["factor"])
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
@ -254,7 +254,7 @@ def translate_weight_names(name):
|
||||
return name
|
||||
|
||||
|
||||
def load(gguf_file: str, repo: str = None):
|
||||
def load(gguf_file: str, repo: Optional[str] = None):
|
||||
# If the gguf_file exists, try to load model from it.
|
||||
# Otherwise try to download and cache from the HF repo
|
||||
if not Path(gguf_file).exists():
|
||||
|
@ -7,6 +7,7 @@ import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -149,7 +150,8 @@ def quantize(weights, config, args):
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
shard: Dict[str, mx.array] = {}
|
||||
shard_size = 0
|
||||
for k, v in weights.items():
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
|
@ -23,7 +23,7 @@ class ModelArgs:
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
moe: dict = None
|
||||
moe: dict
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -91,7 +91,6 @@ class FeedForward(nn.Module):
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.num_experts = args.moe["num_experts"]
|
||||
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
||||
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
||||
@ -115,7 +114,6 @@ class MOEFeedForward(nn.Module):
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
|
@ -160,12 +160,12 @@ class SpeculativeDecoder:
|
||||
)
|
||||
|
||||
n_accepted += num_to_accept
|
||||
n_draft += draft_tokens.size
|
||||
n_draft += len(draft_tokens)
|
||||
|
||||
# Rewind the cache for unaccepted tokens:
|
||||
if (n := draft_tokens.size) > num_to_accept:
|
||||
self.draft_model.truncate_cache(n - new_tokens.size)
|
||||
self.model.truncate_cache(n - new_tokens.size + 1)
|
||||
if (n := len(draft_tokens)) > num_to_accept:
|
||||
self.draft_model.truncate_cache(n - len(new_tokens))
|
||||
self.model.truncate_cache(n - len(new_tokens) + 1)
|
||||
|
||||
n_steps += 1
|
||||
|
||||
@ -181,7 +181,7 @@ class SpeculativeDecoder:
|
||||
|
||||
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
||||
break
|
||||
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
|
||||
draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
|
||||
inputs = draft_inputs[-1:]
|
||||
|
||||
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
|
||||
|
@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory: mx.array,
|
||||
mask: mx.array,
|
||||
memory_mask: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
):
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
y = self.ln1(x)
|
||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||
y, new_cache = self.self_attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
y = self.dense(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
return x, new_cache
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
|
@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
@ -16,7 +17,7 @@ from mlx.utils import tree_flatten
|
||||
from models import LoRALinear
|
||||
|
||||
# Disable output buffering to see print statements in real-time
|
||||
sys.stdout.reconfigure(line_buffering=True)
|
||||
sys.stdout = os.fdopen(sys.stdout.fileno(), "w", buffering=1)
|
||||
|
||||
|
||||
def build_parser():
|
||||
|
@ -17,10 +17,10 @@ class ModelArgs:
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
model_type: Optional[str] = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@ -146,7 +146,7 @@ class Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
1 / float(args.rope_scaling["factor"])
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ import glob
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Any, Dict, Generator, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -72,7 +72,8 @@ python generate.py --model {repo_id} --prompt "My name is"
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
shard: Dict[str, mx.array] = {}
|
||||
shard_size = 0
|
||||
for k, v in weights.items():
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
@ -83,7 +84,7 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
return shards
|
||||
|
||||
|
||||
def save_model(save_dir: str, weights, tokenizer, config):
|
||||
def save_model(save_dir: Union[str, Path], weights, tokenizer, config):
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -96,7 +97,10 @@ def save_model(save_dir: str, weights, tokenizer, config):
|
||||
)
|
||||
|
||||
total_size = sum(v.nbytes for v in weights.values())
|
||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||
index_data: Dict[str, Any] = {
|
||||
"metadata": {"total_size": total_size},
|
||||
"weight_map": {},
|
||||
}
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||
|
Loading…
Reference in New Issue
Block a user