stereo musicgen model output properly encoded/decoded and output as stereo wav

This commit is contained in:
David Van de Ven 2025-01-20 15:41:40 -08:00
parent 07f88f8057
commit 1c0e33615a
2 changed files with 14 additions and 3 deletions

View File

@ -460,8 +460,9 @@ class EncodecResidualVectorQuantizer(nn.Module):
def decode(self, codes: mx.array) -> mx.array:
"""Decode the given codes to the quantized representation."""
quantized_out = None
total_layers = len(self.layers)
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
layer = self.layers[i]
layer = self.layers[i % total_layers]
quantized = layer.decode(indices.squeeze(1))
if quantized_out is None:
quantized_out = quantized

View File

@ -6,10 +6,20 @@ import numpy as np
def save_audio(file: str, audio: mx.array, sampling_rate: int):
"""
Save audio to a wave (.wav) file.
Save audio to a wave (.wav) file, supporting both mono and stereo.
"""
from scipy.io.wavfile import write
# Clip and scale audio
audio = mx.clip(audio, -1, 1)
audio = (audio * 32767).astype(mx.int16)
write(file, sampling_rate, np.array(audio))
# Convert to numpy array
audio_np = np.array(audio)
# Handle stereo by reshaping interleaved audio
if audio_np.shape[1] == 1: # Single column
# Reshape to (samples, 2) for stereo
audio_np = audio_np.reshape(-1, 1).repeat(2, axis=1)
write(file, sampling_rate, audio_np)