Added more fixes

This commit is contained in:
paramthakkar123 2025-04-07 08:59:07 +05:30
parent 298178d669
commit 4304f5aaf5
4 changed files with 24 additions and 14 deletions

View File

@ -160,16 +160,16 @@ class SpeculativeDecoder:
)
n_accepted += num_to_accept
n_draft += draft_tokens.size
n_draft += len(draft_tokens)
# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size)
self.model.truncate_cache(n - new_tokens.size + 1)
if (n := len(draft_tokens)) > num_to_accept:
self.draft_model.truncate_cache(n - len(new_tokens))
self.model.truncate_cache(n - len(new_tokens) + 1)
n_steps += 1
for t in new_tokens.tolist():
for t in list(new_tokens):
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
break
outputs.append(t)
@ -181,7 +181,7 @@ class SpeculativeDecoder:
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
inputs = draft_inputs[-1:]
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)

View File

@ -3,6 +3,7 @@
import argparse
import json
import math
import os
import sys
import time
from pathlib import Path
@ -16,7 +17,7 @@ from mlx.utils import tree_flatten
from models import LoRALinear
# Disable output buffering to see print statements in real-time
sys.stdout.reconfigure(line_buffering=True)
sys.stdout = os.fdopen(sys.stdout.fileno(), "w", buffering=1)
def build_parser():

View File

@ -17,10 +17,10 @@ class ModelArgs:
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
model_type: Optional[str] = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
@ -136,6 +136,11 @@ class Attention(nn.Module):
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
if n_heads is None or n_kv_heads is None:
raise ValueError(
"num_attention_heads and num_key_value_heads must not be None"
)
self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads
@ -146,7 +151,7 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
1 / float(args.rope_scaling["factor"])
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)

View File

@ -4,7 +4,7 @@ import glob
import json
import logging
from pathlib import Path
from typing import Generator
from typing import Any, Dict, Generator, Union
import mlx.core as mx
import mlx.nn as nn
@ -72,7 +72,8 @@ python generate.py --model {repo_id} --prompt "My name is"
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard, shard_size = {}, 0
shard: Dict[str, mx.array] = {}
shard_size = 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
@ -83,7 +84,7 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
return shards
def save_model(save_dir: str, weights, tokenizer, config):
def save_model(save_dir: Union[str, Path], weights, tokenizer, config):
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
@ -96,7 +97,10 @@ def save_model(save_dir: str, weights, tokenizer, config):
)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
index_data: Dict[str, Any] = {
"metadata": {"total_size": total_size},
"weight_map": {},
}
for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count)