From 4c9f9f9be798e6cf04fd0f74395a3b4420077aad Mon Sep 17 00:00:00 2001 From: Param Thakkar <128291516+ParamThakkar123@users.noreply.github.com> Date: Thu, 24 Apr 2025 02:53:46 +0530 Subject: [PATCH] Made llama and mistral files mypy compatible (#1359) * Made mypy compatible * reformatted * Added more fixes * Added fixes to speculative-decoding * Fixes * fix circle * revert some stuff --------- Co-authored-by: Awni Hannun --- .circleci/config.yml | 2 -- llms/gguf_llm/generate.py | 2 +- llms/gguf_llm/models.py | 10 +++++----- llms/llama/convert.py | 4 +++- llms/mixtral/mixtral.py | 4 +--- llms/speculative_decoding/decoder.py | 10 +++++----- llms/speculative_decoding/model.py | 8 ++++---- lora/lora.py | 3 ++- lora/models.py | 6 +++--- lora/utils.py | 12 ++++++++---- 10 files changed, 32 insertions(+), 29 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 42a39194..aec28e77 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -36,7 +36,5 @@ workflows: type: approval - apple/authenticate: context: pr-approval - - mlx_lm_build_and_test: - requires: [ hold ] - linux_build_and_test: requires: [ hold ] diff --git a/llms/gguf_llm/generate.py b/llms/gguf_llm/generate.py index 7215aa48..db327cda 100644 --- a/llms/gguf_llm/generate.py +++ b/llms/gguf_llm/generate.py @@ -40,7 +40,7 @@ def generate( if len(tokens) == 0: print("No tokens generated for this prompt") return - prompt_tps = prompt.size / prompt_time + prompt_tps = len(prompt) / prompt_time gen_tps = (len(tokens) - 1) / gen_time print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") print(f"Generation: {gen_tps:.3f} tokens-per-sec") diff --git a/llms/gguf_llm/models.py b/llms/gguf_llm/models.py index 3b0afc65..9e1f9666 100644 --- a/llms/gguf_llm/models.py +++ b/llms/gguf_llm/models.py @@ -19,10 +19,10 @@ class ModelArgs: rms_norm_eps: float vocab_size: int context_length: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 10000 rope_traditional: bool = False - model_type: str = None + model_type: Optional[str] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None def __post_init__(self): @@ -54,7 +54,7 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads self.repeats = n_heads // n_kv_heads @@ -66,7 +66,7 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) rope_scale = ( - 1 / args.rope_scaling["factor"] + 1 / float(args.rope_scaling["factor"]) if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" else 1 ) @@ -254,7 +254,7 @@ def translate_weight_names(name): return name -def load(gguf_file: str, repo: str = None): +def load(gguf_file: str, repo: Optional[str] = None): # If the gguf_file exists, try to load model from it. # Otherwise try to download and cache from the HF repo if not Path(gguf_file).exists(): diff --git a/llms/llama/convert.py b/llms/llama/convert.py index 04c10a5f..33610f44 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -7,6 +7,7 @@ import glob import json import shutil from pathlib import Path +from typing import Dict import mlx.core as mx import mlx.nn as nn @@ -149,7 +150,8 @@ def quantize(weights, config, args): def make_shards(weights: dict, max_file_size_gibibyte: int = 15): max_file_size_bytes = max_file_size_gibibyte << 30 shards = [] - shard, shard_size = {}, 0 + shard: Dict[str, mx.array] = {} + shard_size = 0 for k, v in weights.items(): if shard_size + v.nbytes > max_file_size_bytes: shards.append(shard) diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index 4b45d066..653dad57 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -23,7 +23,7 @@ class ModelArgs: n_kv_heads: int norm_eps: float vocab_size: int - moe: dict = None + moe: dict class Attention(nn.Module): @@ -91,7 +91,6 @@ class FeedForward(nn.Module): class MOEFeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.num_experts = args.moe["num_experts"] self.num_experts_per_tok = args.moe["num_experts_per_tok"] self.experts = [FeedForward(args) for _ in range(self.num_experts)] @@ -115,7 +114,6 @@ class MOEFeedForward(nn.Module): yt = (yt * st).sum(axis=-1) y.append(yt[None, :]) y = mx.concatenate(y) - return y.reshape(orig_shape) diff --git a/llms/speculative_decoding/decoder.py b/llms/speculative_decoding/decoder.py index 3d547a7f..39cf5b92 100644 --- a/llms/speculative_decoding/decoder.py +++ b/llms/speculative_decoding/decoder.py @@ -160,12 +160,12 @@ class SpeculativeDecoder: ) n_accepted += num_to_accept - n_draft += draft_tokens.size + n_draft += len(draft_tokens) # Rewind the cache for unaccepted tokens: - if (n := draft_tokens.size) > num_to_accept: - self.draft_model.truncate_cache(n - new_tokens.size) - self.model.truncate_cache(n - new_tokens.size + 1) + if (n := len(draft_tokens)) > num_to_accept: + self.draft_model.truncate_cache(n - len(new_tokens)) + self.model.truncate_cache(n - len(new_tokens) + 1) n_steps += 1 @@ -181,7 +181,7 @@ class SpeculativeDecoder: if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id: break - draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :] + draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :] inputs = draft_inputs[-1:] print(self.tokenizer.decode(outputs)[skip:], end="", flush=True) diff --git a/llms/speculative_decoding/model.py b/llms/speculative_decoding/model.py index c310b943..d30daa97 100644 --- a/llms/speculative_decoding/model.py +++ b/llms/speculative_decoding/model.py @@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module): memory: mx.array, mask: mx.array, memory_mask: mx.array, - cache: Optional[List[Tuple[mx.array, mx.array]]] = None, - ): + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: y = self.ln1(x) - y, cache = self.self_attention(y, y, y, mask, cache) + y, new_cache = self.self_attention(y, y, y, mask, cache) x = x + y y = self.ln2(x) @@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module): y = self.dense(y) x = x + y - return x, cache + return x, new_cache def create_additive_causal_mask(N: int, offset: int = 0): diff --git a/lora/lora.py b/lora/lora.py index 6f91ccca..7fc3998b 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -3,6 +3,7 @@ import argparse import json import math +import os import sys import time from pathlib import Path @@ -16,7 +17,7 @@ from mlx.utils import tree_flatten from models import LoRALinear # Disable output buffering to see print statements in real-time -sys.stdout.reconfigure(line_buffering=True) +sys.stdout = os.fdopen(sys.stdout.fileno(), "w", buffering=1) def build_parser(): diff --git a/lora/models.py b/lora/models.py index 3e85b135..acafbc61 100644 --- a/lora/models.py +++ b/lora/models.py @@ -17,10 +17,10 @@ class ModelArgs: num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 10000 rope_traditional: bool = False - model_type: str = None + model_type: Optional[str] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None def __post_init__(self): @@ -146,7 +146,7 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) rope_scale = ( - 1 / args.rope_scaling["factor"] + 1 / float(args.rope_scaling["factor"]) if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" else 1 ) diff --git a/lora/utils.py b/lora/utils.py index a334723c..db9c0876 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -4,7 +4,7 @@ import glob import json import logging from pathlib import Path -from typing import Generator +from typing import Any, Dict, Generator, Union import mlx.core as mx import mlx.nn as nn @@ -72,7 +72,8 @@ python generate.py --model {repo_id} --prompt "My name is" def make_shards(weights: dict, max_file_size_gibibyte: int = 15): max_file_size_bytes = max_file_size_gibibyte << 30 shards = [] - shard, shard_size = {}, 0 + shard: Dict[str, mx.array] = {} + shard_size = 0 for k, v in weights.items(): if shard_size + v.nbytes > max_file_size_bytes: shards.append(shard) @@ -83,7 +84,7 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15): return shards -def save_model(save_dir: str, weights, tokenizer, config): +def save_model(save_dir: Union[str, Path], weights, tokenizer, config): save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) @@ -96,7 +97,10 @@ def save_model(save_dir: str, weights, tokenizer, config): ) total_size = sum(v.nbytes for v in weights.values()) - index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + index_data: Dict[str, Any] = { + "metadata": {"total_size": total_size}, + "weight_map": {}, + } for i, shard in enumerate(shards): shard_name = shard_file_format.format(i + 1, shards_count)