diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 14026f0c..3083723a 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -26,7 +26,10 @@ def make_prompt_cache( if hasattr(model, "make_cache"): return model.make_cache() - num_layers = len(model.layers) + if hasattr(model, "layers"): + num_layers = len(model.layers) + else: + num_layers = len(model.model.layers) if max_kv_size is not None: return [ RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)