mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-03 06:54:33 +08:00
Encodec (#991)
* initial encodec * works * nits * use fast group norm * fix for rnn layer * fix mlx version * use custom LSTM kernel * audio encodec * fix example, support batched inference * nits
This commit is contained in:
37
encodec/example.py
Normal file
37
encodec/example.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from utils import load, load_audio, save_audio
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
|
||||
# Load an audio file
|
||||
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
||||
|
||||
# Preprocess the audio (this can also be a list of arrays for batched
|
||||
# processing).
|
||||
feats, mask = processor(audio)
|
||||
|
||||
|
||||
# Encode at the given bandwidth. A lower bandwidth results in more
|
||||
# compression but lower reconstruction quality.
|
||||
@mx.compile
|
||||
def encode(feats, mask):
|
||||
return model.encode(feats, mask, bandwidth=3)
|
||||
|
||||
|
||||
# Decode to reconstruct the audio
|
||||
@mx.compile
|
||||
def decode(codes, scales, mask):
|
||||
return model.decode(codes, scales, mask)
|
||||
|
||||
|
||||
codes, scales = encode(feats, mask)
|
||||
reconstructed = decode(codes, scales, mask)
|
||||
|
||||
# Trim any padding:
|
||||
reconstructed = reconstructed[0, : len(audio)]
|
||||
|
||||
# Save the audio as a wave file
|
||||
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
|
Reference in New Issue
Block a user