Add recurrent gemma (#856)

* add recurrent gemma

* fix window cache
This commit is contained in:
Awni Hannun
2024-07-07 12:10:04 -07:00
committed by GitHub
parent 1e05aef344
commit 20e221f7f7
2 changed files with 514 additions and 6 deletions

View File

@@ -181,12 +181,15 @@ def generate_step(
)
y = prompt
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
if hasattr(model, "make_cache"):
cache = model.make_cache()
else:
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
repetition_context = prompt.tolist()