mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
format
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
@@ -2,11 +2,10 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import mistral
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
import mistral
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
def test_model(self):
|
||||
|
@@ -46,7 +46,7 @@ if __name__ == "__main__":
|
||||
args = json.load(fid)
|
||||
args["model_type"] = "mixtral"
|
||||
with open(model_path / "config.json", "w") as f:
|
||||
json.dump(args, f, indent=4)
|
||||
json.dump(args, f, indent=4)
|
||||
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
|
||||
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import argparse
|
||||
from transformers import AutoModelForCausalLM
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
|
@@ -1,9 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
||||
|
||||
|
||||
|
@@ -1,13 +1,13 @@
|
||||
import argparse
|
||||
import time
|
||||
import kwt
|
||||
import mlx.nn as nn
|
||||
import mlx.data as dx
|
||||
import mlx.core as mx
|
||||
import mlx.optimizers as optim
|
||||
from mlx.data.features import mfsc
|
||||
from mlx.data.datasets import load_speechcommands
|
||||
|
||||
import kwt
|
||||
import mlx.core as mx
|
||||
import mlx.data as dx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from mlx.data.datasets import load_speechcommands
|
||||
from mlx.data.features import mfsc
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=True)
|
||||
parser.add_argument(
|
||||
|
@@ -47,8 +47,7 @@ def convert(model_name, dtype):
|
||||
dtype = getattr(np, dtype)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||
weights = {
|
||||
replace_key(k): v.numpy().astype(dtype)
|
||||
for k, v in model.state_dict().items()
|
||||
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
|
||||
}
|
||||
file_name = model_name.replace("/", "-")
|
||||
print(f"Saving weights to {file_name}.npz")
|
||||
|
Reference in New Issue
Block a user