Add llms subdir + update README (#145)

* add llms subdir + update README

* nits

* use same pre-commit as mlx

* update readmes a bit

* format
This commit is contained in:
Awni Hannun
2023-12-20 10:22:25 -08:00
committed by GitHub
parent aed14618ca
commit 27c0a8c002
62 changed files with 164 additions and 146 deletions

View File

@@ -1,6 +1,5 @@
from transformers import T5ForConditionalGeneration
import numpy as np
from transformers import T5ForConditionalGeneration
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
@@ -48,8 +47,7 @@ def convert(model_name, dtype):
dtype = getattr(np, dtype)
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {
replace_key(k): v.numpy().astype(dtype)
for k, v in model.state_dict().items()
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
}
file_name = model_name.replace("/", "-")
print(f"Saving weights to {file_name}.npz")

View File

@@ -1,7 +1,7 @@
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
import argparse
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
def embed(t5_model: str):
batch = [
@@ -51,4 +51,3 @@ if __name__ == "__main__":
embed(args.model)
else:
generate(args.model)

View File

@@ -1,11 +1,11 @@
import argparse
from typing import Optional, Tuple, List
from time import perf_counter_ns
from typing import List, Optional, Tuple
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten, tree_map
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import T5Config, T5Tokenizer
@@ -166,7 +166,7 @@ class DenseActivation(nn.Module):
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
@@ -337,7 +337,7 @@ class Tokenizer:
self._tokenizer = T5Tokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=getattr(config, 'n_positions', 512)
model_max_length=getattr(config, "n_positions", 512),
)
@property