From cded14988ca0c1bd88f7b9f19f8e7b902ed50504 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 01:53:15 -0800 Subject: [PATCH] fix test_generate --- llms/tests/test_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index f2345394..b069edef 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -49,7 +49,7 @@ class TestGenerate(unittest.TestCase): verbose=False, logits_processors=[logits_processor], ) - self.assertEqual(len(all_toks), len(init_toks) + 5) + self.assertEqual(all_toks.shape[-1], len(init_toks) + 5) if __name__ == "__main__":