mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
stereo musicgen model output properly encoded/decoded and output as stereo wav
This commit is contained in:
parent
07f88f8057
commit
1c0e33615a
@ -460,8 +460,9 @@ class EncodecResidualVectorQuantizer(nn.Module):
|
|||||||
def decode(self, codes: mx.array) -> mx.array:
|
def decode(self, codes: mx.array) -> mx.array:
|
||||||
"""Decode the given codes to the quantized representation."""
|
"""Decode the given codes to the quantized representation."""
|
||||||
quantized_out = None
|
quantized_out = None
|
||||||
|
total_layers = len(self.layers)
|
||||||
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
|
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i % total_layers]
|
||||||
quantized = layer.decode(indices.squeeze(1))
|
quantized = layer.decode(indices.squeeze(1))
|
||||||
if quantized_out is None:
|
if quantized_out is None:
|
||||||
quantized_out = quantized
|
quantized_out = quantized
|
||||||
|
@ -6,10 +6,20 @@ import numpy as np
|
|||||||
|
|
||||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||||
"""
|
"""
|
||||||
Save audio to a wave (.wav) file.
|
Save audio to a wave (.wav) file, supporting both mono and stereo.
|
||||||
"""
|
"""
|
||||||
from scipy.io.wavfile import write
|
from scipy.io.wavfile import write
|
||||||
|
|
||||||
|
# Clip and scale audio
|
||||||
audio = mx.clip(audio, -1, 1)
|
audio = mx.clip(audio, -1, 1)
|
||||||
audio = (audio * 32767).astype(mx.int16)
|
audio = (audio * 32767).astype(mx.int16)
|
||||||
write(file, sampling_rate, np.array(audio))
|
|
||||||
|
# Convert to numpy array
|
||||||
|
audio_np = np.array(audio)
|
||||||
|
|
||||||
|
# Handle stereo by reshaping interleaved audio
|
||||||
|
if audio_np.shape[1] == 1: # Single column
|
||||||
|
# Reshape to (samples, 2) for stereo
|
||||||
|
audio_np = audio_np.reshape(-1, 1).repeat(2, axis=1)
|
||||||
|
|
||||||
|
write(file, sampling_rate, audio_np)
|
||||||
|
Loading…
Reference in New Issue
Block a user