mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
@@ -181,12 +181,15 @@ def generate_step(
|
||||
)
|
||||
|
||||
y = prompt
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
if hasattr(model, "make_cache"):
|
||||
cache = model.make_cache()
|
||||
else:
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
repetition_context = prompt.tolist()
|
||||
|
||||
|
Reference in New Issue
Block a user