mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
change to from_pretrained
This commit is contained in:
parent
9432f1a643
commit
4d2ee67402
@ -33,11 +33,11 @@ An example using the model:
|
||||
|
||||
```python
|
||||
import mlx.core as mx
|
||||
from encodec import load
|
||||
from encodec import EncodecModel
|
||||
from utils import load_audio, save_audio
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
# Load an audio file
|
||||
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
|
||||
|
@ -3,9 +3,10 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from utils import load
|
||||
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
from encodec import EncodecModel
|
||||
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
audio = mx.random.uniform(shape=(288000, 2))
|
||||
feats, mask = processor(audio)
|
||||
|
@ -673,6 +673,33 @@ class EncodecModel(nn.Module):
|
||||
audio_values = audio_values[:, : padding_mask.shape[1]]
|
||||
return audio_values
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_or_repo: str):
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
model = EncodecModel(config)
|
||||
model.load_weights(str(path / "model.safetensors"))
|
||||
processor = functools.partial(
|
||||
preprocess_audio,
|
||||
sampling_rate=config.sampling_rate,
|
||||
chunk_length=model.chunk_length,
|
||||
chunk_stride=model.chunk_stride,
|
||||
)
|
||||
mx.eval(model)
|
||||
return model, processor
|
||||
|
||||
|
||||
def preprocess_audio(
|
||||
raw_audio: Union[mx.array, List[mx.array]],
|
||||
@ -712,33 +739,3 @@ def preprocess_audio(
|
||||
inputs.append(x)
|
||||
masks.append(mask)
|
||||
return mx.stack(inputs), mx.stack(masks)
|
||||
|
||||
|
||||
def load(path_or_repo):
|
||||
"""
|
||||
Load the model and audo preprocessor.
|
||||
"""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
model = EncodecModel(config)
|
||||
model.load_weights(str(path / "model.safetensors"))
|
||||
processor = functools.partial(
|
||||
preprocess_audio,
|
||||
sampling_rate=config.sampling_rate,
|
||||
chunk_length=model.chunk_length,
|
||||
chunk_stride=model.chunk_stride,
|
||||
)
|
||||
mx.eval(model)
|
||||
return model, processor
|
||||
|
@ -3,10 +3,10 @@
|
||||
import mlx.core as mx
|
||||
from utils import load_audio, save_audio
|
||||
|
||||
from encodec import load
|
||||
from encodec import EncodecModel
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
# Load an audio file
|
||||
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
||||
|
@ -3,9 +3,10 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoProcessor, EncodecModel
|
||||
from transformers import AutoProcessor
|
||||
from transformers import EncodecModel as PTEncodecModel
|
||||
|
||||
from encodec import load, preprocess_audio
|
||||
from encodec import EncodecModel, preprocess_audio
|
||||
|
||||
|
||||
def compare_processors():
|
||||
@ -30,8 +31,8 @@ def compare_processors():
|
||||
|
||||
|
||||
def compare_models():
|
||||
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||
mx_model, _ = load("mlx-community/encodec-48khz-float32")
|
||||
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
np.random.seed(0)
|
||||
audio_length = 190560
|
||||
|
@ -1,12 +1,18 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
from utils import load
|
||||
|
||||
cur_path = Path(__file__).parents[1].resolve()
|
||||
sys.path.append(str(cur_path))
|
||||
|
||||
from musicgen import MusicGen
|
||||
|
||||
text = "folk ballad"
|
||||
model = load("facebook/musicgen-medium")
|
||||
model = MusicGen.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
max_steps = 100
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
@ -2,11 +2,13 @@
|
||||
|
||||
import argparse
|
||||
|
||||
from utils import load, save_audio
|
||||
from utils import save_audio
|
||||
|
||||
from musicgen import MusicGen
|
||||
|
||||
|
||||
def main(text: str, output_path: str, model_name: str, max_steps: int):
|
||||
model = load(model_name)
|
||||
model = MusicGen.from_pretrained(model_name)
|
||||
audio = model.generate(text, max_steps=max_steps)
|
||||
save_audio(output_path, audio, model.sampling_rate)
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
@ -13,14 +15,14 @@ cur_path = Path(__file__).parents[1].resolve()
|
||||
sys.path.append(str(cur_path / "encodec"))
|
||||
sys.path.append(str(cur_path / "t5"))
|
||||
|
||||
from encodec import load as load_encodec
|
||||
from t5 import load_model as load_t5
|
||||
from encodec import EncodecModel
|
||||
from t5 import T5
|
||||
|
||||
|
||||
class TextConditioner(nn.Module):
|
||||
def __init__(self, t5_name, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self._t5, self.tokenizer = load_t5(t5_name)
|
||||
self._t5, self.tokenizer = T5.from_pretrained(t5_name)
|
||||
self.output_proj = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def __call__(self, text):
|
||||
@ -222,7 +224,9 @@ class MusicGen(nn.Module):
|
||||
]
|
||||
encodec_name = config.audio_encoder._name_or_path.split("/")[-1]
|
||||
encodec_name = encodec_name.replace("_", "-")
|
||||
self._audio_decoder, _ = load_encodec(f"mlx-community/{encodec_name}-float32")
|
||||
self._audio_decoder, _ = EncodecModel.from_pretrained(
|
||||
f"mlx-community/{encodec_name}-float32"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -304,3 +308,57 @@ class MusicGen(nn.Module):
|
||||
audio_seq = mx.swapaxes(audio_seq, -1, -2)[:, mx.newaxis]
|
||||
audio = self._audio_decoder.decode(audio_seq, audio_scales=[None])
|
||||
return audio[0]
|
||||
|
||||
@classmethod
|
||||
def sanitize(cls, weights):
|
||||
out_weights = {}
|
||||
for k, arr in weights.items():
|
||||
if k.startswith("transformer."):
|
||||
k = k[len("transformer.") :]
|
||||
|
||||
if "cross_attention" in k:
|
||||
k = k.replace("cross_attention", "cross_attn")
|
||||
|
||||
if "condition_provider" in k:
|
||||
k = k.replace(
|
||||
"condition_provider.conditioners.description", "text_conditioner"
|
||||
)
|
||||
|
||||
if "in_proj_weight" in k:
|
||||
dim = arr.shape[0] // 3
|
||||
name = "in_proj_weight"
|
||||
out_weights[k.replace(name, "q_proj.weight")] = arr[:dim]
|
||||
out_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2]
|
||||
out_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :]
|
||||
continue
|
||||
|
||||
out_weights[k] = arr
|
||||
return out_weights
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_or_repo: str):
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "state_dict.bin"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
config.text_encoder = SimpleNamespace(**config.text_encoder)
|
||||
config.audio_encoder = SimpleNamespace(**config.audio_encoder)
|
||||
config.decoder = SimpleNamespace(**config.decoder)
|
||||
|
||||
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
|
||||
weights = {k: mx.array(v.numpy()) for k, v in weights.items()}
|
||||
weights = cls.sanitize(weights)
|
||||
|
||||
model = MusicGen(config)
|
||||
model.load_weights(list(weights.items()))
|
||||
return model
|
||||
|
@ -1,14 +1,7 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import musicgen
|
||||
|
||||
|
||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||
@ -20,56 +13,3 @@ def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||
audio = mx.clip(audio, -1, 1)
|
||||
audio = (audio * 32767).astype(mx.int16)
|
||||
write(file, sampling_rate, np.array(audio))
|
||||
|
||||
|
||||
def load(path_or_repo):
|
||||
"""
|
||||
Load the model and audio preprocessor.
|
||||
"""
|
||||
import torch
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "state_dict.bin"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
config.text_encoder = SimpleNamespace(**config.text_encoder)
|
||||
config.audio_encoder = SimpleNamespace(**config.audio_encoder)
|
||||
config.decoder = SimpleNamespace(**config.decoder)
|
||||
|
||||
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
|
||||
weights = {k: mx.array(v.numpy()) for k, v in weights.items()}
|
||||
|
||||
decoder_weights = {}
|
||||
for k, arr in weights.items():
|
||||
if k.startswith("transformer."):
|
||||
k = k[len("transformer.") :]
|
||||
|
||||
if "cross_attention" in k:
|
||||
k = k.replace("cross_attention", "cross_attn")
|
||||
|
||||
if "condition_provider" in k:
|
||||
k = k.replace(
|
||||
"condition_provider.conditioners.description", "text_conditioner"
|
||||
)
|
||||
|
||||
if "in_proj_weight" in k:
|
||||
dim = arr.shape[0] // 3
|
||||
name = "in_proj_weight"
|
||||
decoder_weights[k.replace(name, "q_proj.weight")] = arr[:dim]
|
||||
decoder_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2]
|
||||
decoder_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :]
|
||||
continue
|
||||
|
||||
decoder_weights[k] = arr
|
||||
|
||||
model = musicgen.MusicGen(config)
|
||||
model.load_weights(list(decoder_weights.items()))
|
||||
mx.eval(model)
|
||||
return model
|
||||
|
40
t5/README.md
40
t5/README.md
@ -7,31 +7,6 @@ tasks by prepending task-specific prefixes to the input, e.g.:
|
||||
|
||||
This example also supports the FLAN-T5 models variants.[^2]
|
||||
|
||||
## Setup
|
||||
|
||||
Download and convert the model:
|
||||
|
||||
```sh
|
||||
python convert.py --model <model>
|
||||
```
|
||||
|
||||
This will make the `<model>.npz` file which MLX can read.
|
||||
|
||||
The `<model>` can be any of the following:
|
||||
|
||||
| Model Name | Model Size |
|
||||
| ---------- | ----------
|
||||
| t5-small | 60 million |
|
||||
| t5-base | 220 million |
|
||||
| t5-large | 770 million |
|
||||
| t5-3b | 3 billion |
|
||||
| t5-11b | 11 billion |
|
||||
|
||||
The FLAN variants can be specified with `google/flan-t5-small`,
|
||||
`google/flan-t5-base`, etc. See the [Hugging Face
|
||||
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
|
||||
complete list of models.
|
||||
|
||||
## Generate
|
||||
|
||||
Generate text with:
|
||||
@ -48,6 +23,21 @@ To see a list of options run:
|
||||
python t5.py --help
|
||||
```
|
||||
|
||||
The `<model>` can be any of the following:
|
||||
|
||||
| Model Name | Model Size |
|
||||
| ---------- | ----------
|
||||
| t5-small | 60 million |
|
||||
| t5-base | 220 million |
|
||||
| t5-large | 770 million |
|
||||
| t5-3b | 3 billion |
|
||||
| t5-11b | 11 billion |
|
||||
|
||||
The FLAN variants can be specified with `google/flan-t5-small`,
|
||||
`google/flan-t5-base`, etc. See the [Hugging Face
|
||||
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
|
||||
complete list of models.
|
||||
|
||||
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
|
||||
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
|
||||
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).
|
||||
|
195
t5/t5.py
195
t5/t5.py
@ -10,36 +10,36 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
ENCODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
class Tokenizer:
|
||||
def __init__(self, config, model_name):
|
||||
self._decoder_start_id = config.decoder_start_token_id
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
legacy=False,
|
||||
model_max_length=getattr(config, "n_positions", 512),
|
||||
)
|
||||
|
||||
DECODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".self_attention."),
|
||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||
(".layer.2.DenseReluDense.", ".dense."),
|
||||
]
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.eos_token_id
|
||||
|
||||
IGNORED_KEYS = ["decoder.layers.0.cross_attention.relative_attention_bias.weight"]
|
||||
@property
|
||||
def decoder_start_id(self) -> int:
|
||||
return self._decoder_start_id
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array(
|
||||
self._tokenizer(
|
||||
s,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||
|
||||
|
||||
def _relative_position_bucket(
|
||||
@ -350,36 +350,83 @@ class T5(nn.Module):
|
||||
):
|
||||
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
||||
|
||||
@classmethod
|
||||
def sanitize(cls, weights):
|
||||
shared_replacement_patterns = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, config, model_name):
|
||||
self._decoder_start_id = config.decoder_start_token_id
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
legacy=False,
|
||||
model_max_length=getattr(config, "n_positions", 512),
|
||||
)
|
||||
encoder_replacement_patterns = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.eos_token_id
|
||||
decoder_replacement_patterns = [
|
||||
(".layer.0.SelfAttention.", ".self_attention."),
|
||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||
(".layer.2.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
@property
|
||||
def decoder_start_id(self) -> int:
|
||||
return self._decoder_start_id
|
||||
ignored_keys = [
|
||||
"decoder.layers.0.cross_attention.relative_attention_bias.weight"
|
||||
]
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array(
|
||||
self._tokenizer(
|
||||
s,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
)
|
||||
def replace_key(key: str) -> str:
|
||||
for old, new in shared_replacement_patterns:
|
||||
key = key.replace(old, new)
|
||||
if key.startswith("encoder."):
|
||||
for old, new in encoder_replacement_patterns:
|
||||
key = key.replace(old, new)
|
||||
elif key.startswith("decoder."):
|
||||
for old, new in decoder_replacement_patterns:
|
||||
key = key.replace(old, new)
|
||||
return key
|
||||
|
||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||
weights = {replace_key(k): v for k, v in weights.items()}
|
||||
for key in ignored_keys:
|
||||
if key in weights:
|
||||
del weights[key]
|
||||
return weights
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16
|
||||
) -> tuple["T5", Tokenizer]:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
print(path)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
model = T5(config)
|
||||
weights = mx.load(str(path / "model.safetensors"))
|
||||
weights = cls.sanitize(weights)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
model.load_weights(list(weights.items()))
|
||||
return model, Tokenizer(config, "t5-base")
|
||||
|
||||
|
||||
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
|
||||
@ -400,47 +447,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float]
|
||||
yield y.squeeze()
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
if key.startswith("encoder."):
|
||||
for old, new in ENCODER_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
elif key.startswith("decoder."):
|
||||
for old, new in DECODER_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
return key
|
||||
|
||||
|
||||
def load_model(path_or_repo: str, dtype: str = "bfloat16"):
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
print(path)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
dtype = getattr(mx, dtype)
|
||||
|
||||
model = T5(config)
|
||||
weights = mx.load(str(path / "model.safetensors"))
|
||||
weights = {replace_key(k): v.astype(dtype) for k, v in weights.items()}
|
||||
for key in IGNORED_KEYS:
|
||||
del weights[key]
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
mx.eval(model.parameters())
|
||||
return model, Tokenizer(config, "t5-base")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="T5 Inference script")
|
||||
parser.add_argument(
|
||||
@ -486,7 +492,8 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model(args.model, args.dtype)
|
||||
dtype = getattr(mx, args.dtype)
|
||||
model, tokenizer = T5.from_pretrained(args.model, dtype)
|
||||
|
||||
if args.encode_only:
|
||||
print("[INFO] Encoding with T5...", flush=True)
|
||||
|
Loading…
Reference in New Issue
Block a user