diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index f2345394..b069edef 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -49,7 +49,7 @@ class TestGenerate(unittest.TestCase): verbose=False, logits_processors=[logits_processor], ) - self.assertEqual(len(all_toks), len(init_toks) + 5) + self.assertEqual(all_toks.shape[-1], len(init_toks) + 5) if __name__ == "__main__":