mlx-examples/encodec/benchmarks/bench_pt.py
Awni Hannun 9bb2dd62f3
Encodec (#991)
* initial encodec

* works

* nits

* use fast group norm

* fix for rnn layer

* fix mlx version

* use custom LSTM kernel

* audio encodec

* fix example, support batched inference

* nits
2024-09-23 11:39:25 -07:00

35 lines
862 B
Python

# Copyright © 2024 Apple Inc.
import time
import numpy as np
import torch
from transformers import AutoProcessor, EncodecModel
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
pt_inputs = processor(
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
).to("mps")
def fun():
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
pt_audio = pt_model.decode(
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
)
torch.mps.synchronize()
for _ in range(5):
fun()
tic = time.time()
for _ in range(10):
fun()
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")