diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 3083723a..14026f0c 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -26,10 +26,7 @@ def make_prompt_cache( if hasattr(model, "make_cache"): return model.make_cache() - if hasattr(model, "layers"): - num_layers = len(model.layers) - else: - num_layers = len(model.model.layers) + num_layers = len(model.layers) if max_kv_size is not None: return [ RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ff793ee5..78a2e802 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,6 +19,7 @@ from typing import ( Dict, Generator, List, + NamedTuple, Optional, Tuple, Type, @@ -43,6 +44,7 @@ from transformers import PreTrainedTokenizer # Local imports from .models import cache +from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model from .tuner.utils import load_adapters, nparams @@ -1048,7 +1050,6 @@ def convert( model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) - dtype = getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()}