mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Encodec (#991)
* initial encodec * works * nits * use fast group norm * fix for rnn layer * fix mlx version * use custom LSTM kernel * audio encodec * fix example, support batched inference * nits
This commit is contained in:
parent
796d5e40e4
commit
9bb2dd62f3
@ -27,6 +27,7 @@ Some more useful examples are listed below.
|
|||||||
### Audio Models
|
### Audio Models
|
||||||
|
|
||||||
- Speech recognition with [OpenAI's Whisper](whisper).
|
- Speech recognition with [OpenAI's Whisper](whisper).
|
||||||
|
- Audio compression and generation with [Meta's EnCodec](encodec).
|
||||||
|
|
||||||
### Multimodal models
|
### Multimodal models
|
||||||
|
|
||||||
|
83
encodec/README.md
Normal file
83
encodec/README.md
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
# EnCodec
|
||||||
|
|
||||||
|
An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
|
||||||
|
generate audio.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
Install the requirements:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Optionally install FFmpeg and SciPy for loading and saving audio files,
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
Install [FFmpeg](https://ffmpeg.org/):
|
||||||
|
|
||||||
|
```
|
||||||
|
# on macOS using Homebrew (https://brew.sh/)
|
||||||
|
brew install ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
Install SciPy:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install scipy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example
|
||||||
|
|
||||||
|
An example using the model:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import mlx.core as mx
|
||||||
|
from utils import load, load_audio, save_audio
|
||||||
|
|
||||||
|
# Load the 48 KHz model and preprocessor.
|
||||||
|
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
|
# Load an audio file
|
||||||
|
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
|
||||||
|
|
||||||
|
# Preprocess the audio (this can also be a list of arrays for batched
|
||||||
|
# processing).
|
||||||
|
feats, mask = processor(audio)
|
||||||
|
|
||||||
|
# Encode at the given bandwidth. A lower bandwidth results in more
|
||||||
|
# compression but lower reconstruction quality.
|
||||||
|
@mx.compile
|
||||||
|
def encode(feats, mask):
|
||||||
|
return model.encode(feats, mask, bandwidth=3)
|
||||||
|
|
||||||
|
# Decode to reconstruct the audio
|
||||||
|
@mx.compile
|
||||||
|
def decode(codes, scales, mask):
|
||||||
|
return model.decode(codes, scales, mask)
|
||||||
|
|
||||||
|
|
||||||
|
codes, scales = encode(feats, mask)
|
||||||
|
reconstructed = decode(codes, scales, mask)
|
||||||
|
|
||||||
|
# Trim any padding:
|
||||||
|
reconstructed = reconstructed[0, : len(audio)]
|
||||||
|
|
||||||
|
# Save the audio as a wave file
|
||||||
|
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
|
||||||
|
```
|
||||||
|
|
||||||
|
The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
|
||||||
|
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
|
||||||
|
in several data types.
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
To convert models, use the `convert.py` script. To see the options, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python convert.py -h
|
||||||
|
```
|
||||||
|
|
||||||
|
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
|
||||||
|
[code](https://github.com/facebookresearch/encodec) for more details.
|
30
encodec/benchmarks/bench_mx.py
Normal file
30
encodec/benchmarks/bench_mx.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from utils import load
|
||||||
|
|
||||||
|
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
|
audio = mx.random.uniform(shape=(288000, 2))
|
||||||
|
feats, mask = processor(audio)
|
||||||
|
mx.eval(model, feats, mask)
|
||||||
|
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun():
|
||||||
|
codes, scales = model.encode(feats, mask, bandwidth=3)
|
||||||
|
reconstructed = model.decode(codes, scales, mask)
|
||||||
|
return reconstructed
|
||||||
|
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
mx.eval(fun())
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
for _ in range(10):
|
||||||
|
mx.eval(fun())
|
||||||
|
toc = time.time()
|
||||||
|
ms = 1000 * (toc - tic) / 10
|
||||||
|
print(f"Time per it: {ms:.3f}")
|
34
encodec/benchmarks/bench_pt.py
Normal file
34
encodec/benchmarks/bench_pt.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, EncodecModel
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
||||||
|
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)
|
||||||
|
|
||||||
|
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
|
||||||
|
pt_inputs = processor(
|
||||||
|
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
|
||||||
|
).to("mps")
|
||||||
|
|
||||||
|
|
||||||
|
def fun():
|
||||||
|
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
|
||||||
|
pt_audio = pt_model.decode(
|
||||||
|
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
|
||||||
|
)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
fun()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
for _ in range(10):
|
||||||
|
fun()
|
||||||
|
toc = time.time()
|
||||||
|
ms = 1000 * (toc - tic) / 10
|
||||||
|
print(f"Time per it: {ms:.3f}")
|
213
encodec/convert.py
Normal file
213
encodec/convert.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from textwrap import dedent
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
import encodec
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_from_hub(hf_repo: str) -> Path:
|
||||||
|
model_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=hf_repo,
|
||||||
|
allow_patterns=["*.json", "*.safetensors"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||||
|
"""
|
||||||
|
Uploads the model to Hugging Face hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Local path to the model.
|
||||||
|
upload_repo (str): Name of the HF repo to upload to.
|
||||||
|
hf_path (str): Path to the original Hugging Face model.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, ModelCard, logging
|
||||||
|
|
||||||
|
content = dedent(
|
||||||
|
f"""
|
||||||
|
---
|
||||||
|
language: en
|
||||||
|
license: other
|
||||||
|
library: mlx
|
||||||
|
tags:
|
||||||
|
- mlx
|
||||||
|
---
|
||||||
|
|
||||||
|
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
||||||
|
converted to MLX format from
|
||||||
|
[{hf_path}](https://huggingface.co/{hf_path}).
|
||||||
|
|
||||||
|
This model is intended to be used with the [EnCodec MLX
|
||||||
|
example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
card = ModelCard(content)
|
||||||
|
card.save(os.path.join(path, "README.md"))
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||||
|
api.upload_folder(
|
||||||
|
folder_path=path,
|
||||||
|
repo_id=upload_repo,
|
||||||
|
repo_type="model",
|
||||||
|
multi_commits=True,
|
||||||
|
multi_commits_verbose=True,
|
||||||
|
)
|
||||||
|
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
||||||
|
|
||||||
|
|
||||||
|
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||||
|
if isinstance(save_path, str):
|
||||||
|
save_path = Path(save_path)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
total_size = sum(v.nbytes for v in weights.values())
|
||||||
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||||
|
mx.save_safetensors(
|
||||||
|
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
|
||||||
|
)
|
||||||
|
|
||||||
|
for weight_name in weights.keys():
|
||||||
|
index_data["weight_map"][weight_name] = "model.safetensors"
|
||||||
|
|
||||||
|
index_data["weight_map"] = {
|
||||||
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(save_path / "model.safetensors.index.json", "w") as f:
|
||||||
|
json.dump(index_data, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(
|
||||||
|
config: dict,
|
||||||
|
config_path: Union[str, Path],
|
||||||
|
) -> None:
|
||||||
|
"""Save the model configuration to the ``config_path``.
|
||||||
|
|
||||||
|
The final configuration will be sorted before saving for better readability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): The model configuration.
|
||||||
|
config_path (Union[str, Path]): Model configuration file path.
|
||||||
|
"""
|
||||||
|
# Clean unused keys
|
||||||
|
config.pop("_name_or_path", None)
|
||||||
|
|
||||||
|
# sort the config for better readability
|
||||||
|
config = dict(sorted(config.items()))
|
||||||
|
|
||||||
|
# write the updated config to the config_path (if provided)
|
||||||
|
with open(config_path, "w") as fid:
|
||||||
|
json.dump(config, fid, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(
|
||||||
|
upload: bool,
|
||||||
|
model: str,
|
||||||
|
dtype: str = None,
|
||||||
|
):
|
||||||
|
hf_repo = f"facebook/encodec_{model}"
|
||||||
|
mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
|
||||||
|
path = fetch_from_hub(hf_repo)
|
||||||
|
save_path = Path("mlx_models")
|
||||||
|
|
||||||
|
weights = mx.load(str(Path(path) / "model.safetensors"))
|
||||||
|
|
||||||
|
with open(path / "config.json", "r") as fid:
|
||||||
|
config = SimpleNamespace(**json.load(fid))
|
||||||
|
|
||||||
|
model = encodec.EncodecModel(config)
|
||||||
|
|
||||||
|
new_weights = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
basename, pname = k.rsplit(".", 1)
|
||||||
|
if pname == "weight_v":
|
||||||
|
g = weights[basename + ".weight_g"]
|
||||||
|
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
|
||||||
|
k = basename + ".weight"
|
||||||
|
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
|
||||||
|
continue
|
||||||
|
elif "lstm" in basename:
|
||||||
|
w_or_b, ih_or_hh, ln = pname.split("_")
|
||||||
|
if w_or_b == "weight":
|
||||||
|
new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
|
||||||
|
elif w_or_b == "bias" and ih_or_hh == "ih":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
v = v + weights[k.replace("_hh_", "_ih_")]
|
||||||
|
new_pname = "bias"
|
||||||
|
k = basename + "." + ln[1:] + "." + new_pname
|
||||||
|
if "conv.weight" in k:
|
||||||
|
# Possibly a transposed conv which has a different order
|
||||||
|
if "decoder" in k:
|
||||||
|
ln = int(k.split(".")[2])
|
||||||
|
if "conv" in model.decoder.layers[ln] and isinstance(
|
||||||
|
model.decoder.layers[ln].conv, nn.ConvTranspose1d
|
||||||
|
):
|
||||||
|
v = mx.moveaxis(v, 0, 2)
|
||||||
|
else:
|
||||||
|
v = mx.moveaxis(v, 1, 2)
|
||||||
|
else:
|
||||||
|
v = mx.moveaxis(v, 1, 2)
|
||||||
|
|
||||||
|
new_weights[k] = v
|
||||||
|
weights = new_weights
|
||||||
|
|
||||||
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
t = getattr(mx, dtype)
|
||||||
|
weights = {k: v.astype(t) for k, v in weights.items()}
|
||||||
|
|
||||||
|
if isinstance(save_path, str):
|
||||||
|
save_path = Path(save_path)
|
||||||
|
|
||||||
|
save_weights(save_path, weights)
|
||||||
|
|
||||||
|
save_config(vars(config), config_path=save_path / "config.json")
|
||||||
|
|
||||||
|
if upload:
|
||||||
|
upload_to_hub(save_path, mlx_repo, hf_repo)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="48khz",
|
||||||
|
help="",
|
||||||
|
choices=["24khz", "32khz", "48khz"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upload",
|
||||||
|
action="store_true",
|
||||||
|
help="Upload the weights to Hugging Face.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
help="Data type to convert the model to.",
|
||||||
|
default="float32",
|
||||||
|
choices=["float32", "bfloat16", "float16"],
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert(upload=args.upload, model=args.model, dtype=args.dtype)
|
671
encodec/encodec.py
Normal file
671
encodec/encodec.py
Normal file
@ -0,0 +1,671 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_lstm_kernel = mx.fast.metal_kernel(
|
||||||
|
name="lstm",
|
||||||
|
input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
|
||||||
|
output_names=["hidden_state", "cell_state"],
|
||||||
|
header="""
|
||||||
|
template <typename T>
|
||||||
|
T sigmoid(T x) {
|
||||||
|
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||||
|
return (x < 0) ? 1 - y : y;
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
source="""
|
||||||
|
uint b = thread_position_in_grid.x;
|
||||||
|
uint d = hidden_size * 4;
|
||||||
|
|
||||||
|
uint elem = b * d + thread_position_in_grid.y;
|
||||||
|
uint index = elem;
|
||||||
|
uint x_index = b * num_time_steps * d + time_step * d + index;
|
||||||
|
|
||||||
|
auto i = sigmoid(h_in[index] + x[x_index]);
|
||||||
|
index += hidden_size;
|
||||||
|
x_index += hidden_size;
|
||||||
|
auto f = sigmoid(h_in[index] + x[x_index]);
|
||||||
|
index += hidden_size;
|
||||||
|
x_index += hidden_size;
|
||||||
|
auto g = metal::precise::tanh(h_in[index] + x[x_index]);
|
||||||
|
index += hidden_size;
|
||||||
|
x_index += hidden_size;
|
||||||
|
auto o = sigmoid(h_in[index] + x[x_index]);
|
||||||
|
|
||||||
|
cell_state[elem] = f * cell[elem] + i * g;
|
||||||
|
hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def lstm_custom(x, h_in, cell, time_step):
|
||||||
|
assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
|
||||||
|
out_shape = cell.shape
|
||||||
|
return _lstm_kernel(
|
||||||
|
inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
|
||||||
|
output_shapes=[out_shape, out_shape],
|
||||||
|
output_dtypes=[h_in.dtype, h_in.dtype],
|
||||||
|
grid=(x.shape[0], h_in.size // 4, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LSTM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.Wx = mx.zeros((4 * hidden_size, input_size))
|
||||||
|
self.Wh = mx.zeros((4 * hidden_size, hidden_size))
|
||||||
|
self.bias = mx.zeros((4 * hidden_size,)) if bias else None
|
||||||
|
|
||||||
|
def __call__(self, x, hidden=None, cell=None):
|
||||||
|
if self.bias is not None:
|
||||||
|
x = mx.addmm(self.bias, x, self.Wx.T)
|
||||||
|
else:
|
||||||
|
x = x @ self.Wx.T
|
||||||
|
|
||||||
|
all_hidden = []
|
||||||
|
|
||||||
|
B = x.shape[0]
|
||||||
|
cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
|
||||||
|
for t in range(x.shape[-2]):
|
||||||
|
if hidden is None:
|
||||||
|
hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
|
||||||
|
else:
|
||||||
|
hidden = hidden @ self.Wh.T
|
||||||
|
hidden, cell = lstm_custom(x, hidden, cell, t)
|
||||||
|
all_hidden.append(hidden)
|
||||||
|
|
||||||
|
return mx.stack(all_hidden, axis=-2)
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecConv1d(nn.Module):
|
||||||
|
"""Conv1d with asymmetric or causal padding and normalization."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.causal = config.use_causal_conv
|
||||||
|
self.pad_mode = config.pad_mode
|
||||||
|
self.norm_type = config.norm_type
|
||||||
|
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
in_channels, out_channels, kernel_size, stride, dilation=dilation
|
||||||
|
)
|
||||||
|
if self.norm_type == "time_group_norm":
|
||||||
|
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
# Effective kernel size with dilations.
|
||||||
|
self.kernel_size = (kernel_size - 1) * dilation + 1
|
||||||
|
|
||||||
|
self.padding_total = kernel_size - stride
|
||||||
|
|
||||||
|
def _get_extra_padding_for_conv1d(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
length = hidden_states.shape[1]
|
||||||
|
n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
|
||||||
|
n_frames = int(math.ceil(n_frames)) - 1
|
||||||
|
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
|
||||||
|
return ideal_length - length
|
||||||
|
|
||||||
|
def _pad1d(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
paddings: Tuple[int, int],
|
||||||
|
mode: str = "zero",
|
||||||
|
value: float = 0.0,
|
||||||
|
):
|
||||||
|
if mode != "reflect":
|
||||||
|
return mx.pad(
|
||||||
|
hidden_states, paddings, mode="constant", constant_values=value
|
||||||
|
)
|
||||||
|
|
||||||
|
length = hidden_states.shape[1]
|
||||||
|
prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
|
||||||
|
suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
|
||||||
|
return mx.concatenate([prefix, hidden_states, suffix], axis=1)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
# Left padding for causal
|
||||||
|
hidden_states = self._pad1d(
|
||||||
|
hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Asymmetric padding required for odd strides
|
||||||
|
padding_right = self.padding_total // 2
|
||||||
|
padding_left = self.padding_total - padding_right
|
||||||
|
hidden_states = self._pad1d(
|
||||||
|
hidden_states,
|
||||||
|
(padding_left, padding_right + extra_padding),
|
||||||
|
mode=self.pad_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
|
||||||
|
if self.norm_type == "time_group_norm":
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecConvTranspose1d(nn.Module):
|
||||||
|
"""ConvTranspose1d with asymmetric or causal padding and normalization."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.causal = config.use_causal_conv
|
||||||
|
self.trim_right_ratio = config.trim_right_ratio
|
||||||
|
self.norm_type = config.norm_type
|
||||||
|
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
||||||
|
if config.norm_type == "time_group_norm":
|
||||||
|
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
|
||||||
|
self.padding_total = kernel_size - stride
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
|
||||||
|
if self.norm_type == "time_group_norm":
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
||||||
|
else:
|
||||||
|
padding_right = self.padding_total // 2
|
||||||
|
|
||||||
|
padding_left = self.padding_total - padding_right
|
||||||
|
|
||||||
|
end = hidden_states.shape[1] - padding_right
|
||||||
|
hidden_states = hidden_states[:, padding_left:end, :]
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecLSTM(nn.Module):
|
||||||
|
def __init__(self, config, dimension):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
h = hidden_states
|
||||||
|
for lstm in self.lstm:
|
||||||
|
h = lstm(h)
|
||||||
|
return h + hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecResnetBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Residual block from SEANet model as used by EnCodec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, dim: int, dilations: List[int]):
|
||||||
|
super().__init__()
|
||||||
|
kernel_sizes = (config.residual_kernel_size, 1)
|
||||||
|
if len(kernel_sizes) != len(dilations):
|
||||||
|
raise ValueError("Number of kernel sizes should match number of dilations")
|
||||||
|
|
||||||
|
hidden = dim // config.compress
|
||||||
|
block = []
|
||||||
|
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
||||||
|
in_chs = dim if i == 0 else hidden
|
||||||
|
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
||||||
|
block += [nn.ELU()]
|
||||||
|
block += [
|
||||||
|
EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
|
||||||
|
]
|
||||||
|
self.block = block
|
||||||
|
|
||||||
|
if getattr(config, "use_conv_shortcut", True):
|
||||||
|
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
|
||||||
|
else:
|
||||||
|
self.shortcut = nn.Identity()
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
residual = hidden_states
|
||||||
|
for layer in self.block:
|
||||||
|
hidden_states = layer(hidden_states)
|
||||||
|
|
||||||
|
return self.shortcut(residual) + hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecEncoder(nn.Module):
|
||||||
|
"""SEANet encoder as used by EnCodec."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
model = [
|
||||||
|
EncodecConv1d(
|
||||||
|
config, config.audio_channels, config.num_filters, config.kernel_size
|
||||||
|
)
|
||||||
|
]
|
||||||
|
scaling = 1
|
||||||
|
|
||||||
|
for ratio in reversed(config.upsampling_ratios):
|
||||||
|
current_scale = scaling * config.num_filters
|
||||||
|
for j in range(config.num_residual_layers):
|
||||||
|
model += [
|
||||||
|
EncodecResnetBlock(
|
||||||
|
config, current_scale, [config.dilation_growth_rate**j, 1]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
model += [nn.ELU()]
|
||||||
|
model += [
|
||||||
|
EncodecConv1d(
|
||||||
|
config,
|
||||||
|
current_scale,
|
||||||
|
current_scale * 2,
|
||||||
|
kernel_size=ratio * 2,
|
||||||
|
stride=ratio,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
scaling *= 2
|
||||||
|
|
||||||
|
model += [EncodecLSTM(config, scaling * config.num_filters)]
|
||||||
|
model += [nn.ELU()]
|
||||||
|
model += [
|
||||||
|
EncodecConv1d(
|
||||||
|
config,
|
||||||
|
scaling * config.num_filters,
|
||||||
|
config.hidden_size,
|
||||||
|
config.last_kernel_size,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = model
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecDecoder(nn.Module):
|
||||||
|
"""SEANet decoder as used by EnCodec."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
scaling = int(2 ** len(config.upsampling_ratios))
|
||||||
|
model = [
|
||||||
|
EncodecConv1d(
|
||||||
|
config,
|
||||||
|
config.hidden_size,
|
||||||
|
scaling * config.num_filters,
|
||||||
|
config.kernel_size,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
model += [EncodecLSTM(config, scaling * config.num_filters)]
|
||||||
|
|
||||||
|
for ratio in config.upsampling_ratios:
|
||||||
|
current_scale = scaling * config.num_filters
|
||||||
|
model += [nn.ELU()]
|
||||||
|
model += [
|
||||||
|
EncodecConvTranspose1d(
|
||||||
|
config,
|
||||||
|
current_scale,
|
||||||
|
current_scale // 2,
|
||||||
|
kernel_size=ratio * 2,
|
||||||
|
stride=ratio,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for j in range(config.num_residual_layers):
|
||||||
|
model += [
|
||||||
|
EncodecResnetBlock(
|
||||||
|
config, current_scale // 2, (config.dilation_growth_rate**j, 1)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
scaling //= 2
|
||||||
|
|
||||||
|
model += [nn.ELU()]
|
||||||
|
model += [
|
||||||
|
EncodecConv1d(
|
||||||
|
config,
|
||||||
|
config.num_filters,
|
||||||
|
config.audio_channels,
|
||||||
|
config.last_kernel_size,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.layers = model
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecEuclideanCodebook(nn.Module):
|
||||||
|
"""Codebook with Euclidean distance."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
|
||||||
|
|
||||||
|
def quantize(self, hidden_states):
|
||||||
|
embed = self.embed.T
|
||||||
|
scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
|
||||||
|
dist = -(
|
||||||
|
scaled_states
|
||||||
|
- 2 * hidden_states @ embed
|
||||||
|
+ embed.square().sum(axis=0, keepdims=True)
|
||||||
|
)
|
||||||
|
embed_ind = dist.argmax(axis=-1)
|
||||||
|
return embed_ind
|
||||||
|
|
||||||
|
def encode(self, hidden_states):
|
||||||
|
shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.reshape((-1, shape[-1]))
|
||||||
|
embed_ind = self.quantize(hidden_states)
|
||||||
|
embed_ind = embed_ind.reshape(*shape[:-1])
|
||||||
|
return embed_ind
|
||||||
|
|
||||||
|
def decode(self, embed_ind):
|
||||||
|
return self.embed[embed_ind]
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecVectorQuantization(nn.Module):
|
||||||
|
"""
|
||||||
|
Vector quantization implementation. Currently supports only euclidean distance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.codebook = EncodecEuclideanCodebook(config)
|
||||||
|
|
||||||
|
def encode(self, hidden_states):
|
||||||
|
return self.codebook.encode(hidden_states)
|
||||||
|
|
||||||
|
def decode(self, embed_ind):
|
||||||
|
return self.codebook.decode(embed_ind)
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecResidualVectorQuantizer(nn.Module):
|
||||||
|
"""Residual Vector Quantizer."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.codebook_size = config.codebook_size
|
||||||
|
|
||||||
|
hop_length = np.prod(config.upsampling_ratios)
|
||||||
|
self.frame_rate = math.ceil(config.sampling_rate / hop_length)
|
||||||
|
self.num_quantizers = int(
|
||||||
|
1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
|
||||||
|
)
|
||||||
|
self.layers = [
|
||||||
|
EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_num_quantizers_for_bandwidth(
|
||||||
|
self, bandwidth: Optional[float] = None
|
||||||
|
) -> int:
|
||||||
|
"""Return num_quantizers based on specified target bandwidth."""
|
||||||
|
bw_per_q = math.log2(self.codebook_size) * self.frame_rate
|
||||||
|
num_quantizers = self.num_quantizers
|
||||||
|
if bandwidth is not None and bandwidth > 0.0:
|
||||||
|
num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
|
||||||
|
return num_quantizers
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, embeddings: mx.array, bandwidth: Optional[float] = None
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Encode a given input array with the specified frame rate at the given
|
||||||
|
bandwidth. The RVQ encode method sets the appropriate number of
|
||||||
|
quantizers to use and returns indices for each quantizer.
|
||||||
|
"""
|
||||||
|
num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
|
||||||
|
residual = embeddings
|
||||||
|
all_indices = []
|
||||||
|
for layer in self.layers[:num_quantizers]:
|
||||||
|
indices = layer.encode(residual)
|
||||||
|
quantized = layer.decode(indices)
|
||||||
|
residual = residual - quantized
|
||||||
|
all_indices.append(indices)
|
||||||
|
out_indices = mx.stack(all_indices, axis=1)
|
||||||
|
return out_indices
|
||||||
|
|
||||||
|
def decode(self, codes: mx.array) -> mx.array:
|
||||||
|
"""Decode the given codes to the quantized representation."""
|
||||||
|
quantized_out = None
|
||||||
|
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
|
||||||
|
layer = self.layers[i]
|
||||||
|
quantized = layer.decode(indices.squeeze(1))
|
||||||
|
if quantized_out is None:
|
||||||
|
quantized_out = quantized
|
||||||
|
else:
|
||||||
|
quantized_out = quantized + quantized_out
|
||||||
|
return quantized_out
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecModel(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.encoder = EncodecEncoder(config)
|
||||||
|
self.decoder = EncodecDecoder(config)
|
||||||
|
self.quantizer = EncodecResidualVectorQuantizer(config)
|
||||||
|
|
||||||
|
def _encode_frame(
|
||||||
|
self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
|
||||||
|
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||||
|
"""
|
||||||
|
Encodes the given input using the underlying VQVAE.
|
||||||
|
"""
|
||||||
|
length = input_values.shape[1]
|
||||||
|
duration = length / self.config.sampling_rate
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config.chunk_length_s is not None
|
||||||
|
and duration > 1e-5 + self.config.chunk_length_s
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = None
|
||||||
|
if self.config.normalize:
|
||||||
|
# if the padding is non zero
|
||||||
|
input_values = input_values * padding_mask[..., None]
|
||||||
|
mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
|
||||||
|
scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
|
||||||
|
input_values = input_values / scale
|
||||||
|
|
||||||
|
embeddings = self.encoder(input_values)
|
||||||
|
codes = self.quantizer.encode(embeddings, bandwidth)
|
||||||
|
return codes, scale
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
input_values: mx.array,
|
||||||
|
padding_mask: mx.array = None,
|
||||||
|
bandwidth: Optional[float] = None,
|
||||||
|
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||||
|
"""
|
||||||
|
Encodes the input audio waveform into discrete codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_values (mx.array): The input audio waveform with shape
|
||||||
|
``(batch_size, channels, sequence_length)``.
|
||||||
|
padding_mask (mx.array): Padding mask used to pad the ``input_values``.
|
||||||
|
bandwidth (float, optional): The target bandwidth. Must be one of
|
||||||
|
``config.target_bandwidths``. If ``None``, uses the smallest
|
||||||
|
possible bandwidth. bandwidth is represented as a thousandth of
|
||||||
|
what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of frames containing the discrete encoded codes for the
|
||||||
|
input audio waveform, along with rescaling factors for each chunk
|
||||||
|
when ``config.normalize==True``. Each frame is a tuple ``(codebook,
|
||||||
|
scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
|
||||||
|
frames)``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if bandwidth is None:
|
||||||
|
bandwidth = self.config.target_bandwidths[0]
|
||||||
|
if bandwidth not in self.config.target_bandwidths:
|
||||||
|
raise ValueError(
|
||||||
|
f"This model doesn't support the bandwidth {bandwidth}. "
|
||||||
|
f"Select one of {self.config.target_bandwidths}."
|
||||||
|
)
|
||||||
|
|
||||||
|
_, input_length, channels = input_values.shape
|
||||||
|
|
||||||
|
if channels < 1 or channels > 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of audio channels must be 1 or 2, but got {channels}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_length = self.chunk_length
|
||||||
|
if chunk_length is None:
|
||||||
|
chunk_length = input_length
|
||||||
|
stride = input_length
|
||||||
|
else:
|
||||||
|
stride = self.chunk_stride
|
||||||
|
|
||||||
|
if padding_mask is None:
|
||||||
|
padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
|
||||||
|
encoded_frames = []
|
||||||
|
scales = []
|
||||||
|
|
||||||
|
step = chunk_length - stride
|
||||||
|
if (input_length % stride) != step:
|
||||||
|
raise ValueError(
|
||||||
|
"The input length is not properly padded for batched chunked "
|
||||||
|
"encoding. Make sure to pad the input correctly."
|
||||||
|
)
|
||||||
|
|
||||||
|
for offset in range(0, input_length - step, stride):
|
||||||
|
mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
|
||||||
|
frame = input_values[:, offset : offset + chunk_length]
|
||||||
|
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
|
||||||
|
encoded_frames.append(encoded_frame)
|
||||||
|
scales.append(scale)
|
||||||
|
|
||||||
|
encoded_frames = mx.stack(encoded_frames)
|
||||||
|
|
||||||
|
return (encoded_frames, scales)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _linear_overlap_add(frames: List[mx.array], stride: int):
|
||||||
|
if len(frames) == 0:
|
||||||
|
raise ValueError("`frames` cannot be an empty list.")
|
||||||
|
|
||||||
|
dtype = frames[0].dtype
|
||||||
|
N, frame_length, C = frames[0].shape
|
||||||
|
total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
|
||||||
|
|
||||||
|
time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
|
||||||
|
weight = 0.5 - (time_vec - 0.5).abs()
|
||||||
|
|
||||||
|
weight = weight[:, None]
|
||||||
|
sum_weight = mx.zeros((total_size, 1), dtype=dtype)
|
||||||
|
out = mx.zeros((N, total_size, C), dtype=dtype)
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
for frame in frames:
|
||||||
|
frame_length = frame.shape[1]
|
||||||
|
out[:, offset : offset + frame_length] += weight[:frame_length] * frame
|
||||||
|
sum_weight[offset : offset + frame_length] += weight[:frame_length]
|
||||||
|
offset += stride
|
||||||
|
|
||||||
|
return out / sum_weight
|
||||||
|
|
||||||
|
def _decode_frame(
|
||||||
|
self, codes: mx.array, scale: Optional[mx.array] = None
|
||||||
|
) -> mx.array:
|
||||||
|
embeddings = self.quantizer.decode(codes)
|
||||||
|
outputs = self.decoder(embeddings)
|
||||||
|
if scale is not None:
|
||||||
|
outputs = outputs * scale
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self):
|
||||||
|
return self.config.audio_channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampling_rate(self):
|
||||||
|
return self.config.sampling_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_length(self):
|
||||||
|
if self.config.chunk_length_s is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return int(self.config.chunk_length_s * self.config.sampling_rate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_stride(self):
|
||||||
|
if self.config.chunk_length_s is None or self.config.overlap is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
audio_codes: mx.array,
|
||||||
|
audio_scales: Union[mx.array, List[mx.array]],
|
||||||
|
padding_mask: Optional[mx.array] = None,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""
|
||||||
|
Decodes the given frames into an output audio waveform.
|
||||||
|
|
||||||
|
Note that the output might be a bit bigger than the input. In that
|
||||||
|
case, any extra steps at the end should be trimmed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_codes (mx.array): Discret code embeddings of shape
|
||||||
|
``(batch_size, nb_chunks, chunk_length)``.
|
||||||
|
audio_scales (mx.array): Scaling factor for each input.
|
||||||
|
padding_mask (mx.array): Padding mask.
|
||||||
|
"""
|
||||||
|
chunk_length = self.chunk_length
|
||||||
|
if chunk_length is None:
|
||||||
|
if audio_codes.shape[1] != 1:
|
||||||
|
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
||||||
|
audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
|
||||||
|
else:
|
||||||
|
decoded_frames = []
|
||||||
|
|
||||||
|
for frame, scale in zip(audio_codes, audio_scales):
|
||||||
|
frames = self._decode_frame(frame, scale)
|
||||||
|
decoded_frames.append(frames)
|
||||||
|
|
||||||
|
audio_values = self._linear_overlap_add(
|
||||||
|
decoded_frames, self.chunk_stride or 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# truncate based on padding mask
|
||||||
|
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
|
||||||
|
audio_values = audio_values[:, : padding_mask.shape[1]]
|
||||||
|
return audio_values
|
37
encodec/example.py
Normal file
37
encodec/example.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from utils import load, load_audio, save_audio
|
||||||
|
|
||||||
|
# Load the 48 KHz model and preprocessor.
|
||||||
|
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
|
# Load an audio file
|
||||||
|
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
||||||
|
|
||||||
|
# Preprocess the audio (this can also be a list of arrays for batched
|
||||||
|
# processing).
|
||||||
|
feats, mask = processor(audio)
|
||||||
|
|
||||||
|
|
||||||
|
# Encode at the given bandwidth. A lower bandwidth results in more
|
||||||
|
# compression but lower reconstruction quality.
|
||||||
|
@mx.compile
|
||||||
|
def encode(feats, mask):
|
||||||
|
return model.encode(feats, mask, bandwidth=3)
|
||||||
|
|
||||||
|
|
||||||
|
# Decode to reconstruct the audio
|
||||||
|
@mx.compile
|
||||||
|
def decode(codes, scales, mask):
|
||||||
|
return model.decode(codes, scales, mask)
|
||||||
|
|
||||||
|
|
||||||
|
codes, scales = encode(feats, mask)
|
||||||
|
reconstructed = decode(codes, scales, mask)
|
||||||
|
|
||||||
|
# Trim any padding:
|
||||||
|
reconstructed = reconstructed[0, : len(audio)]
|
||||||
|
|
||||||
|
# Save the audio as a wave file
|
||||||
|
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
|
3
encodec/requirements.txt
Normal file
3
encodec/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
mlx>=0.18
|
||||||
|
numpy
|
||||||
|
huggingface_hub
|
66
encodec/test.py
Normal file
66
encodec/test.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import Audio, load_dataset
|
||||||
|
from transformers import AutoProcessor, EncodecModel
|
||||||
|
from utils import load, load_audio, preprocess_audio
|
||||||
|
|
||||||
|
|
||||||
|
def compare_processors():
|
||||||
|
np.random.seed(0)
|
||||||
|
audio_length = 95500
|
||||||
|
audio = np.random.uniform(size=(2, audio_length)).astype(np.float32)
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
||||||
|
|
||||||
|
pt_inputs = processor(
|
||||||
|
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
|
||||||
|
)
|
||||||
|
mx_inputs = preprocess_audio(
|
||||||
|
mx.array(audio).T,
|
||||||
|
processor.sampling_rate,
|
||||||
|
processor.chunk_length,
|
||||||
|
processor.chunk_stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1))
|
||||||
|
assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1])
|
||||||
|
|
||||||
|
|
||||||
|
def compare_models():
|
||||||
|
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||||
|
mx_model, _ = load("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
audio_length = 190560
|
||||||
|
audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32)
|
||||||
|
mask = np.ones((1, audio_length), dtype=np.int32)
|
||||||
|
pt_encoded = pt_model.encode(
|
||||||
|
torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None]
|
||||||
|
)
|
||||||
|
mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask))
|
||||||
|
pt_codes = pt_encoded.audio_codes.numpy()
|
||||||
|
mx_codes = mx_encoded[0]
|
||||||
|
assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"
|
||||||
|
|
||||||
|
for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
|
||||||
|
if mx_scale is not None:
|
||||||
|
pt_scale = pt_scale.numpy()
|
||||||
|
assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4)
|
||||||
|
|
||||||
|
pt_audio = pt_model.decode(
|
||||||
|
pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
|
||||||
|
)
|
||||||
|
pt_audio = pt_audio[0].squeeze().T.detach().numpy()
|
||||||
|
mx_audio = mx_model.decode(*mx_encoded, mx.array(mask))
|
||||||
|
mx_audio = mx_audio.squeeze()
|
||||||
|
assert np.allclose(
|
||||||
|
pt_audio, mx_audio, atol=1e-4, rtol=1e-4
|
||||||
|
), "Decoding audio mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
compare_processors()
|
||||||
|
compare_models()
|
129
encodec/utils.py
Normal file
129
encodec/utils.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
import encodec
|
||||||
|
|
||||||
|
|
||||||
|
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||||
|
"""
|
||||||
|
Save audio to a wave (.wav) file.
|
||||||
|
"""
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
|
||||||
|
audio = (audio * 32767).astype(mx.int16)
|
||||||
|
write(file, sampling_rate, np.array(audio))
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(file: str, sampling_rate: int, channels: int):
|
||||||
|
"""
|
||||||
|
Read audio into an mx.array, resampling if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): The audio file to open.
|
||||||
|
sampling_rate (int): The sample rate to resample the audio at if needed.
|
||||||
|
channels (int): The number of audio channels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An mx.array containing the audio waveform in float32.
|
||||||
|
"""
|
||||||
|
from subprocess import CalledProcessError, run
|
||||||
|
|
||||||
|
# This launches a subprocess to decode audio while down-mixing
|
||||||
|
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||||
|
# fmt: off
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-nostdin",
|
||||||
|
"-threads", "0",
|
||||||
|
"-i", file,
|
||||||
|
"-f", "s16le",
|
||||||
|
"-ac", str(channels),
|
||||||
|
"-acodec", "pcm_s16le",
|
||||||
|
"-ar", str(sampling_rate),
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
try:
|
||||||
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
|
except CalledProcessError as e:
|
||||||
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
|
out = mx.array(np.frombuffer(out, np.int16))
|
||||||
|
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_audio(
|
||||||
|
raw_audio: Union[mx.array, List[mx.array]],
|
||||||
|
sampling_rate: int = 24000,
|
||||||
|
chunk_length: Optional[int] = None,
|
||||||
|
chunk_stride: Optional[int] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Prepare inputs for the EnCodec model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
||||||
|
sequences to be processed.
|
||||||
|
sampling_rate (int): The sampling rate at which the audio waveform
|
||||||
|
should be digitalized.
|
||||||
|
chunk_length (int, optional): The model's chunk length.
|
||||||
|
chunk_stride (int, optional): The model's chunk stride.
|
||||||
|
"""
|
||||||
|
if not isinstance(raw_audio, list):
|
||||||
|
raw_audio = [raw_audio]
|
||||||
|
|
||||||
|
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
||||||
|
|
||||||
|
max_length = max(array.shape[0] for array in raw_audio)
|
||||||
|
if chunk_length is not None:
|
||||||
|
max_length += chunk_length - (max_length % chunk_stride)
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
masks = []
|
||||||
|
for x in raw_audio:
|
||||||
|
length = x.shape[0]
|
||||||
|
mask = mx.ones((length,), dtype=mx.bool_)
|
||||||
|
difference = max_length - length
|
||||||
|
if difference > 0:
|
||||||
|
mask = mx.pad(mask, (0, difference))
|
||||||
|
x = mx.pad(x, ((0, difference), (0, 0)))
|
||||||
|
inputs.append(x)
|
||||||
|
masks.append(mask)
|
||||||
|
return mx.stack(inputs), mx.stack(masks)
|
||||||
|
|
||||||
|
|
||||||
|
def load(path_or_repo):
|
||||||
|
"""
|
||||||
|
Load the model and audo preprocessor.
|
||||||
|
"""
|
||||||
|
path = Path(path_or_repo)
|
||||||
|
if not path.exists():
|
||||||
|
path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=path_or_repo,
|
||||||
|
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(path / "config.json", "r") as f:
|
||||||
|
config = SimpleNamespace(**json.load(f))
|
||||||
|
|
||||||
|
model = encodec.EncodecModel(config)
|
||||||
|
model.load_weights(str(path / "model.safetensors"))
|
||||||
|
processor = functools.partial(
|
||||||
|
preprocess_audio,
|
||||||
|
sampling_rate=config.sampling_rate,
|
||||||
|
chunk_length=model.chunk_length,
|
||||||
|
chunk_stride=model.chunk_stride,
|
||||||
|
)
|
||||||
|
mx.eval(model)
|
||||||
|
return model, processor
|
Loading…
Reference in New Issue
Block a user