diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 10292d75..4d69115e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -191,6 +191,8 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ prompt_cache[i] = prompt_cache[i].to_quantized( group_size=kv_group_size, bits=kv_bits ) + + def generate_step( prompt: mx.array, model: nn.Module, diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index d6decb3f..3097c522 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -863,7 +863,10 @@ class TestModels(unittest.TestCase): sliding_window_pattern=4, ) model = cohere2.Model(args) - self.model_test_runner(model, args.model_type, args.vocab_size, args.num_hidden_layers) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main()