This commit is contained in:
Awni Hannun
2023-12-20 10:14:01 -08:00
parent 7b5787e62c
commit 03824e3477
10 changed files with 18 additions and 19 deletions

View File

@@ -2,7 +2,6 @@
import argparse
import json
import numpy as np
from pathlib import Path
import numpy as np

View File

@@ -2,11 +2,10 @@
import unittest
import mistral
import mlx.core as mx
from mlx.utils import tree_map
import mistral
class TestMistral(unittest.TestCase):
def test_model(self):

View File

@@ -46,7 +46,7 @@ if __name__ == "__main__":
args = json.load(fid)
args["model_type"] = "mixtral"
with open(model_path / "config.json", "w") as f:
json.dump(args, f, indent=4)
json.dump(args, f, indent=4)
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))

View File

@@ -1,7 +1,7 @@
import argparse
from pathlib import Path
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import mlx.core as mx

View File

@@ -1,8 +1,9 @@
import argparse
from transformers import AutoModelForCausalLM
import json
import numpy as np
import torch
import json
from transformers import AutoModelForCausalLM
def replace_key(key: str) -> str:

View File

@@ -1,6 +1,7 @@
import argparse
from dataclasses import dataclass
import json
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten

View File

@@ -1,9 +1,9 @@
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]

View File

@@ -1,13 +1,13 @@
import argparse
import time
import kwt
import mlx.nn as nn
import mlx.data as dx
import mlx.core as mx
import mlx.optimizers as optim
from mlx.data.features import mfsc
from mlx.data.datasets import load_speechcommands
import kwt
import mlx.core as mx
import mlx.data as dx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.data.datasets import load_speechcommands
from mlx.data.features import mfsc
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(

View File

@@ -47,8 +47,7 @@ def convert(model_name, dtype):
dtype = getattr(np, dtype)
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {
replace_key(k): v.numpy().astype(dtype)
for k, v in model.state_dict().items()
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
}
file_name = model_name.replace("/", "-")
print(f"Saving weights to {file_name}.npz")

View File

@@ -337,7 +337,7 @@ class Tokenizer:
self._tokenizer = T5Tokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=getattr(config, 'n_positions', 512)
model_max_length=getattr(config, "n_positions", 512),
)
@property