This commit is contained in:
Awni Hannun 2024-12-16 07:53:07 -08:00
parent 4aee86243e
commit dec2acface
2 changed files with 6 additions and 1 deletions

View File

@ -191,6 +191,8 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
def generate_step(
prompt: mx.array,
model: nn.Module,

View File

@ -863,7 +863,10 @@ class TestModels(unittest.TestCase):
sliding_window_pattern=4,
)
model = cohere2.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.num_hidden_layers)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__":
unittest.main()