* initial encodec

* works

* nits

* use fast group norm

* fix for rnn layer

* fix mlx version

* use custom LSTM kernel

* audio encodec

* fix example, support batched inference

* nits
This commit is contained in:
Awni Hannun 2024-09-23 11:39:25 -07:00 committed by GitHub
parent 796d5e40e4
commit 9bb2dd62f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1267 additions and 0 deletions

View File

@ -27,6 +27,7 @@ Some more useful examples are listed below.
### Audio Models
- Speech recognition with [OpenAI's Whisper](whisper).
- Audio compression and generation with [Meta's EnCodec](encodec).
### Multimodal models

83
encodec/README.md Normal file
View File

@ -0,0 +1,83 @@
# EnCodec
An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
generate audio.
### Setup
Install the requirements:
```
pip install -r requirements.txt
```
Optionally install FFmpeg and SciPy for loading and saving audio files,
respectively.
Install [FFmpeg](https://ffmpeg.org/):
```
# on macOS using Homebrew (https://brew.sh/)
brew install ffmpeg
```
Install SciPy:
```
pip install scipy
```
### Example
An example using the model:
```python
import mlx.core as mx
from utils import load, load_audio, save_audio
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)
# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
return model.encode(feats, mask, bandwidth=3)
# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
return model.decode(codes, scales, mask)
codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)
# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]
# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
```
The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
in several data types.
### Optional
To convert models, use the `convert.py` script. To see the options, run:
```bash
python convert.py -h
```
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
[code](https://github.com/facebookresearch/encodec) for more details.

View File

@ -0,0 +1,30 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
from utils import load
model, processor = load("mlx-community/encodec-48khz-float32")
audio = mx.random.uniform(shape=(288000, 2))
feats, mask = processor(audio)
mx.eval(model, feats, mask)
@mx.compile
def fun():
codes, scales = model.encode(feats, mask, bandwidth=3)
reconstructed = model.decode(codes, scales, mask)
return reconstructed
for _ in range(5):
mx.eval(fun())
tic = time.time()
for _ in range(10):
mx.eval(fun())
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")

View File

@ -0,0 +1,34 @@
# Copyright © 2024 Apple Inc.
import time
import numpy as np
import torch
from transformers import AutoProcessor, EncodecModel
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
pt_inputs = processor(
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
).to("mps")
def fun():
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
pt_audio = pt_model.decode(
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
)
torch.mps.synchronize()
for _ in range(5):
fun()
tic = time.time()
for _ in range(10):
fun()
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")

213
encodec/convert.py Normal file
View File

@ -0,0 +1,213 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
from pathlib import Path
from textwrap import dedent
from types import SimpleNamespace
from typing import Any, Dict, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
import encodec
def fetch_from_hub(hf_repo: str) -> Path:
model_path = Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.json", "*.safetensors"],
)
)
return model_path
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
content = dedent(
f"""
---
language: en
license: other
library: mlx
tags:
- mlx
---
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from
[{hf_path}](https://huggingface.co/{hf_path}).
This model is intended to be used with the [EnCodec MLX
example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
"""
)
card = ModelCard(content)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
mx.save_safetensors(
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
)
for weight_name in weights.keys():
index_data["weight_map"][weight_name] = "model.safetensors"
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(index_data, f, indent=4)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Clean unused keys
config.pop("_name_or_path", None)
# sort the config for better readability
config = dict(sorted(config.items()))
# write the updated config to the config_path (if provided)
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)
def convert(
upload: bool,
model: str,
dtype: str = None,
):
hf_repo = f"facebook/encodec_{model}"
mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
path = fetch_from_hub(hf_repo)
save_path = Path("mlx_models")
weights = mx.load(str(Path(path) / "model.safetensors"))
with open(path / "config.json", "r") as fid:
config = SimpleNamespace(**json.load(fid))
model = encodec.EncodecModel(config)
new_weights = {}
for k, v in weights.items():
basename, pname = k.rsplit(".", 1)
if pname == "weight_v":
g = weights[basename + ".weight_g"]
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
k = basename + ".weight"
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
continue
elif "lstm" in basename:
w_or_b, ih_or_hh, ln = pname.split("_")
if w_or_b == "weight":
new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
elif w_or_b == "bias" and ih_or_hh == "ih":
continue
else:
v = v + weights[k.replace("_hh_", "_ih_")]
new_pname = "bias"
k = basename + "." + ln[1:] + "." + new_pname
if "conv.weight" in k:
# Possibly a transposed conv which has a different order
if "decoder" in k:
ln = int(k.split(".")[2])
if "conv" in model.decoder.layers[ln] and isinstance(
model.decoder.layers[ln].conv, nn.ConvTranspose1d
):
v = mx.moveaxis(v, 0, 2)
else:
v = mx.moveaxis(v, 1, 2)
else:
v = mx.moveaxis(v, 1, 2)
new_weights[k] = v
weights = new_weights
model.load_weights(list(weights.items()))
if dtype is not None:
t = getattr(mx, dtype)
weights = {k: v.astype(t) for k, v in weights.items()}
if isinstance(save_path, str):
save_path = Path(save_path)
save_weights(save_path, weights)
save_config(vars(config), config_path=save_path / "config.json")
if upload:
upload_to_hub(save_path, mlx_repo, hf_repo)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
parser.add_argument(
"--model",
type=str,
default="48khz",
help="",
choices=["24khz", "32khz", "48khz"],
)
parser.add_argument(
"--upload",
action="store_true",
help="Upload the weights to Hugging Face.",
)
parser.add_argument(
"--dtype",
type=str,
help="Data type to convert the model to.",
default="float32",
choices=["float32", "bfloat16", "float16"],
)
args = parser.parse_args()
convert(upload=args.upload, model=args.model, dtype=args.dtype)

671
encodec/encodec.py Normal file
View File

@ -0,0 +1,671 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
_lstm_kernel = mx.fast.metal_kernel(
name="lstm",
input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
output_names=["hidden_state", "cell_state"],
header="""
template <typename T>
T sigmoid(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
""",
source="""
uint b = thread_position_in_grid.x;
uint d = hidden_size * 4;
uint elem = b * d + thread_position_in_grid.y;
uint index = elem;
uint x_index = b * num_time_steps * d + time_step * d + index;
auto i = sigmoid(h_in[index] + x[x_index]);
index += hidden_size;
x_index += hidden_size;
auto f = sigmoid(h_in[index] + x[x_index]);
index += hidden_size;
x_index += hidden_size;
auto g = metal::precise::tanh(h_in[index] + x[x_index]);
index += hidden_size;
x_index += hidden_size;
auto o = sigmoid(h_in[index] + x[x_index]);
cell_state[elem] = f * cell[elem] + i * g;
hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
""",
)
def lstm_custom(x, h_in, cell, time_step):
assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
out_shape = cell.shape
return _lstm_kernel(
inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
output_shapes=[out_shape, out_shape],
output_dtypes=[h_in.dtype, h_in.dtype],
grid=(x.shape[0], h_in.size // 4, 1),
threadgroup=(256, 1, 1),
)
class LSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.Wx = mx.zeros((4 * hidden_size, input_size))
self.Wh = mx.zeros((4 * hidden_size, hidden_size))
self.bias = mx.zeros((4 * hidden_size,)) if bias else None
def __call__(self, x, hidden=None, cell=None):
if self.bias is not None:
x = mx.addmm(self.bias, x, self.Wx.T)
else:
x = x @ self.Wx.T
all_hidden = []
B = x.shape[0]
cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
for t in range(x.shape[-2]):
if hidden is None:
hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
else:
hidden = hidden @ self.Wh.T
hidden, cell = lstm_custom(x, hidden, cell, t)
all_hidden.append(hidden)
return mx.stack(all_hidden, axis=-2)
class EncodecConv1d(nn.Module):
"""Conv1d with asymmetric or causal padding and normalization."""
def __init__(
self,
config,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
):
super().__init__()
self.causal = config.use_causal_conv
self.pad_mode = config.pad_mode
self.norm_type = config.norm_type
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size, stride, dilation=dilation
)
if self.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
self.stride = stride
# Effective kernel size with dilations.
self.kernel_size = (kernel_size - 1) * dilation + 1
self.padding_total = kernel_size - stride
def _get_extra_padding_for_conv1d(
self,
hidden_states: mx.array,
) -> mx.array:
length = hidden_states.shape[1]
n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
n_frames = int(math.ceil(n_frames)) - 1
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
return ideal_length - length
def _pad1d(
self,
hidden_states: mx.array,
paddings: Tuple[int, int],
mode: str = "zero",
value: float = 0.0,
):
if mode != "reflect":
return mx.pad(
hidden_states, paddings, mode="constant", constant_values=value
)
length = hidden_states.shape[1]
prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
return mx.concatenate([prefix, hidden_states, suffix], axis=1)
def __call__(self, hidden_states):
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
if self.causal:
# Left padding for causal
hidden_states = self._pad1d(
hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
)
else:
# Asymmetric padding required for odd strides
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
hidden_states = self._pad1d(
hidden_states,
(padding_left, padding_right + extra_padding),
mode=self.pad_mode,
)
hidden_states = self.conv(hidden_states)
if self.norm_type == "time_group_norm":
hidden_states = self.norm(hidden_states)
return hidden_states
class EncodecConvTranspose1d(nn.Module):
"""ConvTranspose1d with asymmetric or causal padding and normalization."""
def __init__(
self,
config,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
):
super().__init__()
self.causal = config.use_causal_conv
self.trim_right_ratio = config.trim_right_ratio
self.norm_type = config.norm_type
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
if config.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
self.padding_total = kernel_size - stride
def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
if self.norm_type == "time_group_norm":
hidden_states = self.norm(hidden_states)
if self.causal:
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
else:
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
end = hidden_states.shape[1] - padding_right
hidden_states = hidden_states[:, padding_left:end, :]
return hidden_states
class EncodecLSTM(nn.Module):
def __init__(self, config, dimension):
super().__init__()
self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
def __call__(self, hidden_states):
h = hidden_states
for lstm in self.lstm:
h = lstm(h)
return h + hidden_states
class EncodecResnetBlock(nn.Module):
"""
Residual block from SEANet model as used by EnCodec.
"""
def __init__(self, config, dim: int, dilations: List[int]):
super().__init__()
kernel_sizes = (config.residual_kernel_size, 1)
if len(kernel_sizes) != len(dilations):
raise ValueError("Number of kernel sizes should match number of dilations")
hidden = dim // config.compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [nn.ELU()]
block += [
EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
]
self.block = block
if getattr(config, "use_conv_shortcut", True):
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
else:
self.shortcut = nn.Identity()
def __call__(self, hidden_states):
residual = hidden_states
for layer in self.block:
hidden_states = layer(hidden_states)
return self.shortcut(residual) + hidden_states
class EncodecEncoder(nn.Module):
"""SEANet encoder as used by EnCodec."""
def __init__(self, config):
super().__init__()
model = [
EncodecConv1d(
config, config.audio_channels, config.num_filters, config.kernel_size
)
]
scaling = 1
for ratio in reversed(config.upsampling_ratios):
current_scale = scaling * config.num_filters
for j in range(config.num_residual_layers):
model += [
EncodecResnetBlock(
config, current_scale, [config.dilation_growth_rate**j, 1]
)
]
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
current_scale,
current_scale * 2,
kernel_size=ratio * 2,
stride=ratio,
)
]
scaling *= 2
model += [EncodecLSTM(config, scaling * config.num_filters)]
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
scaling * config.num_filters,
config.hidden_size,
config.last_kernel_size,
)
]
self.layers = model
def __call__(self, hidden_states):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class EncodecDecoder(nn.Module):
"""SEANet decoder as used by EnCodec."""
def __init__(self, config):
super().__init__()
scaling = int(2 ** len(config.upsampling_ratios))
model = [
EncodecConv1d(
config,
config.hidden_size,
scaling * config.num_filters,
config.kernel_size,
)
]
model += [EncodecLSTM(config, scaling * config.num_filters)]
for ratio in config.upsampling_ratios:
current_scale = scaling * config.num_filters
model += [nn.ELU()]
model += [
EncodecConvTranspose1d(
config,
current_scale,
current_scale // 2,
kernel_size=ratio * 2,
stride=ratio,
)
]
for j in range(config.num_residual_layers):
model += [
EncodecResnetBlock(
config, current_scale // 2, (config.dilation_growth_rate**j, 1)
)
]
scaling //= 2
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
config.num_filters,
config.audio_channels,
config.last_kernel_size,
)
]
self.layers = model
def __call__(self, hidden_states):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class EncodecEuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance."""
def __init__(self, config):
super().__init__()
self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
def quantize(self, hidden_states):
embed = self.embed.T
scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
dist = -(
scaled_states
- 2 * hidden_states @ embed
+ embed.square().sum(axis=0, keepdims=True)
)
embed_ind = dist.argmax(axis=-1)
return embed_ind
def encode(self, hidden_states):
shape = hidden_states.shape
hidden_states = hidden_states.reshape((-1, shape[-1]))
embed_ind = self.quantize(hidden_states)
embed_ind = embed_ind.reshape(*shape[:-1])
return embed_ind
def decode(self, embed_ind):
return self.embed[embed_ind]
class EncodecVectorQuantization(nn.Module):
"""
Vector quantization implementation. Currently supports only euclidean distance.
"""
def __init__(self, config):
super().__init__()
self.codebook = EncodecEuclideanCodebook(config)
def encode(self, hidden_states):
return self.codebook.encode(hidden_states)
def decode(self, embed_ind):
return self.codebook.decode(embed_ind)
class EncodecResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer."""
def __init__(self, config):
super().__init__()
self.codebook_size = config.codebook_size
hop_length = np.prod(config.upsampling_ratios)
self.frame_rate = math.ceil(config.sampling_rate / hop_length)
self.num_quantizers = int(
1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
)
self.layers = [
EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
]
def get_num_quantizers_for_bandwidth(
self, bandwidth: Optional[float] = None
) -> int:
"""Return num_quantizers based on specified target bandwidth."""
bw_per_q = math.log2(self.codebook_size) * self.frame_rate
num_quantizers = self.num_quantizers
if bandwidth is not None and bandwidth > 0.0:
num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
return num_quantizers
def encode(
self, embeddings: mx.array, bandwidth: Optional[float] = None
) -> mx.array:
"""
Encode a given input array with the specified frame rate at the given
bandwidth. The RVQ encode method sets the appropriate number of
quantizers to use and returns indices for each quantizer.
"""
num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
residual = embeddings
all_indices = []
for layer in self.layers[:num_quantizers]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = mx.stack(all_indices, axis=1)
return out_indices
def decode(self, codes: mx.array) -> mx.array:
"""Decode the given codes to the quantized representation."""
quantized_out = None
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
layer = self.layers[i]
quantized = layer.decode(indices.squeeze(1))
if quantized_out is None:
quantized_out = quantized
else:
quantized_out = quantized + quantized_out
return quantized_out
class EncodecModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.encoder = EncodecEncoder(config)
self.decoder = EncodecDecoder(config)
self.quantizer = EncodecResidualVectorQuantizer(config)
def _encode_frame(
self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Encodes the given input using the underlying VQVAE.
"""
length = input_values.shape[1]
duration = length / self.config.sampling_rate
if (
self.config.chunk_length_s is not None
and duration > 1e-5 + self.config.chunk_length_s
):
raise RuntimeError(
f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
)
scale = None
if self.config.normalize:
# if the padding is non zero
input_values = input_values * padding_mask[..., None]
mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
input_values = input_values / scale
embeddings = self.encoder(input_values)
codes = self.quantizer.encode(embeddings, bandwidth)
return codes, scale
def encode(
self,
input_values: mx.array,
padding_mask: mx.array = None,
bandwidth: Optional[float] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Encodes the input audio waveform into discrete codes.
Args:
input_values (mx.array): The input audio waveform with shape
``(batch_size, channels, sequence_length)``.
padding_mask (mx.array): Padding mask used to pad the ``input_values``.
bandwidth (float, optional): The target bandwidth. Must be one of
``config.target_bandwidths``. If ``None``, uses the smallest
possible bandwidth. bandwidth is represented as a thousandth of
what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
Returns:
A list of frames containing the discrete encoded codes for the
input audio waveform, along with rescaling factors for each chunk
when ``config.normalize==True``. Each frame is a tuple ``(codebook,
scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
frames)``.
"""
if bandwidth is None:
bandwidth = self.config.target_bandwidths[0]
if bandwidth not in self.config.target_bandwidths:
raise ValueError(
f"This model doesn't support the bandwidth {bandwidth}. "
f"Select one of {self.config.target_bandwidths}."
)
_, input_length, channels = input_values.shape
if channels < 1 or channels > 2:
raise ValueError(
f"Number of audio channels must be 1 or 2, but got {channels}"
)
chunk_length = self.chunk_length
if chunk_length is None:
chunk_length = input_length
stride = input_length
else:
stride = self.chunk_stride
if padding_mask is None:
padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
encoded_frames = []
scales = []
step = chunk_length - stride
if (input_length % stride) != step:
raise ValueError(
"The input length is not properly padded for batched chunked "
"encoding. Make sure to pad the input correctly."
)
for offset in range(0, input_length - step, stride):
mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
frame = input_values[:, offset : offset + chunk_length]
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
encoded_frames.append(encoded_frame)
scales.append(scale)
encoded_frames = mx.stack(encoded_frames)
return (encoded_frames, scales)
@staticmethod
def _linear_overlap_add(frames: List[mx.array], stride: int):
if len(frames) == 0:
raise ValueError("`frames` cannot be an empty list.")
dtype = frames[0].dtype
N, frame_length, C = frames[0].shape
total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
weight = 0.5 - (time_vec - 0.5).abs()
weight = weight[:, None]
sum_weight = mx.zeros((total_size, 1), dtype=dtype)
out = mx.zeros((N, total_size, C), dtype=dtype)
offset = 0
for frame in frames:
frame_length = frame.shape[1]
out[:, offset : offset + frame_length] += weight[:frame_length] * frame
sum_weight[offset : offset + frame_length] += weight[:frame_length]
offset += stride
return out / sum_weight
def _decode_frame(
self, codes: mx.array, scale: Optional[mx.array] = None
) -> mx.array:
embeddings = self.quantizer.decode(codes)
outputs = self.decoder(embeddings)
if scale is not None:
outputs = outputs * scale
return outputs
@property
def channels(self):
return self.config.audio_channels
@property
def sampling_rate(self):
return self.config.sampling_rate
@property
def chunk_length(self):
if self.config.chunk_length_s is None:
return None
else:
return int(self.config.chunk_length_s * self.config.sampling_rate)
@property
def chunk_stride(self):
if self.config.chunk_length_s is None or self.config.overlap is None:
return None
else:
return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
def decode(
self,
audio_codes: mx.array,
audio_scales: Union[mx.array, List[mx.array]],
padding_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
"""
Decodes the given frames into an output audio waveform.
Note that the output might be a bit bigger than the input. In that
case, any extra steps at the end should be trimmed.
Args:
audio_codes (mx.array): Discret code embeddings of shape
``(batch_size, nb_chunks, chunk_length)``.
audio_scales (mx.array): Scaling factor for each input.
padding_mask (mx.array): Padding mask.
"""
chunk_length = self.chunk_length
if chunk_length is None:
if audio_codes.shape[1] != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
else:
decoded_frames = []
for frame, scale in zip(audio_codes, audio_scales):
frames = self._decode_frame(frame, scale)
decoded_frames.append(frames)
audio_values = self._linear_overlap_add(
decoded_frames, self.chunk_stride or 1
)
# truncate based on padding mask
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
audio_values = audio_values[:, : padding_mask.shape[1]]
return audio_values

37
encodec/example.py Normal file
View File

@ -0,0 +1,37 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
from utils import load, load_audio, save_audio
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)
# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
return model.encode(feats, mask, bandwidth=3)
# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
return model.decode(codes, scales, mask)
codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)
# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]
# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)

