Added more fixes

This commit is contained in:
paramthakkar123 2025-04-07 08:59:07 +05:30
parent 298178d669
commit 4304f5aaf5
4 changed files with 24 additions and 14 deletions

View File

@ -160,16 +160,16 @@ class SpeculativeDecoder:
) )
n_accepted += num_to_accept n_accepted += num_to_accept
n_draft += draft_tokens.size n_draft += len(draft_tokens)
# Rewind the cache for unaccepted tokens: # Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept: if (n := len(draft_tokens)) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size) self.draft_model.truncate_cache(n - len(new_tokens))
self.model.truncate_cache(n - new_tokens.size + 1) self.model.truncate_cache(n - len(new_tokens) + 1)
n_steps += 1 n_steps += 1
for t in new_tokens.tolist(): for t in list(new_tokens):
if t == self.tokenizer.eos_id or ntoks >= max_tokens: if t == self.tokenizer.eos_id or ntoks >= max_tokens:
break break
outputs.append(t) outputs.append(t)
@ -181,7 +181,7 @@ class SpeculativeDecoder:
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id: if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :] draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
inputs = draft_inputs[-1:] inputs = draft_inputs[-1:]
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True) print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)

View File

@ -3,6 +3,7 @@
import argparse import argparse
import json import json
import math import math
import os
import sys import sys
import time import time
from pathlib import Path from pathlib import Path
@ -16,7 +17,7 @@ from mlx.utils import tree_flatten
from models import LoRALinear from models import LoRALinear
# Disable output buffering to see print statements in real-time # Disable output buffering to see print statements in real-time
sys.stdout.reconfigure(line_buffering=True) sys.stdout = os.fdopen(sys.stdout.fileno(), "w", buffering=1)
def build_parser(): def build_parser():

View File

@ -17,10 +17,10 @@ class ModelArgs:
num_attention_heads: int num_attention_heads: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: int = None num_key_value_heads: Optional[int] = None
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
model_type: str = None model_type: Optional[str] = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self): def __post_init__(self):
@ -136,6 +136,11 @@ class Attention(nn.Module):
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
if n_heads is None or n_kv_heads is None:
raise ValueError(
"num_attention_heads and num_key_value_heads must not be None"
)
self.repeats = n_heads // n_kv_heads self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
@ -146,7 +151,7 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = ( rope_scale = (
1 / args.rope_scaling["factor"] 1 / float(args.rope_scaling["factor"])
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1 else 1
) )

View File

@ -4,7 +4,7 @@ import glob
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Generator from typing import Any, Dict, Generator, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -72,7 +72,8 @@ python generate.py --model {repo_id} --prompt "My name is"
def make_shards(weights: dict, max_file_size_gibibyte: int = 15): def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30 max_file_size_bytes = max_file_size_gibibyte << 30
shards = [] shards = []
shard, shard_size = {}, 0 shard: Dict[str, mx.array] = {}
shard_size = 0
for k, v in weights.items(): for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes: if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard) shards.append(shard)
@ -83,7 +84,7 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
return shards return shards
def save_model(save_dir: str, weights, tokenizer, config): def save_model(save_dir: Union[str, Path], weights, tokenizer, config):
save_dir = Path(save_dir) save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
@ -96,7 +97,10 @@ def save_model(save_dir: str, weights, tokenizer, config):
) )
total_size = sum(v.nbytes for v in weights.values()) total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} index_data: Dict[str, Any] = {
"metadata": {"total_size": total_size},
"weight_map": {},
}
for i, shard in enumerate(shards): for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count) shard_name = shard_file_format.format(i + 1, shards_count)