Remove unnecessary changes

This commit is contained in:
Shunta Saito 2025-02-13 19:53:34 +09:00
parent 72269c306c
commit ebea6928a3
2 changed files with 3 additions and 5 deletions

View File

@ -26,10 +26,7 @@ def make_prompt_cache(
if hasattr(model, "make_cache"):
return model.make_cache()
if hasattr(model, "layers"):
num_layers = len(model.layers)
else:
num_layers = len(model.model.layers)
num_layers = len(model.layers)
if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)

View File

@ -19,6 +19,7 @@ from typing import (
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Type,
@ -43,6 +44,7 @@ from transformers import PreTrainedTokenizer
# Local imports
from .models import cache
from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters, nparams
@ -1048,7 +1050,6 @@ def convert(
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))
dtype = getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}