fix test_generate

This commit is contained in:
L Lllvvuu 2024-12-27 01:53:15 -08:00
parent a28ca03e04
commit cded14988c
No known key found for this signature in database
GPG Key ID: CFAD5A25056DDD0F

View File

@ -49,7 +49,7 @@ class TestGenerate(unittest.TestCase):
verbose=False,
logits_processors=[logits_processor],
)
self.assertEqual(len(all_toks), len(init_toks) + 5)
self.assertEqual(all_toks.shape[-1], len(init_toks) + 5)
if __name__ == "__main__":