Made llama and mistral files mypy compatible (#1359)

* Made mypy compatible

* reformatted

* Added more fixes

* Added fixes to speculative-decoding

* Fixes

* fix circle

* revert some stuff

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Param Thakkar
2025-04-24 02:53:46 +05:30
committed by GitHub
parent c52cc748f8
commit 4c9f9f9be7
10 changed files with 32 additions and 29 deletions

View File

@@ -40,7 +40,7 @@ def generate(
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt.size / prompt_time
prompt_tps = len(prompt) / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")

View File

@@ -19,10 +19,10 @@ class ModelArgs:
rms_norm_eps: float
vocab_size: int
context_length: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
model_type: Optional[str] = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
@@ -54,7 +54,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
self.repeats = n_heads // n_kv_heads
@@ -66,7 +66,7 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
1 / float(args.rope_scaling["factor"])
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
@@ -254,7 +254,7 @@ def translate_weight_names(name):
return name
def load(gguf_file: str, repo: str = None):
def load(gguf_file: str, repo: Optional[str] = None):
# If the gguf_file exists, try to load model from it.
# Otherwise try to download and cache from the HF repo
if not Path(gguf_file).exists():

View File

@@ -7,6 +7,7 @@ import glob
import json
import shutil
from pathlib import Path
from typing import Dict
import mlx.core as mx
import mlx.nn as nn
@@ -149,7 +150,8 @@ def quantize(weights, config, args):
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard, shard_size = {}, 0
shard: Dict[str, mx.array] = {}
shard_size = 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)

View File

@@ -23,7 +23,7 @@ class ModelArgs:
n_kv_heads: int
norm_eps: float
vocab_size: int
moe: dict = None
moe: dict
class Attention(nn.Module):
@@ -91,7 +91,6 @@ class FeedForward(nn.Module):
class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
@@ -115,7 +114,6 @@ class MOEFeedForward(nn.Module):
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)

View File

@@ -160,12 +160,12 @@ class SpeculativeDecoder:
)
n_accepted += num_to_accept
n_draft += draft_tokens.size
n_draft += len(draft_tokens)
# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size)
self.model.truncate_cache(n - new_tokens.size + 1)
if (n := len(draft_tokens)) > num_to_accept:
self.draft_model.truncate_cache(n - len(new_tokens))
self.model.truncate_cache(n - len(new_tokens) + 1)
n_steps += 1
@@ -181,7 +181,7 @@ class SpeculativeDecoder:
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
inputs = draft_inputs[-1:]
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)

View File

@@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
y, new_cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
@@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
y = self.dense(y)
x = x + y
return x, cache
return x, new_cache
def create_additive_causal_mask(N: int, offset: int = 0):