3
encodec/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
mlx>=0.18
numpy
huggingface_hub

66
encodec/test.py Normal file
View File

@ -0,0 +1,66 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import numpy as np
import torch
from datasets import Audio, load_dataset
from transformers import AutoProcessor, EncodecModel
from utils import load, load_audio, preprocess_audio
def compare_processors():
np.random.seed(0)
audio_length = 95500
audio = np.random.uniform(size=(2, audio_length)).astype(np.float32)
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
pt_inputs = processor(
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
)
mx_inputs = preprocess_audio(
mx.array(audio).T,
processor.sampling_rate,
processor.chunk_length,
processor.chunk_stride,
)
assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1))
assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1])
def compare_models():
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
mx_model, _ = load("mlx-community/encodec-48khz-float32")
np.random.seed(0)
audio_length = 190560
audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32)
mask = np.ones((1, audio_length), dtype=np.int32)
pt_encoded = pt_model.encode(
torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None]
)
mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask))
pt_codes = pt_encoded.audio_codes.numpy()
mx_codes = mx_encoded[0]
assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"
for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
if mx_scale is not None:
pt_scale = pt_scale.numpy()
assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4)
pt_audio = pt_model.decode(
pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
)
pt_audio = pt_audio[0].squeeze().T.detach().numpy()
mx_audio = mx_model.decode(*mx_encoded, mx.array(mask))
mx_audio = mx_audio.squeeze()
assert np.allclose(
pt_audio, mx_audio, atol=1e-4, rtol=1e-4
), "Decoding audio mismatch"
if __name__ == "__main__":
compare_processors()
compare_models()

