Made llama and mistral files mypy compatible (#1359)

* Made mypy compatible

* reformatted

* Added more fixes

* Added fixes to speculative-decoding

* Fixes

* fix circle

* revert some stuff

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Param Thakkar 2025-04-24 02:53:46 +05:30 committed by GitHub
parent c52cc748f8
commit 4c9f9f9be7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 32 additions and 29 deletions

View File

@ -36,7 +36,5 @@ workflows:
type: approval type: approval
- apple/authenticate: - apple/authenticate:
context: pr-approval context: pr-approval
- mlx_lm_build_and_test:
requires: [ hold ]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]

View File

@ -40,7 +40,7 @@ def generate(
if len(tokens) == 0: if len(tokens) == 0:
print("No tokens generated for this prompt") print("No tokens generated for this prompt")
return return
prompt_tps = prompt.size / prompt_time prompt_tps = len(prompt) / prompt_time
gen_tps = (len(tokens) - 1) / gen_time gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec") print(f"Generation: {gen_tps:.3f} tokens-per-sec")

View File

@ -19,10 +19,10 @@ class ModelArgs:
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
context_length: int context_length: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
model_type: str = None model_type: Optional[str] = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self): def __post_init__(self):
@ -54,7 +54,7 @@ class Attention(nn.Module):
dim = args.hidden_size dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
self.repeats = n_heads // n_kv_heads self.repeats = n_heads // n_kv_heads
@ -66,7 +66,7 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = ( rope_scale = (
1 / args.rope_scaling["factor"] 1 / float(args.rope_scaling["factor"])
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1 else 1
) )
@ -254,7 +254,7 @@ def translate_weight_names(name):
return name return name
def load(gguf_file: str, repo: str = None): def load(gguf_file: str, repo: Optional[str] = None):
# If the gguf_file exists, try to load model from it. # If the gguf_file exists, try to load model from it.
# Otherwise try to download and cache from the HF repo # Otherwise try to download and cache from the HF repo
if not Path(gguf_file).exists(): if not Path(gguf_file).exists():

View File

@ -7,6 +7,7 @@ import glob
import json import json
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Dict
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -149,7 +150,8 @@ def quantize(weights, config, args):
def make_shards(weights: dict, max_file_size_gibibyte: int = 15): def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30 max_file_size_bytes = max_file_size_gibibyte << 30
shards = [] shards = []
shard, shard_size = {}, 0 shard: Dict[str, mx.array] = {}
shard_size = 0
for k, v in weights.items(): for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes: if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard) shards.append(shard)

View File

@ -23,7 +23,7 @@ class ModelArgs:
n_kv_heads: int n_kv_heads: int
norm_eps: float norm_eps: float
vocab_size: int vocab_size: int
moe: dict = None moe: dict
class Attention(nn.Module): class Attention(nn.Module):
@ -91,7 +91,6 @@ class FeedForward(nn.Module):
class MOEFeedForward(nn.Module): class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.num_experts = args.moe["num_experts"] self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"] self.num_experts_per_tok = args.moe["num_experts_per_tok"]
self.experts = [FeedForward(args) for _ in range(self.num_experts)] self.experts = [FeedForward(args) for _ in range(self.num_experts)]
@ -115,7 +114,6 @@ class MOEFeedForward(nn.Module):
yt = (yt * st).sum(axis=-1) yt = (yt * st).sum(axis=-1)
y.append(yt[None, :]) y.append(yt[None, :])
y = mx.concatenate(y) y = mx.concatenate(y)
return y.reshape(orig_shape) return y.reshape(orig_shape)

View File

@ -160,12 +160,12 @@ class SpeculativeDecoder:
) )
n_accepted += num_to_accept n_accepted += num_to_accept
n_draft += draft_tokens.size n_draft += len(draft_tokens)
# Rewind the cache for unaccepted tokens: # Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept: if (n := len(draft_tokens)) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size) self.draft_model.truncate_cache(n - len(new_tokens))
self.model.truncate_cache(n - new_tokens.size + 1) self.model.truncate_cache(n - len(new_tokens) + 1)
n_steps += 1 n_steps += 1
@ -181,7 +181,7 @@ class SpeculativeDecoder:
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id: if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :] draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
inputs = draft_inputs[-1:] inputs = draft_inputs[-1:]
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True) print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)

View File

@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
memory: mx.array, memory: mx.array,
mask: mx.array, mask: mx.array,
memory_mask: mx.array, memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None, cache: Optional[Tuple[mx.array, mx.array]] = None,
): ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
y = self.ln1(x) y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache) y, new_cache = self.self_attention(y, y, y, mask, cache)
x = x + y x = x + y
y = self.ln2(x) y = self.ln2(x)
@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
y = self.dense(y) y = self.dense(y)
x = x + y x = x + y
return x, cache return x, new_cache
def create_additive_causal_mask(N: int, offset: int = 0): def create_additive_causal_mask(N: int, offset: int = 0):

View File

@ -3,6 +3,7 @@
import argparse import argparse
import json import json
import math import math
import os
import sys import sys
import time import time
from pathlib import Path from pathlib import Path
@ -16,7 +17,7 @@ from mlx.utils import tree_flatten
from models import LoRALinear from models import LoRALinear
# Disable output buffering to see print statements in real-time # Disable output buffering to see print statements in real-time
sys.stdout.reconfigure(line_buffering=True) sys.stdout = os.fdopen(sys.stdout.fileno(), "w", buffering=1)
def build_parser(): def build_parser():

View File

@ -17,10 +17,10 @@ class ModelArgs:
num_attention_heads: int num_attention_heads: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
model_type: str = None model_type: Optional[str] = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self): def __post_init__(self):
@ -146,7 +146,7 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = ( rope_scale = (
1 / args.rope_scaling["factor"] 1 / float(args.rope_scaling["factor"])
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1 else 1
) )

View File

@ -4,7 +4,7 @@ import glob
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Generator from typing import Any, Dict, Generator, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -72,7 +72,8 @@ python generate.py --model {repo_id} --prompt "My name is"
def make_shards(weights: dict, max_file_size_gibibyte: int = 15): def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30 max_file_size_bytes = max_file_size_gibibyte << 30
shards = [] shards = []
shard, shard_size = {}, 0 shard: Dict[str, mx.array] = {}
shard_size = 0
for k, v in weights.items(): for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes: if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard) shards.append(shard)
@ -83,7 +84,7 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
return shards return shards
def save_model(save_dir: str, weights, tokenizer, config): def save_model(save_dir: Union[str, Path], weights, tokenizer, config):
save_dir = Path(save_dir) save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
@ -96,7 +97,10 @@ def save_model(save_dir: str, weights, tokenizer, config):
) )
total_size = sum(v.nbytes for v in weights.values()) total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} index_data: Dict[str, Any] = {
"metadata": {"total_size": total_size},
"weight_map": {},
}
for i, shard in enumerate(shards): for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count) shard_name = shard_file_format.format(i + 1, shards_count)