From 1c0e33615a08b460753ab0552443c18ace2df7b4 Mon Sep 17 00:00:00 2001 From: David Van de Ven Date: Mon, 20 Jan 2025 15:41:40 -0800 Subject: [PATCH] stereo musicgen model output properly encoded/decoded and output as stereo wav --- encodec/encodec.py | 3 ++- musicgen/utils.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/encodec/encodec.py b/encodec/encodec.py index 4b85dfdd..7dcf14fe 100644 --- a/encodec/encodec.py +++ b/encodec/encodec.py @@ -460,8 +460,9 @@ class EncodecResidualVectorQuantizer(nn.Module): def decode(self, codes: mx.array) -> mx.array: """Decode the given codes to the quantized representation.""" quantized_out = None + total_layers = len(self.layers) for i, indices in enumerate(codes.split(codes.shape[1], axis=1)): - layer = self.layers[i] + layer = self.layers[i % total_layers] quantized = layer.decode(indices.squeeze(1)) if quantized_out is None: quantized_out = quantized diff --git a/musicgen/utils.py b/musicgen/utils.py index 78e92571..dc4ed0a5 100644 --- a/musicgen/utils.py +++ b/musicgen/utils.py @@ -6,10 +6,20 @@ import numpy as np def save_audio(file: str, audio: mx.array, sampling_rate: int): """ - Save audio to a wave (.wav) file. + Save audio to a wave (.wav) file, supporting both mono and stereo. """ from scipy.io.wavfile import write + # Clip and scale audio audio = mx.clip(audio, -1, 1) audio = (audio * 32767).astype(mx.int16) - write(file, sampling_rate, np.array(audio)) + + # Convert to numpy array + audio_np = np.array(audio) + + # Handle stereo by reshaping interleaved audio + if audio_np.shape[1] == 1: # Single column + # Reshape to (samples, 2) for stereo + audio_np = audio_np.reshape(-1, 1).repeat(2, axis=1) + + write(file, sampling_rate, audio_np)