mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-23 14:08:07 +08:00 
			
		
		
		
	Add llms subdir + update README (#145)
* add llms subdir + update README * nits * use same pre-commit as mlx * update readmes a bit * format
This commit is contained in:
		| @@ -1,5 +1,11 @@ | ||||
| repos: | ||||
| -   repo: https://github.com/psf/black | ||||
| -   repo: https://github.com/psf/black-pre-commit-mirror | ||||
|     rev: 22.10.0 | ||||
|     hooks: | ||||
|     -   id: black | ||||
| -   repo: https://github.com/pycqa/isort | ||||
|     rev: 5.12.0 | ||||
|     hooks: | ||||
|     -   id: isort | ||||
|         args: | ||||
|             - --profile=black | ||||
|   | ||||
							
								
								
									
										29
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								README.md
									
									
									
									
									
								
							| @@ -5,18 +5,37 @@ framework](https://github.com/ml-explore/mlx). | ||||
|  | ||||
| The [MNIST](mnist) example is a good starting point to learn how to use MLX. | ||||
|  | ||||
| Some more useful examples include: | ||||
| Some more useful examples are listed below. | ||||
|  | ||||
| ### Text Models  | ||||
|  | ||||
| - [Transformer language model](transformer_lm) training. | ||||
| - Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2).   | ||||
| - Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral) | ||||
| - Large scale text generation with [LLaMA](llms/llama), | ||||
|   [Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms) | ||||
|   directory. | ||||
| - A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral). | ||||
| - Parameter efficient fine-tuning with [LoRA](lora). | ||||
| - Text-to-text multi-task Transformers with [T5](t5). | ||||
| - Bidirectional language understanding with [BERT](bert). | ||||
|  | ||||
| ### Image Models  | ||||
|  | ||||
| - Generating images with [Stable Diffusion](stable_diffusion). | ||||
|  | ||||
| ### Audio Models | ||||
|  | ||||
| - Speech recognition with [OpenAI's Whisper](whisper). | ||||
| - Bidirectional language understanding with [BERT](bert) | ||||
|  | ||||
| ### Other Models  | ||||
|  | ||||
| - Semi-supervised learning on graph-structured data with [GCN](gcn). | ||||
|  | ||||
| Note: You can now directly download a few converted checkpoints from the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face. | ||||
| ### Hugging Face | ||||
|  | ||||
| Note: You can now directly download a few converted checkpoints from the [MLX | ||||
| Community](https://huggingface.co/mlx-community) organization on Hugging Face. | ||||
| We encourage you to join the community and [contribute new | ||||
| models](https://github.com/ml-explore/mlx-examples/issues/155). | ||||
|  | ||||
| ## Contributing  | ||||
|  | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| from transformers import BertModel | ||||
|  | ||||
| import argparse | ||||
|  | ||||
| import numpy | ||||
| from transformers import BertModel | ||||
|  | ||||
|  | ||||
| def replace_key(key: str) -> str: | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| from transformers import AutoModel, AutoTokenizer | ||||
|  | ||||
| import argparse | ||||
|  | ||||
| from transformers import AutoModel, AutoTokenizer | ||||
|  | ||||
|  | ||||
| def run(bert_model: str): | ||||
|     batch = [ | ||||
|   | ||||
| @@ -1,13 +1,13 @@ | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import argparse | ||||
| from dataclasses import dataclass | ||||
| from transformers import BertTokenizer | ||||
| from mlx.utils import tree_unflatten | ||||
| from typing import Optional | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import argparse | ||||
| import numpy | ||||
| import numpy as np | ||||
| from mlx.utils import tree_unflatten | ||||
| from transformers import BertTokenizer | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| import math | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.data.datasets import load_cifar10 | ||||
| import math | ||||
|  | ||||
|  | ||||
| def get_cifar10(batch_size, root=None): | ||||
|   | ||||
| @@ -1,11 +1,11 @@ | ||||
| import argparse | ||||
| import time | ||||
| import resnet | ||||
| import mlx.nn as nn | ||||
| import mlx.core as mx | ||||
| import mlx.optimizers as optim | ||||
| from dataset import get_cifar10 | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx.optimizers as optim | ||||
| import resnet | ||||
| from dataset import get_cifar10 | ||||
|  | ||||
| parser = argparse.ArgumentParser(add_help=True) | ||||
| parser.add_argument( | ||||
|   | ||||
| @@ -6,11 +6,11 @@ There's no BatchNorm is mlx==0.0.4, using LayerNorm instead. | ||||
| """ | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| from mlx.utils import tree_flatten | ||||
|  | ||||
|  | ||||
| __all__ = [ | ||||
|     "ResNet", | ||||
|     "resnet20", | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
| import os | ||||
| import requests | ||||
| import tarfile | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| import requests | ||||
| import scipy.sparse as sparse | ||||
|  | ||||
| """ | ||||
|   | ||||
| @@ -3,10 +3,10 @@ from argparse import ArgumentParser | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx.optimizers as optim | ||||
| from datasets import download_cora, load_data, train_val_test_mask | ||||
| from mlx.nn.losses import cross_entropy | ||||
| from mlx.utils import tree_flatten | ||||
|  | ||||
| from datasets import download_cora, load_data, train_val_test_mask | ||||
| from gcn import GCN | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -20,8 +20,9 @@ weights you will need to request access from Meta: | ||||
| - [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) | ||||
| - [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) | ||||
| 
 | ||||
| > [!TIP] | ||||
| > Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step. | ||||
| > [!TIP] Alternatively, you can also download a few converted checkpoints from | ||||
| > the [MLX Community](https://huggingface.co/mlx-community) organization on | ||||
| > Hugging Face and skip the conversion step. | ||||
| 
 | ||||
| You can download the TinyLlama models directly from [Hugging | ||||
| Face](https://huggingface.co/TinyLlama). | ||||
| @@ -4,8 +4,9 @@ import argparse | ||||
| import collections | ||||
| import glob | ||||
| import json | ||||
| import numpy as np | ||||
| from pathlib import Path | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch | ||||
| 
 | ||||
| 
 | ||||
| @@ -1,16 +1,16 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| 
 | ||||
| import argparse | ||||
| from dataclasses import dataclass | ||||
| import json | ||||
| from pathlib import Path | ||||
| from typing import Optional, Tuple, List | ||||
| from sentencepiece import SentencePieceProcessor | ||||
| import time | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| from typing import List, Optional, Tuple | ||||
| 
 | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| from mlx.utils import tree_unflatten | ||||
| from sentencepiece import SentencePieceProcessor | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| @@ -29,7 +29,10 @@ python convert.py | ||||
| The conversion script will save the converted weights in the same location. | ||||
| 
 | ||||
| > [!TIP] | ||||
| > Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step. | ||||
| > Alternatively, you can also download a few converted checkpoints from the | ||||
| > [MLX Community](https://huggingface.co/mlx-community) organization on Hugging | ||||
| > Face and skip the conversion step. | ||||
| 
 | ||||
| 
 | ||||
| ### Run | ||||
| 
 | ||||
| @@ -2,10 +2,10 @@ | ||||
| 
 | ||||
| import argparse | ||||
| import json | ||||
| import numpy as np | ||||
| from pathlib import Path | ||||
| import torch | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") | ||||
| @@ -1,15 +1,15 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| 
 | ||||
| import argparse | ||||
| from dataclasses import dataclass | ||||
| import json | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| from typing import Optional, Tuple, List | ||||
| from sentencepiece import SentencePieceProcessor | ||||
| from typing import List, Optional, Tuple | ||||
| 
 | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| from mlx.utils import tree_map, tree_unflatten | ||||
| from sentencepiece import SentencePieceProcessor | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| @@ -2,11 +2,10 @@ | ||||
| 
 | ||||
| import unittest | ||||
| 
 | ||||
| import mistral | ||||
| import mlx.core as mx | ||||
| from mlx.utils import tree_map | ||||
| 
 | ||||
| import mistral | ||||
| 
 | ||||
| 
 | ||||
| class TestMistral(unittest.TestCase): | ||||
|     def test_model(self): | ||||
| @@ -3,8 +3,9 @@ | ||||
| import argparse | ||||
| import glob | ||||
| import json | ||||
| import numpy as np | ||||
| from pathlib import Path | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch | ||||
| 
 | ||||
| 
 | ||||
| @@ -1,17 +1,17 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| 
 | ||||
| import argparse | ||||
| from dataclasses import dataclass | ||||
| import glob | ||||
| import json | ||||
| import numpy as np | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| from typing import Optional, Tuple, List | ||||
| from sentencepiece import SentencePieceProcessor | ||||
| from typing import List, Optional, Tuple | ||||
| 
 | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import numpy as np | ||||
| from mlx.utils import tree_map, tree_unflatten | ||||
| from sentencepiece import SentencePieceProcessor | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
							
								
								
									
										0
									
								
								phi2/.gitignore → llms/phi2/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										0
									
								
								phi2/.gitignore → llms/phi2/.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -17,8 +17,10 @@ python convert.py | ||||
| 
 | ||||
| This will make the `weights.npz` file which MLX can read. | ||||
| 
 | ||||
| > [!TIP] | ||||
| > Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step. | ||||
| > [!TIP] Alternatively, you can also download a few converted checkpoints from | ||||
| > the [MLX Community](https://huggingface.co/mlx-community) organization on | ||||
| > Hugging Face and skip the conversion step. | ||||
| 
 | ||||
| 
 | ||||
| ## Generate  | ||||
| 
 | ||||
| @@ -1,5 +1,5 @@ | ||||
| from transformers import AutoModelForCausalLM | ||||
| import numpy as np | ||||
| from transformers import AutoModelForCausalLM | ||||
| 
 | ||||
| 
 | ||||
| def replace_key(key: str) -> str: | ||||
| @@ -1,13 +1,13 @@ | ||||
| import argparse | ||||
| import math | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| from typing import Optional | ||||
| from dataclasses import dataclass | ||||
| from mlx.utils import tree_unflatten | ||||
| from transformers import AutoTokenizer | ||||
| 
 | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import math | ||||
| from mlx.utils import tree_unflatten | ||||
| from transformers import AutoTokenizer | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| @@ -167,7 +167,7 @@ def load_model(model_path: str): | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Phi-2 inference script") | ||||
|     parser.add_argument( | ||||
|         "--model_path", | ||||
|         "--model-path", | ||||
|         type=str, | ||||
|         default="phi-2", | ||||
|         help="The path to the model weights", | ||||
| @@ -178,7 +178,7 @@ if __name__ == "__main__": | ||||
|         default="Write a detailed analogy between mathematics and a lighthouse.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_tokens", | ||||
|         "--max-tokens", | ||||
|         "-m", | ||||
|         type=int, | ||||
|         default=100, | ||||
| @@ -1,8 +1,9 @@ | ||||
| import argparse | ||||
| from transformers import AutoModelForCausalLM | ||||
| import json | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| import json | ||||
| from transformers import AutoModelForCausalLM | ||||
|  | ||||
|  | ||||
| def replace_key(key: str) -> str: | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| import argparse | ||||
| from dataclasses import dataclass | ||||
| import json | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| from mlx.utils import tree_unflatten | ||||
|   | ||||
| @@ -2,12 +2,12 @@ | ||||
|  | ||||
| import argparse | ||||
| import json | ||||
| import numpy as np | ||||
| from pathlib import Path | ||||
| import shutil | ||||
| import os | ||||
| import torch | ||||
| import shutil | ||||
| from pathlib import Path | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|   | ||||
| @@ -32,8 +32,8 @@ class WikiSQL: | ||||
|     def _maybe_download(self, data_dir): | ||||
|         if not os.path.exists(data_dir): | ||||
|             import io | ||||
|             from urllib import request | ||||
|             import tarfile | ||||
|             from urllib import request | ||||
|  | ||||
|             url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2" | ||||
|             r = request.urlopen(url) | ||||
|   | ||||
							
								
								
									
										14
									
								
								lora/lora.py
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								lora/lora.py
									
									
									
									
									
								
							| @@ -3,19 +3,17 @@ | ||||
| import argparse | ||||
| import json | ||||
| import math | ||||
| import numpy as np | ||||
| from pathlib import Path | ||||
| from sentencepiece import SentencePieceProcessor | ||||
| import time | ||||
| from typing import Optional, Tuple, List | ||||
| from pathlib import Path | ||||
| from typing import List, Optional, Tuple | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx.optimizers as optim | ||||
| from mlx.utils import tree_map, tree_flatten, tree_unflatten | ||||
|  | ||||
|  | ||||
| from models import ModelArgs, Model, LoRALinear | ||||
| import numpy as np | ||||
| from mlx.utils import tree_flatten, tree_map, tree_unflatten | ||||
| from models import LoRALinear, Model, ModelArgs | ||||
| from sentencepiece import SentencePieceProcessor | ||||
|  | ||||
|  | ||||
| def build_parser(): | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| from dataclasses import dataclass | ||||
| import math | ||||
| from typing import Optional, Tuple, List | ||||
| from dataclasses import dataclass | ||||
| from typing import List, Optional, Tuple | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
|   | ||||
| @@ -1,9 +1,10 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import functools | ||||
| import time | ||||
|  | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| import functools | ||||
| import time | ||||
|  | ||||
| import mnist | ||||
|  | ||||
|   | ||||
| @@ -3,11 +3,10 @@ | ||||
| import argparse | ||||
| import time | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx.optimizers as optim | ||||
| import numpy as np | ||||
|  | ||||
| import mnist | ||||
|  | ||||
|   | ||||
| @@ -1,11 +1,12 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import gzip | ||||
| import numpy as np | ||||
| import os | ||||
| import pickle | ||||
| from urllib import request | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def mnist(save_dir="/tmp"): | ||||
|     """ | ||||
|   | ||||
| @@ -1,9 +1,10 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import argparse | ||||
| import torch | ||||
| import time | ||||
|  | ||||
| import torch | ||||
|  | ||||
| import mnist | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
| from typing import Any | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| from mlx.utils import tree_flatten | ||||
|  | ||||
|  | ||||
| __all__ = ["KWT", "kwt1", "kwt2", "kwt3"] | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,13 +1,13 @@ | ||||
| import argparse | ||||
| import time | ||||
| import kwt | ||||
| import mlx.nn as nn | ||||
| import mlx.data as dx | ||||
| import mlx.core as mx | ||||
| import mlx.optimizers as optim | ||||
| from mlx.data.features import mfsc | ||||
| from mlx.data.datasets import load_speechcommands | ||||
|  | ||||
| import kwt | ||||
| import mlx.core as mx | ||||
| import mlx.data as dx | ||||
| import mlx.nn as nn | ||||
| import mlx.optimizers as optim | ||||
| from mlx.data.datasets import load_speechcommands | ||||
| from mlx.data.features import mfsc | ||||
|  | ||||
| parser = argparse.ArgumentParser(add_help=True) | ||||
| parser.add_argument( | ||||
|   | ||||
| @@ -6,12 +6,12 @@ from typing import Tuple | ||||
| import mlx.core as mx | ||||
|  | ||||
| from .model_io import ( | ||||
|     load_unet, | ||||
|     load_text_encoder, | ||||
|     _DEFAULT_MODEL, | ||||
|     load_autoencoder, | ||||
|     load_diffusion_config, | ||||
|     load_text_encoder, | ||||
|     load_tokenizer, | ||||
|     _DEFAULT_MODEL, | ||||
|     load_unet, | ||||
| ) | ||||
| from .sampler import SimpleEulerSampler | ||||
|  | ||||
|   | ||||
| @@ -3,20 +3,18 @@ | ||||
| import json | ||||
| from functools import partial | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| from huggingface_hub import hf_hub_download | ||||
| from mlx.utils import tree_unflatten | ||||
| from safetensors import safe_open as safetensor_open | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.utils import tree_unflatten | ||||
|  | ||||
| from .clip import CLIPTextModel | ||||
| from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig | ||||
| from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig | ||||
| from .tokenizer import Tokenizer | ||||
| from .unet import UNetModel | ||||
| from .vae import Autoencoder | ||||
|  | ||||
|  | ||||
| _DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base" | ||||
| _MODELS = { | ||||
|     # See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| from .config import DiffusionConfig | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| from .config import DiffusionConfig | ||||
|  | ||||
|  | ||||
| def _linspace(a, b, num): | ||||
|     x = mx.arange(0, num) / (num - 1) | ||||
|   | ||||
| @@ -1,14 +1,13 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import argparse | ||||
|  | ||||
| import mlx.core as mx | ||||
| from PIL import Image | ||||
| from tqdm import tqdm | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| from stable_diffusion import StableDiffusion | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Generate images from a textual prompt using stable diffusion" | ||||
|   | ||||
| @@ -1,6 +1,5 @@ | ||||
| from transformers import T5ForConditionalGeneration | ||||
| import numpy as np | ||||
|  | ||||
| from transformers import T5ForConditionalGeneration | ||||
|  | ||||
| SHARED_REPLACEMENT_PATTERNS = [ | ||||
|     (".block.", ".layers."), | ||||
| @@ -48,8 +47,7 @@ def convert(model_name, dtype): | ||||
|     dtype = getattr(np, dtype) | ||||
|     model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") | ||||
|     weights = { | ||||
|         replace_key(k): v.numpy().astype(dtype) | ||||
|         for k, v in model.state_dict().items() | ||||
|         replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items() | ||||
|     } | ||||
|     file_name = model_name.replace("/", "-") | ||||
|     print(f"Saving weights to {file_name}.npz") | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer | ||||
|  | ||||
| import argparse | ||||
|  | ||||
| from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration | ||||
|  | ||||
|  | ||||
| def embed(t5_model: str): | ||||
|     batch = [ | ||||
| @@ -51,4 +51,3 @@ if __name__ == "__main__": | ||||
|         embed(args.model) | ||||
|     else: | ||||
|         generate(args.model) | ||||
|  | ||||
|   | ||||
							
								
								
									
										8
									
								
								t5/t5.py
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								t5/t5.py
									
									
									
									
									
								
							| @@ -1,11 +1,11 @@ | ||||
| import argparse | ||||
| from typing import Optional, Tuple, List | ||||
| from time import perf_counter_ns | ||||
| from typing import List, Optional, Tuple | ||||
|  | ||||
| import numpy as np | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| from mlx.utils import tree_unflatten, tree_map | ||||
| import numpy as np | ||||
| from mlx.utils import tree_map, tree_unflatten | ||||
| from transformers import T5Config, T5Tokenizer | ||||
|  | ||||
|  | ||||
| @@ -337,7 +337,7 @@ class Tokenizer: | ||||
|         self._tokenizer = T5Tokenizer.from_pretrained( | ||||
|             args.model, | ||||
|             legacy=False, | ||||
|             model_max_length=getattr(config, 'n_positions', 512) | ||||
|             model_max_length=getattr(config, "n_positions", 512), | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|   | ||||
| @@ -2,10 +2,11 @@ | ||||
|  | ||||
| import io | ||||
| import itertools | ||||
| import numpy as np | ||||
| import os | ||||
| from urllib import request | ||||
| import zipfile | ||||
| from urllib import request | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def load_dataset(dataname): | ||||
|   | ||||
| @@ -1,14 +1,14 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import functools | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| import math | ||||
| import numpy as np | ||||
| import time | ||||
| from collections import namedtuple | ||||
|  | ||||
| import datasets | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| import numpy as np | ||||
| from tree_utils import tree_flatten | ||||
|  | ||||
| """ | ||||
|   | ||||
| @@ -3,15 +3,13 @@ | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| import datasets | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx.optimizers as optim | ||||
| import numpy as np | ||||
| from mlx.utils import tree_flatten | ||||
|  | ||||
| import datasets | ||||
|  | ||||
|  | ||||
| class TransformerLM(nn.Module): | ||||
|     def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int): | ||||
|   | ||||
| @@ -3,11 +3,10 @@ | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import datasets | ||||
| import numpy as np | ||||
| import tensorflow as tf | ||||
|  | ||||
| import datasets | ||||
|  | ||||
|  | ||||
| def to_samples(context_size, dataset): | ||||
|     tokens = dataset.size | ||||
|   | ||||
| @@ -3,11 +3,10 @@ | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import datasets | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| import datasets | ||||
|  | ||||
|  | ||||
| def to_samples(context_size, dataset): | ||||
|     tokens = dataset.size | ||||
|   | ||||
| @@ -5,10 +5,7 @@ import time | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| from whisper import load_models | ||||
| from whisper import audio | ||||
| from whisper import decoding | ||||
| from whisper import transcribe | ||||
| from whisper import audio, decoding, load_models, transcribe | ||||
|  | ||||
| audio_file = "whisper/assets/ls_test.flac" | ||||
|  | ||||
|   | ||||
| @@ -1,19 +1,18 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import os | ||||
| import subprocess | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| import os | ||||
| import subprocess | ||||
| import torch | ||||
|  | ||||
| import whisper | ||||
| import whisper.audio as audio | ||||
| import whisper.decoding as decoding | ||||
| import whisper.load_models as load_models | ||||
| import whisper.torch_whisper as torch_whisper | ||||
| import whisper.decoding as decoding | ||||
|  | ||||
|  | ||||
| TEST_AUDIO = "whisper/assets/ls_test.flac" | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,4 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| from . import load_models | ||||
| from . import audio | ||||
| from . import decoding | ||||
| from . import audio, decoding, load_models | ||||
| from .transcribe import transcribe | ||||
|   | ||||
| @@ -1,12 +1,13 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import zlib | ||||
| from dataclasses import dataclass, field, replace | ||||
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.utils import tree_map | ||||
| import mlx.nn as nn | ||||
| import numpy as np | ||||
| import zlib | ||||
| from mlx.utils import tree_map | ||||
|  | ||||
| from .audio import CHUNK_LENGTH | ||||
| from .tokenizer import Tokenizer, get_tokenizer | ||||
|   | ||||
| @@ -7,13 +7,11 @@ import warnings | ||||
| from typing import List | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.utils import tree_map | ||||
|  | ||||
| import torch | ||||
| from mlx.utils import tree_map | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from . import whisper | ||||
| from . import torch_whisper | ||||
| from . import torch_whisper, whisper | ||||
|  | ||||
| _MODELS = { | ||||
|     "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", | ||||
|   | ||||
| @@ -1,12 +1,12 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import sys | ||||
| from typing import Optional, Tuple, Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| import sys | ||||
| from typing import Optional, Tuple, Union | ||||
| import tqdm | ||||
|  | ||||
|  | ||||
| from .audio import ( | ||||
|     FRAMES_PER_SECOND, | ||||
|     HOP_LENGTH, | ||||
|   | ||||
| @@ -2,13 +2,12 @@ | ||||
|  | ||||
| import base64 | ||||
| import gzip | ||||
| from dataclasses import dataclass | ||||
| import math | ||||
| from dataclasses import dataclass | ||||
| from typing import Dict, Iterable, Optional | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from .decoding import decode as decode_function | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun