mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 13:00:00 +08:00
stereo musicgen model output properly encoded/decoded and output as stereo wav
This commit is contained in:
parent
07f88f8057
commit
1c0e33615a
@ -460,8 +460,9 @@ class EncodecResidualVectorQuantizer(nn.Module):
|
||||
def decode(self, codes: mx.array) -> mx.array:
|
||||
"""Decode the given codes to the quantized representation."""
|
||||
quantized_out = None
|
||||
total_layers = len(self.layers)
|
||||
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
|
||||
layer = self.layers[i]
|
||||
layer = self.layers[i % total_layers]
|
||||
quantized = layer.decode(indices.squeeze(1))
|
||||
if quantized_out is None:
|
||||
quantized_out = quantized
|
||||
|
@ -6,10 +6,20 @@ import numpy as np
|
||||
|
||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||
"""
|
||||
Save audio to a wave (.wav) file.
|
||||
Save audio to a wave (.wav) file, supporting both mono and stereo.
|
||||
"""
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
# Clip and scale audio
|
||||
audio = mx.clip(audio, -1, 1)
|
||||
audio = (audio * 32767).astype(mx.int16)
|
||||
write(file, sampling_rate, np.array(audio))
|
||||
|
||||
# Convert to numpy array
|
||||
audio_np = np.array(audio)
|
||||
|
||||
# Handle stereo by reshaping interleaved audio
|
||||
if audio_np.shape[1] == 1: # Single column
|
||||
# Reshape to (samples, 2) for stereo
|
||||
audio_np = audio_np.reshape(-1, 1).repeat(2, axis=1)
|
||||
|
||||
write(file, sampling_rate, audio_np)
|
||||
|
Loading…
Reference in New Issue
Block a user