From 4304f5aaf586870399b3590c28062647777b4a7c Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Mon, 7 Apr 2025 08:59:07 +0530 Subject: [PATCH] Added more fixes --- llms/speculative_decoding/decoder.py | 12 ++++++------ lora/lora.py | 3 ++- lora/models.py | 11 ++++++++--- lora/utils.py | 12 ++++++++---- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/llms/speculative_decoding/decoder.py b/llms/speculative_decoding/decoder.py index 3d547a7f..d2b97716 100644 --- a/llms/speculative_decoding/decoder.py +++ b/llms/speculative_decoding/decoder.py @@ -160,16 +160,16 @@ class SpeculativeDecoder: ) n_accepted += num_to_accept - n_draft += draft_tokens.size + n_draft += len(draft_tokens) # Rewind the cache for unaccepted tokens: - if (n := draft_tokens.size) > num_to_accept: - self.draft_model.truncate_cache(n - new_tokens.size) - self.model.truncate_cache(n - new_tokens.size + 1) + if (n := len(draft_tokens)) > num_to_accept: + self.draft_model.truncate_cache(n - len(new_tokens)) + self.model.truncate_cache(n - len(new_tokens) + 1) n_steps += 1 - for t in new_tokens.tolist(): + for t in list(new_tokens): if t == self.tokenizer.eos_id or ntoks >= max_tokens: break outputs.append(t) @@ -181,7 +181,7 @@ class SpeculativeDecoder: if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id: break - draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :] + draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :] inputs = draft_inputs[-1:] print(self.tokenizer.decode(outputs)[skip:], end="", flush=True) diff --git a/lora/lora.py b/lora/lora.py index 6f91ccca..7fc3998b 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -3,6 +3,7 @@ import argparse import json import math +import os import sys import time from pathlib import Path @@ -16,7 +17,7 @@ from mlx.utils import tree_flatten from models import LoRALinear # Disable output buffering to see print statements in real-time -sys.stdout.reconfigure(line_buffering=True) +sys.stdout = os.fdopen(sys.stdout.fileno(), "w", buffering=1) def build_parser(): diff --git a/lora/models.py b/lora/models.py index 3e85b135..ddec473c 100644 --- a/lora/models.py +++ b/lora/models.py @@ -17,10 +17,10 @@ class ModelArgs: num_attention_heads: int rms_norm_eps: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None rope_theta: float = 10000 rope_traditional: bool = False - model_type: str = None + model_type: Optional[str] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None def __post_init__(self): @@ -136,6 +136,11 @@ class Attention(nn.Module): self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads + if n_heads is None or n_kv_heads is None: + raise ValueError( + "num_attention_heads and num_key_value_heads must not be None" + ) + self.repeats = n_heads // n_kv_heads head_dim = args.hidden_size // n_heads @@ -146,7 +151,7 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) rope_scale = ( - 1 / args.rope_scaling["factor"] + 1 / float(args.rope_scaling["factor"]) if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" else 1 ) diff --git a/lora/utils.py b/lora/utils.py index a334723c..db9c0876 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -4,7 +4,7 @@ import glob import json import logging from pathlib import Path -from typing import Generator +from typing import Any, Dict, Generator, Union import mlx.core as mx import mlx.nn as nn @@ -72,7 +72,8 @@ python generate.py --model {repo_id} --prompt "My name is" def make_shards(weights: dict, max_file_size_gibibyte: int = 15): max_file_size_bytes = max_file_size_gibibyte << 30 shards = [] - shard, shard_size = {}, 0 + shard: Dict[str, mx.array] = {} + shard_size = 0 for k, v in weights.items(): if shard_size + v.nbytes > max_file_size_bytes: shards.append(shard) @@ -83,7 +84,7 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15): return shards -def save_model(save_dir: str, weights, tokenizer, config): +def save_model(save_dir: Union[str, Path], weights, tokenizer, config): save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) @@ -96,7 +97,10 @@ def save_model(save_dir: str, weights, tokenizer, config): ) total_size = sum(v.nbytes for v in weights.values()) - index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + index_data: Dict[str, Any] = { + "metadata": {"total_size": total_size}, + "weight_map": {}, + } for i, shard in enumerate(shards): shard_name = shard_file_format.format(i + 1, shards_count)