mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
work with tuple shape (#393)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
mlx
|
||||
mlx>=0.1
|
||||
numba
|
||||
numpy
|
||||
torch
|
||||
|
@@ -303,7 +303,7 @@ class TestWhisper(unittest.TestCase):
|
||||
def check_segment(seg, expected):
|
||||
for k, v in expected.items():
|
||||
if isinstance(v, float):
|
||||
self.assertAlmostEqual(seg[k], v, places=3)
|
||||
self.assertAlmostEqual(seg[k], v, places=2)
|
||||
else:
|
||||
self.assertEqual(seg[k], v)
|
||||
|
||||
|
@@ -50,7 +50,7 @@ def detect_language(
|
||||
mel = mel[None]
|
||||
|
||||
# skip encoder forward pass if already-encoded audio features were given
|
||||
if mel.shape[-2:] != [model.dims.n_audio_ctx, model.dims.n_audio_state]:
|
||||
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||
mel = model.encoder(mel)
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
|
Reference in New Issue
Block a user