work with tuple shape (#393)

This commit is contained in:
Awni Hannun
2024-02-01 13:03:47 -08:00
committed by GitHub
parent 0340113e02
commit ec14583c2a
6 changed files with 5 additions and 22 deletions

View File

@@ -1,4 +1,4 @@
mlx
mlx>=0.1
numba
numpy
torch

View File

@@ -303,7 +303,7 @@ class TestWhisper(unittest.TestCase):
def check_segment(seg, expected):
for k, v in expected.items():
if isinstance(v, float):
self.assertAlmostEqual(seg[k], v, places=3)
self.assertAlmostEqual(seg[k], v, places=2)
else:
self.assertEqual(seg[k], v)

View File

@@ -50,7 +50,7 @@ def detect_language(
mel = mel[None]
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != [model.dims.n_audio_ctx, model.dims.n_audio_state]:
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript