diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py index 81fa41e3..e04309c1 100644 --- a/whisper/mlx_whisper/audio.py +++ b/whisper/mlx_whisper/audio.py @@ -151,8 +151,6 @@ def log_mel_spectrogram( mx.array, shape = (80, n_frames) An array that contains the Mel spectrogram """ - device = mx.default_device() - mx.set_default_device(mx.cpu) if isinstance(audio, str): audio = load_audio(audio) elif not isinstance(audio, mx.array): @@ -170,5 +168,4 @@ def log_mel_spectrogram( log_spec = mx.maximum(mel_spec, 1e-10).log10() log_spec = mx.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - mx.set_default_device(device) return log_spec