129
encodec/utils.py Normal file
View File

@ -0,0 +1,129 @@
# Copyright © 2024 Apple Inc.
import functools
import json
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Union
import mlx.core as mx
import numpy as np
from huggingface_hub import snapshot_download
import encodec
def save_audio(file: str, audio: mx.array, sampling_rate: int):
"""
Save audio to a wave (.wav) file.
"""
from scipy.io.wavfile import write
audio = (audio * 32767).astype(mx.int16)
write(file, sampling_rate, np.array(audio))
def load_audio(file: str, sampling_rate: int, channels: int):
"""
Read audio into an mx.array, resampling if necessary.
Args:
file (str): The audio file to open.
sampling_rate (int): The sample rate to resample the audio at if needed.
channels (int): The number of audio channels.
Returns:
An mx.array containing the audio waveform in float32.
"""
from subprocess import CalledProcessError, run
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", str(channels),
"-acodec", "pcm_s16le",
"-ar", str(sampling_rate),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
out = mx.array(np.frombuffer(out, np.int16))
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
chunk_length: Optional[int] = None,
chunk_stride: Optional[int] = None,
):
r"""
Prepare inputs for the EnCodec model.
Args:
raw_audio (mx.array or List[mx.array]): The sequence or batch of
sequences to be processed.
sampling_rate (int): The sampling rate at which the audio waveform
should be digitalized.
chunk_length (int, optional): The model's chunk length.
chunk_stride (int, optional): The model's chunk stride.
"""
if not isinstance(raw_audio, list):
raw_audio = [raw_audio]
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
max_length = max(array.shape[0] for array in raw_audio)
if chunk_length is not None:
max_length += chunk_length - (max_length % chunk_stride)
inputs = []
masks = []
for x in raw_audio:
length = x.shape[0]
mask = mx.ones((length,), dtype=mx.bool_)
difference = max_length - length
if difference > 0:
mask = mx.pad(mask, (0, difference))
x = mx.pad(x, ((0, difference), (0, 0)))
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)
def load(path_or_repo):
"""
Load the model and audo preprocessor.
"""
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = encodec.EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor