mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Add llms subdir + update README (#145)
* add llms subdir + update README * nits * use same pre-commit as mlx * update readmes a bit * format
This commit is contained in:
parent
aed14618ca
commit
27c0a8c002
@ -1,5 +1,11 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 22.10.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
- --profile=black
|
||||
|
29
README.md
29
README.md
@ -5,18 +5,37 @@ framework](https://github.com/ml-explore/mlx).
|
||||
|
||||
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
||||
|
||||
Some more useful examples include:
|
||||
Some more useful examples are listed below.
|
||||
|
||||
### Text Models
|
||||
|
||||
- [Transformer language model](transformer_lm) training.
|
||||
- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2).
|
||||
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral)
|
||||
- Large scale text generation with [LLaMA](llms/llama),
|
||||
[Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms)
|
||||
directory.
|
||||
- A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral).
|
||||
- Parameter efficient fine-tuning with [LoRA](lora).
|
||||
- Text-to-text multi-task Transformers with [T5](t5).
|
||||
- Bidirectional language understanding with [BERT](bert).
|
||||
|
||||
### Image Models
|
||||
|
||||
- Generating images with [Stable Diffusion](stable_diffusion).
|
||||
|
||||
### Audio Models
|
||||
|
||||
- Speech recognition with [OpenAI's Whisper](whisper).
|
||||
- Bidirectional language understanding with [BERT](bert)
|
||||
|
||||
### Other Models
|
||||
|
||||
- Semi-supervised learning on graph-structured data with [GCN](gcn).
|
||||
|
||||
Note: You can now directly download a few converted checkpoints from the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face.
|
||||
### Hugging Face
|
||||
|
||||
Note: You can now directly download a few converted checkpoints from the [MLX
|
||||
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
||||
We encourage you to join the community and [contribute new
|
||||
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
||||
|
||||
## Contributing
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from transformers import BertModel
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
|
||||
def run(bert_model: str):
|
||||
batch = [
|
||||
|
@ -1,13 +1,13 @@
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from transformers import BertTokenizer
|
||||
from mlx.utils import tree_unflatten
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import argparse
|
||||
import numpy
|
||||
import numpy as np
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1,6 +1,7 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.data.datasets import load_cifar10
|
||||
import math
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
|
@ -1,11 +1,11 @@
|
||||
import argparse
|
||||
import time
|
||||
import resnet
|
||||
import mlx.nn as nn
|
||||
import mlx.core as mx
|
||||
import mlx.optimizers as optim
|
||||
from dataset import get_cifar10
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import resnet
|
||||
from dataset import get_cifar10
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=True)
|
||||
parser.add_argument(
|
||||
|
@ -6,11 +6,11 @@ There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ResNet",
|
||||
"resnet20",
|
||||
|
@ -1,9 +1,9 @@
|
||||
import os
|
||||
import requests
|
||||
import tarfile
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import requests
|
||||
import scipy.sparse as sparse
|
||||
|
||||
"""
|
||||
|
@ -3,10 +3,10 @@ from argparse import ArgumentParser
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from datasets import download_cora, load_data, train_val_test_mask
|
||||
from mlx.nn.losses import cross_entropy
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from datasets import download_cora, load_data, train_val_test_mask
|
||||
from gcn import GCN
|
||||
|
||||
|
||||
|
@ -20,8 +20,9 @@ weights you will need to request access from Meta:
|
||||
- [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
||||
- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
||||
|
||||
> [!TIP]
|
||||
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step.
|
||||
> [!TIP] Alternatively, you can also download a few converted checkpoints from
|
||||
> the [MLX Community](https://huggingface.co/mlx-community) organization on
|
||||
> Hugging Face and skip the conversion step.
|
||||
|
||||
You can download the TinyLlama models directly from [Hugging
|
||||
Face](https://huggingface.co/TinyLlama).
|
@ -4,8 +4,9 @@ import argparse
|
||||
import collections
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
@ -1,16 +1,16 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
@dataclass
|
@ -29,7 +29,10 @@ python convert.py
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
> [!TIP]
|
||||
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step.
|
||||
> Alternatively, you can also download a few converted checkpoints from the
|
||||
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
|
||||
> Face and skip the conversion step.
|
||||
|
||||
|
||||
### Run
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
@ -1,15 +1,15 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
@dataclass
|
@ -2,11 +2,10 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import mistral
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
import mistral
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
def test_model(self):
|
@ -3,8 +3,9 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@ -45,7 +46,7 @@ if __name__ == "__main__":
|
||||
args = json.load(fid)
|
||||
args["model_type"] = "mixtral"
|
||||
with open(model_path / "config.json", "w") as f:
|
||||
json.dump(args, f, indent=4)
|
||||
json.dump(args, f, indent=4)
|
||||
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
|
||||
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))
|
@ -1,17 +1,17 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
@dataclass
|
0
phi2/.gitignore → llms/phi2/.gitignore
vendored
0
phi2/.gitignore → llms/phi2/.gitignore
vendored
@ -17,8 +17,10 @@ python convert.py
|
||||
|
||||
This will make the `weights.npz` file which MLX can read.
|
||||
|
||||
> [!TIP]
|
||||
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step.
|
||||
> [!TIP] Alternatively, you can also download a few converted checkpoints from
|
||||
> the [MLX Community](https://huggingface.co/mlx-community) organization on
|
||||
> Hugging Face and skip the conversion step.
|
||||
|
||||
|
||||
## Generate
|
||||
|
@ -1,5 +1,5 @@
|
||||
from transformers import AutoModelForCausalLM
|
||||
import numpy as np
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
@ -1,13 +1,13 @@
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import math
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -167,7 +167,7 @@ def load_model(model_path: str):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="phi-2",
|
||||
help="The path to the model weights",
|
||||
@ -178,7 +178,7 @@ if __name__ == "__main__":
|
||||
default="Write a detailed analogy between mathematics and a lighthouse.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
@ -1,8 +1,9 @@
|
||||
import argparse
|
||||
from transformers import AutoModelForCausalLM
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
|
@ -2,12 +2,12 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import os
|
||||
import torch
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -32,8 +32,8 @@ class WikiSQL:
|
||||
def _maybe_download(self, data_dir):
|
||||
if not os.path.exists(data_dir):
|
||||
import io
|
||||
from urllib import request
|
||||
import tarfile
|
||||
from urllib import request
|
||||
|
||||
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
|
||||
r = request.urlopen(url)
|
||||
|
14
lora/lora.py
14
lora/lora.py
@ -3,19 +3,17 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
import time
|
||||
from typing import Optional, Tuple, List
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from mlx.utils import tree_map, tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
from models import ModelArgs, Model, LoRALinear
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
from models import LoRALinear, Model, ModelArgs
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
def build_parser():
|
||||
|
@ -1,8 +1,8 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
@ -1,9 +1,10 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import functools
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import functools
|
||||
import time
|
||||
|
||||
import mnist
|
||||
|
||||
|
@ -3,11 +3,10 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
|
||||
import mnist
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import gzip
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
from urllib import request
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def mnist(save_dir="/tmp"):
|
||||
"""
|
||||
|
@ -1,9 +1,10 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
import mnist
|
||||
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
|
||||
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
||||
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
import argparse
|
||||
import time
|
||||
import kwt
|
||||
import mlx.nn as nn
|
||||
import mlx.data as dx
|
||||
import mlx.core as mx
|
||||
import mlx.optimizers as optim
|
||||
from mlx.data.features import mfsc
|
||||
from mlx.data.datasets import load_speechcommands
|
||||
|
||||
import kwt
|
||||
import mlx.core as mx
|
||||
import mlx.data as dx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from mlx.data.datasets import load_speechcommands
|
||||
from mlx.data.features import mfsc
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=True)
|
||||
parser.add_argument(
|
||||
|
@ -6,12 +6,12 @@ from typing import Tuple
|
||||
import mlx.core as mx
|
||||
|
||||
from .model_io import (
|
||||
load_unet,
|
||||
load_text_encoder,
|
||||
_DEFAULT_MODEL,
|
||||
load_autoencoder,
|
||||
load_diffusion_config,
|
||||
load_text_encoder,
|
||||
load_tokenizer,
|
||||
_DEFAULT_MODEL,
|
||||
load_unet,
|
||||
)
|
||||
from .sampler import SimpleEulerSampler
|
||||
|
||||
|
@ -3,20 +3,18 @@
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mlx.utils import tree_unflatten
|
||||
from safetensors import safe_open as safetensor_open
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .clip import CLIPTextModel
|
||||
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig
|
||||
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
||||
from .tokenizer import Tokenizer
|
||||
from .unet import UNetModel
|
||||
from .vae import Autoencoder
|
||||
|
||||
|
||||
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
||||
_MODELS = {
|
||||
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
||||
|
@ -1,9 +1,9 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from .config import DiffusionConfig
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .config import DiffusionConfig
|
||||
|
||||
|
||||
def _linspace(a, b, num):
|
||||
x = mx.arange(0, num) / (num - 1)
|
||||
|
@ -1,14 +1,13 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from stable_diffusion import StableDiffusion
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using stable diffusion"
|
||||
|
@ -1,6 +1,5 @@
|
||||
from transformers import T5ForConditionalGeneration
|
||||
import numpy as np
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".block.", ".layers."),
|
||||
@ -48,8 +47,7 @@ def convert(model_name, dtype):
|
||||
dtype = getattr(np, dtype)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||
weights = {
|
||||
replace_key(k): v.numpy().astype(dtype)
|
||||
for k, v in model.state_dict().items()
|
||||
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
|
||||
}
|
||||
file_name = model_name.replace("/", "-")
|
||||
print(f"Saving weights to {file_name}.npz")
|
||||
|
@ -1,7 +1,7 @@
|
||||
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
|
||||
|
||||
|
||||
def embed(t5_model: str):
|
||||
batch = [
|
||||
@ -51,4 +51,3 @@ if __name__ == "__main__":
|
||||
embed(args.model)
|
||||
else:
|
||||
generate(args.model)
|
||||
|
||||
|
10
t5/t5.py
10
t5/t5.py
@ -1,11 +1,11 @@
|
||||
import argparse
|
||||
from typing import Optional, Tuple, List
|
||||
from time import perf_counter_ns
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten, tree_map
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from transformers import T5Config, T5Tokenizer
|
||||
|
||||
|
||||
@ -166,7 +166,7 @@ class DenseActivation(nn.Module):
|
||||
self.act = nn.relu
|
||||
elif activation == "gelu":
|
||||
self.act = nn.gelu
|
||||
elif activation == "silu":
|
||||
elif activation == "silu":
|
||||
self.act = nn.silu
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
@ -337,7 +337,7 @@ class Tokenizer:
|
||||
self._tokenizer = T5Tokenizer.from_pretrained(
|
||||
args.model,
|
||||
legacy=False,
|
||||
model_max_length=getattr(config, 'n_positions', 512)
|
||||
model_max_length=getattr(config, "n_positions", 512),
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -2,10 +2,11 @@
|
||||
|
||||
import io
|
||||
import itertools
|
||||
import numpy as np
|
||||
import os
|
||||
from urllib import request
|
||||
import zipfile
|
||||
from urllib import request
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_dataset(dataname):
|
||||
|
@ -1,14 +1,14 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import functools
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import math
|
||||
import numpy as np
|
||||
import time
|
||||
from collections import namedtuple
|
||||
|
||||
import datasets
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from tree_utils import tree_flatten
|
||||
|
||||
"""
|
||||
|
@ -3,15 +3,13 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import datasets
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
class TransformerLM(nn.Module):
|
||||
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
|
||||
|
@ -3,11 +3,10 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
|
@ -3,11 +3,10 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
|
@ -5,10 +5,7 @@ import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from whisper import load_models
|
||||
from whisper import audio
|
||||
from whisper import decoding
|
||||
from whisper import transcribe
|
||||
from whisper import audio, decoding, load_models, transcribe
|
||||
|
||||
audio_file = "whisper/assets/ls_test.flac"
|
||||
|
||||
|
@ -1,19 +1,18 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
|
||||
import whisper
|
||||
import whisper.audio as audio
|
||||
import whisper.decoding as decoding
|
||||
import whisper.load_models as load_models
|
||||
import whisper.torch_whisper as torch_whisper
|
||||
import whisper.decoding as decoding
|
||||
|
||||
|
||||
TEST_AUDIO = "whisper/assets/ls_test.flac"
|
||||
|
||||
|
@ -1,6 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from . import load_models
|
||||
from . import audio
|
||||
from . import decoding
|
||||
from . import audio, decoding, load_models
|
||||
from .transcribe import transcribe
|
||||
|
@ -1,12 +1,13 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import zlib
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import zlib
|
||||
from mlx.utils import tree_map
|
||||
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
|
@ -7,13 +7,11 @@ import warnings
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
import torch
|
||||
from mlx.utils import tree_map
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import whisper
|
||||
from . import torch_whisper
|
||||
from . import torch_whisper, whisper
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
|
@ -1,12 +1,12 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import sys
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import sys
|
||||
from typing import Optional, Tuple, Union
|
||||
import tqdm
|
||||
|
||||
|
||||
from .audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
|
@ -2,13 +2,12 @@
|
||||
|
||||
import base64
|
||||
import gzip
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .decoding import decode as decode_function
|
||||
|
Loading…
Reference in New Issue
Block a user