mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Added more fixes
This commit is contained in:
parent
298178d669
commit
4304f5aaf5
@ -160,16 +160,16 @@ class SpeculativeDecoder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
n_accepted += num_to_accept
|
n_accepted += num_to_accept
|
||||||
n_draft += draft_tokens.size
|
n_draft += len(draft_tokens)
|
||||||
|
|
||||||
# Rewind the cache for unaccepted tokens:
|
# Rewind the cache for unaccepted tokens:
|
||||||
if (n := draft_tokens.size) > num_to_accept:
|
if (n := len(draft_tokens)) > num_to_accept:
|
||||||
self.draft_model.truncate_cache(n - new_tokens.size)
|
self.draft_model.truncate_cache(n - len(new_tokens))
|
||||||
self.model.truncate_cache(n - new_tokens.size + 1)
|
self.model.truncate_cache(n - len(new_tokens) + 1)
|
||||||
|
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
for t in new_tokens.tolist():
|
for t in list(new_tokens):
|
||||||
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
|
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
|
||||||
break
|
break
|
||||||
outputs.append(t)
|
outputs.append(t)
|
||||||
@ -181,7 +181,7 @@ class SpeculativeDecoder:
|
|||||||
|
|
||||||
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
||||||
break
|
break
|
||||||
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
|
draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
|
||||||
inputs = draft_inputs[-1:]
|
inputs = draft_inputs[-1:]
|
||||||
|
|
||||||
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
|
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -16,7 +17,7 @@ from mlx.utils import tree_flatten
|
|||||||
from models import LoRALinear
|
from models import LoRALinear
|
||||||
|
|
||||||
# Disable output buffering to see print statements in real-time
|
# Disable output buffering to see print statements in real-time
|
||||||
sys.stdout.reconfigure(line_buffering=True)
|
sys.stdout = os.fdopen(sys.stdout.fileno(), "w", buffering=1)
|
||||||
|
|
||||||
|
|
||||||
def build_parser():
|
def build_parser():
|
||||||
|
@ -17,10 +17,10 @@ class ModelArgs:
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int = None
|
num_key_value_heads: Optional[int] = None
|
||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
model_type: str = None
|
model_type: Optional[str] = None
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -136,6 +136,11 @@ class Attention(nn.Module):
|
|||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
if n_heads is None or n_kv_heads is None:
|
||||||
|
raise ValueError(
|
||||||
|
"num_attention_heads and num_key_value_heads must not be None"
|
||||||
|
)
|
||||||
|
|
||||||
self.repeats = n_heads // n_kv_heads
|
self.repeats = n_heads // n_kv_heads
|
||||||
|
|
||||||
head_dim = args.hidden_size // n_heads
|
head_dim = args.hidden_size // n_heads
|
||||||
@ -146,7 +151,7 @@ class Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
rope_scale = (
|
rope_scale = (
|
||||||
1 / args.rope_scaling["factor"]
|
1 / float(args.rope_scaling["factor"])
|
||||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||||
else 1
|
else 1
|
||||||
)
|
)
|
||||||
|
@ -4,7 +4,7 @@ import glob
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator
|
from typing import Any, Dict, Generator, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -72,7 +72,8 @@ python generate.py --model {repo_id} --prompt "My name is"
|
|||||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||||
shards = []
|
shards = []
|
||||||
shard, shard_size = {}, 0
|
shard: Dict[str, mx.array] = {}
|
||||||
|
shard_size = 0
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if shard_size + v.nbytes > max_file_size_bytes:
|
if shard_size + v.nbytes > max_file_size_bytes:
|
||||||
shards.append(shard)
|
shards.append(shard)
|
||||||
@ -83,7 +84,7 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
|||||||
return shards
|
return shards
|
||||||
|
|
||||||
|
|
||||||
def save_model(save_dir: str, weights, tokenizer, config):
|
def save_model(save_dir: Union[str, Path], weights, tokenizer, config):
|
||||||
save_dir = Path(save_dir)
|
save_dir = Path(save_dir)
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@ -96,7 +97,10 @@ def save_model(save_dir: str, weights, tokenizer, config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
total_size = sum(v.nbytes for v in weights.values())
|
total_size = sum(v.nbytes for v in weights.values())
|
||||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
index_data: Dict[str, Any] = {
|
||||||
|
"metadata": {"total_size": total_size},
|
||||||
|
"weight_map": {},
|
||||||
|
}
|
||||||
|
|
||||||
for i, shard in enumerate(shards):
|
for i, shard in enumerate(shards):
|
||||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||||
|
Loading…
Reference in New Issue
Block a user