mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-06 08:54:33 +08:00
Add llms subdir + update README (#145)
* add llms subdir + update README * nits * use same pre-commit as mlx * update readmes a bit * format
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from transformers import T5ForConditionalGeneration
|
||||
import numpy as np
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".block.", ".layers."),
|
||||
@@ -48,8 +47,7 @@ def convert(model_name, dtype):
|
||||
dtype = getattr(np, dtype)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||
weights = {
|
||||
replace_key(k): v.numpy().astype(dtype)
|
||||
for k, v in model.state_dict().items()
|
||||
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
|
||||
}
|
||||
file_name = model_name.replace("/", "-")
|
||||
print(f"Saving weights to {file_name}.npz")
|
||||
|
Reference in New Issue
Block a user