diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 87f41bce..5240881b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,11 @@ repos: -- repo: https://github.com/psf/black +- repo: https://github.com/psf/black-pre-commit-mirror rev: 22.10.0 hooks: - id: black +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: + - --profile=black diff --git a/README.md b/README.md index da2aa90c..0c56324b 100644 --- a/README.md +++ b/README.md @@ -5,18 +5,37 @@ framework](https://github.com/ml-explore/mlx). The [MNIST](mnist) example is a good starting point to learn how to use MLX. -Some more useful examples include: +Some more useful examples are listed below. + +### Text Models - [Transformer language model](transformer_lm) training. -- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2). -- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral) +- Large scale text generation with [LLaMA](llms/llama), + [Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms) + directory. +- A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral). - Parameter efficient fine-tuning with [LoRA](lora). +- Text-to-text multi-task Transformers with [T5](t5). +- Bidirectional language understanding with [BERT](bert). + +### Image Models + - Generating images with [Stable Diffusion](stable_diffusion). + +### Audio Models + - Speech recognition with [OpenAI's Whisper](whisper). -- Bidirectional language understanding with [BERT](bert) + +### Other Models + - Semi-supervised learning on graph-structured data with [GCN](gcn). -Note: You can now directly download a few converted checkpoints from the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face. +### Hugging Face + +Note: You can now directly download a few converted checkpoints from the [MLX +Community](https://huggingface.co/mlx-community) organization on Hugging Face. +We encourage you to join the community and [contribute new +models](https://github.com/ml-explore/mlx-examples/issues/155). ## Contributing diff --git a/bert/convert.py b/bert/convert.py index 82ee8b82..4cad3cc6 100644 --- a/bert/convert.py +++ b/bert/convert.py @@ -1,7 +1,7 @@ -from transformers import BertModel - import argparse + import numpy +from transformers import BertModel def replace_key(key: str) -> str: diff --git a/bert/hf_model.py b/bert/hf_model.py index e63b904b..7b8ad722 100644 --- a/bert/hf_model.py +++ b/bert/hf_model.py @@ -1,7 +1,7 @@ -from transformers import AutoModel, AutoTokenizer - import argparse +from transformers import AutoModel, AutoTokenizer + def run(bert_model: str): batch = [ diff --git a/bert/model.py b/bert/model.py index ff73dea2..f34b3272 100644 --- a/bert/model.py +++ b/bert/model.py @@ -1,13 +1,13 @@ -import numpy as np -from typing import Optional +import argparse from dataclasses import dataclass -from transformers import BertTokenizer -from mlx.utils import tree_unflatten +from typing import Optional import mlx.core as mx import mlx.nn as nn -import argparse import numpy +import numpy as np +from mlx.utils import tree_unflatten +from transformers import BertTokenizer @dataclass diff --git a/cifar/dataset.py b/cifar/dataset.py index 89b10136..32918a4b 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -1,6 +1,7 @@ +import math + import mlx.core as mx from mlx.data.datasets import load_cifar10 -import math def get_cifar10(batch_size, root=None): diff --git a/cifar/main.py b/cifar/main.py index 829417b1..42ed57d3 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -1,11 +1,11 @@ import argparse import time -import resnet -import mlx.nn as nn -import mlx.core as mx -import mlx.optimizers as optim -from dataset import get_cifar10 +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import resnet +from dataset import get_cifar10 parser = argparse.ArgumentParser(add_help=True) parser.add_argument( diff --git a/cifar/resnet.py b/cifar/resnet.py index 758ee3de..e707574f 100644 --- a/cifar/resnet.py +++ b/cifar/resnet.py @@ -6,11 +6,11 @@ There's no BatchNorm is mlx==0.0.4, using LayerNorm instead. """ from typing import Any + import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten - __all__ = [ "ResNet", "resnet20", diff --git a/gcn/datasets.py b/gcn/datasets.py index d5ab59ad..736b7403 100644 --- a/gcn/datasets.py +++ b/gcn/datasets.py @@ -1,9 +1,9 @@ import os -import requests import tarfile import mlx.core as mx import numpy as np +import requests import scipy.sparse as sparse """ diff --git a/gcn/main.py b/gcn/main.py index 66c29550..c3c8795a 100644 --- a/gcn/main.py +++ b/gcn/main.py @@ -3,10 +3,10 @@ from argparse import ArgumentParser import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim +from datasets import download_cora, load_data, train_val_test_mask from mlx.nn.losses import cross_entropy from mlx.utils import tree_flatten -from datasets import download_cora, load_data, train_val_test_mask from gcn import GCN diff --git a/llama/README.md b/llms/llama/README.md similarity index 89% rename from llama/README.md rename to llms/llama/README.md index 0bd6d5af..ffa8105d 100644 --- a/llama/README.md +++ b/llms/llama/README.md @@ -20,8 +20,9 @@ weights you will need to request access from Meta: - [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) - [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) -> [!TIP] -> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step. +> [!TIP] Alternatively, you can also download a few converted checkpoints from +> the [MLX Community](https://huggingface.co/mlx-community) organization on +> Hugging Face and skip the conversion step. You can download the TinyLlama models directly from [Hugging Face](https://huggingface.co/TinyLlama). diff --git a/llama/convert.py b/llms/llama/convert.py similarity index 99% rename from llama/convert.py rename to llms/llama/convert.py index 957f3d22..ecd75c6f 100644 --- a/llama/convert.py +++ b/llms/llama/convert.py @@ -4,8 +4,9 @@ import argparse import collections import glob import json -import numpy as np from pathlib import Path + +import numpy as np import torch diff --git a/llama/llama.py b/llms/llama/llama.py similarity index 99% rename from llama/llama.py rename to llms/llama/llama.py index 2c1f4d16..3d13350b 100644 --- a/llama/llama.py +++ b/llms/llama/llama.py @@ -1,16 +1,16 @@ # Copyright © 2023 Apple Inc. import argparse -from dataclasses import dataclass import json -from pathlib import Path -from typing import Optional, Tuple, List -from sentencepiece import SentencePieceProcessor import time +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_unflatten +from sentencepiece import SentencePieceProcessor @dataclass diff --git a/llama/requirements.txt b/llms/llama/requirements.txt similarity index 100% rename from llama/requirements.txt rename to llms/llama/requirements.txt diff --git a/llama/sample_prompt.txt b/llms/llama/sample_prompt.txt similarity index 100% rename from llama/sample_prompt.txt rename to llms/llama/sample_prompt.txt diff --git a/mistral/.gitignore b/llms/mistral/.gitignore similarity index 100% rename from mistral/.gitignore rename to llms/mistral/.gitignore diff --git a/mistral/README.md b/llms/mistral/README.md similarity index 88% rename from mistral/README.md rename to llms/mistral/README.md index 5b9bcbae..5da5ace0 100644 --- a/mistral/README.md +++ b/llms/mistral/README.md @@ -29,7 +29,10 @@ python convert.py The conversion script will save the converted weights in the same location. > [!TIP] -> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step. +> Alternatively, you can also download a few converted checkpoints from the +> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging +> Face and skip the conversion step. + ### Run diff --git a/mistral/convert.py b/llms/mistral/convert.py similarity index 100% rename from mistral/convert.py rename to llms/mistral/convert.py index f7bcfd86..2d2c7f2e 100644 --- a/mistral/convert.py +++ b/llms/mistral/convert.py @@ -2,10 +2,10 @@ import argparse import json -import numpy as np from pathlib import Path -import torch +import numpy as np +import torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") diff --git a/mistral/mistral.py b/llms/mistral/mistral.py similarity index 99% rename from mistral/mistral.py rename to llms/mistral/mistral.py index 11dcface..f35f4b90 100644 --- a/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -1,15 +1,15 @@ # Copyright © 2023 Apple Inc. import argparse -from dataclasses import dataclass import json +from dataclasses import dataclass from pathlib import Path -from typing import Optional, Tuple, List -from sentencepiece import SentencePieceProcessor +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_map, tree_unflatten +from sentencepiece import SentencePieceProcessor @dataclass diff --git a/mistral/requirements.txt b/llms/mistral/requirements.txt similarity index 100% rename from mistral/requirements.txt rename to llms/mistral/requirements.txt diff --git a/mistral/test.py b/llms/mistral/test.py similarity index 99% rename from mistral/test.py rename to llms/mistral/test.py index bc52d49f..25385626 100644 --- a/mistral/test.py +++ b/llms/mistral/test.py @@ -2,11 +2,10 @@ import unittest +import mistral import mlx.core as mx from mlx.utils import tree_map -import mistral - class TestMistral(unittest.TestCase): def test_model(self): diff --git a/mixtral/README.md b/llms/mixtral/README.md similarity index 100% rename from mixtral/README.md rename to llms/mixtral/README.md diff --git a/mixtral/convert.py b/llms/mixtral/convert.py similarity index 97% rename from mixtral/convert.py rename to llms/mixtral/convert.py index e89433a5..6ea40040 100644 --- a/mixtral/convert.py +++ b/llms/mixtral/convert.py @@ -3,8 +3,9 @@ import argparse import glob import json -import numpy as np from pathlib import Path + +import numpy as np import torch @@ -45,7 +46,7 @@ if __name__ == "__main__": args = json.load(fid) args["model_type"] = "mixtral" with open(model_path / "config.json", "w") as f: - json.dump(args, f, indent=4) + json.dump(args, f, indent=4) torch_files = glob.glob(str(model_path / "consolidated.*.pt")) torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2])) diff --git a/mixtral/mixtral.py b/llms/mixtral/mixtral.py similarity index 99% rename from mixtral/mixtral.py rename to llms/mixtral/mixtral.py index 2e82d305..57b7c5d2 100644 --- a/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -1,17 +1,17 @@ # Copyright © 2023 Apple Inc. import argparse -from dataclasses import dataclass import glob import json -import numpy as np +from dataclasses import dataclass from pathlib import Path -from typing import Optional, Tuple, List -from sentencepiece import SentencePieceProcessor +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn +import numpy as np from mlx.utils import tree_map, tree_unflatten +from sentencepiece import SentencePieceProcessor @dataclass diff --git a/mixtral/params.json b/llms/mixtral/params.json similarity index 100% rename from mixtral/params.json rename to llms/mixtral/params.json diff --git a/mixtral/requirements.txt b/llms/mixtral/requirements.txt similarity index 100% rename from mixtral/requirements.txt rename to llms/mixtral/requirements.txt diff --git a/phi2/.gitignore b/llms/phi2/.gitignore similarity index 100% rename from phi2/.gitignore rename to llms/phi2/.gitignore diff --git a/phi2/README.md b/llms/phi2/README.md similarity index 87% rename from phi2/README.md rename to llms/phi2/README.md index 4729257a..76be6260 100644 --- a/phi2/README.md +++ b/llms/phi2/README.md @@ -17,8 +17,10 @@ python convert.py This will make the `weights.npz` file which MLX can read. -> [!TIP] -> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step. +> [!TIP] Alternatively, you can also download a few converted checkpoints from +> the [MLX Community](https://huggingface.co/mlx-community) organization on +> Hugging Face and skip the conversion step. + ## Generate diff --git a/phi2/convert.py b/llms/phi2/convert.py similarity index 100% rename from phi2/convert.py rename to llms/phi2/convert.py index 5aa07dce..819e0363 100644 --- a/phi2/convert.py +++ b/llms/phi2/convert.py @@ -1,5 +1,5 @@ -from transformers import AutoModelForCausalLM import numpy as np +from transformers import AutoModelForCausalLM def replace_key(key: str) -> str: diff --git a/phi2/phi2.py b/llms/phi2/phi2.py similarity index 99% rename from phi2/phi2.py rename to llms/phi2/phi2.py index 1e57d157..6e9dd7a2 100644 --- a/phi2/phi2.py +++ b/llms/phi2/phi2.py @@ -1,13 +1,13 @@ import argparse +import math +from dataclasses import dataclass from pathlib import Path from typing import Optional -from dataclasses import dataclass -from mlx.utils import tree_unflatten -from transformers import AutoTokenizer import mlx.core as mx import mlx.nn as nn -import math +from mlx.utils import tree_unflatten +from transformers import AutoTokenizer @dataclass @@ -167,7 +167,7 @@ def load_model(model_path: str): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Phi-2 inference script") parser.add_argument( - "--model_path", + "--model-path", type=str, default="phi-2", help="The path to the model weights", @@ -178,7 +178,7 @@ if __name__ == "__main__": default="Write a detailed analogy between mathematics and a lighthouse.", ) parser.add_argument( - "--max_tokens", + "--max-tokens", "-m", type=int, default=100, diff --git a/phi2/requirements.txt b/llms/phi2/requirements.txt similarity index 100% rename from phi2/requirements.txt rename to llms/phi2/requirements.txt diff --git a/llms/qwen/convert.py b/llms/qwen/convert.py index 50a8d7a8..c5fec060 100644 --- a/llms/qwen/convert.py +++ b/llms/qwen/convert.py @@ -1,8 +1,9 @@ import argparse -from transformers import AutoModelForCausalLM +import json + import numpy as np import torch -import json +from transformers import AutoModelForCausalLM def replace_key(key: str) -> str: diff --git a/llms/qwen/qwen.py b/llms/qwen/qwen.py index c490d650..3b153eb9 100644 --- a/llms/qwen/qwen.py +++ b/llms/qwen/qwen.py @@ -1,6 +1,7 @@ import argparse -from dataclasses import dataclass import json +from dataclasses import dataclass + import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_unflatten diff --git a/lora/convert.py b/lora/convert.py index 3fdb5d42..6f71d772 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -2,12 +2,12 @@ import argparse import json -import numpy as np -from pathlib import Path -import shutil import os -import torch +import shutil +from pathlib import Path +import numpy as np +import torch if __name__ == "__main__": parser = argparse.ArgumentParser( diff --git a/lora/data/wikisql.py b/lora/data/wikisql.py index 3ac11339..c1303824 100644 --- a/lora/data/wikisql.py +++ b/lora/data/wikisql.py @@ -32,8 +32,8 @@ class WikiSQL: def _maybe_download(self, data_dir): if not os.path.exists(data_dir): import io - from urllib import request import tarfile + from urllib import request url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2" r = request.urlopen(url) diff --git a/lora/lora.py b/lora/lora.py index e1412da3..875b51ef 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -3,19 +3,17 @@ import argparse import json import math -import numpy as np -from pathlib import Path -from sentencepiece import SentencePieceProcessor import time -from typing import Optional, Tuple, List +from pathlib import Path +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim -from mlx.utils import tree_map, tree_flatten, tree_unflatten - - -from models import ModelArgs, Model, LoRALinear +import numpy as np +from mlx.utils import tree_flatten, tree_map, tree_unflatten +from models import LoRALinear, Model, ModelArgs +from sentencepiece import SentencePieceProcessor def build_parser(): diff --git a/lora/models.py b/lora/models.py index e0bfa9f9..31eef654 100644 --- a/lora/models.py +++ b/lora/models.py @@ -1,8 +1,8 @@ # Copyright © 2023 Apple Inc. -from dataclasses import dataclass import math -from typing import Optional, Tuple, List +from dataclasses import dataclass +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn diff --git a/mnist/jax_main.py b/mnist/jax_main.py index 22ccb96c..088a4993 100644 --- a/mnist/jax_main.py +++ b/mnist/jax_main.py @@ -1,9 +1,10 @@ # Copyright © 2023 Apple Inc. +import functools +import time + import jax import jax.numpy as jnp -import functools -import time import mnist diff --git a/mnist/main.py b/mnist/main.py index dc8d2d7e..b921261f 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -3,11 +3,10 @@ import argparse import time -import numpy as np - import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim +import numpy as np import mnist diff --git a/mnist/mnist.py b/mnist/mnist.py index 966ac912..b2e5c510 100644 --- a/mnist/mnist.py +++ b/mnist/mnist.py @@ -1,11 +1,12 @@ # Copyright © 2023 Apple Inc. import gzip -import numpy as np import os import pickle from urllib import request +import numpy as np + def mnist(save_dir="/tmp"): """ diff --git a/mnist/torch_main.py b/mnist/torch_main.py index 0aa8519b..a6c0380a 100644 --- a/mnist/torch_main.py +++ b/mnist/torch_main.py @@ -1,9 +1,10 @@ # Copyright © 2023 Apple Inc. import argparse -import torch import time +import torch + import mnist diff --git a/speechcommands/kwt.py b/speechcommands/kwt.py index 63d4e074..bf0b23ce 100644 --- a/speechcommands/kwt.py +++ b/speechcommands/kwt.py @@ -1,9 +1,9 @@ from typing import Any + import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten - __all__ = ["KWT", "kwt1", "kwt2", "kwt3"] diff --git a/speechcommands/main.py b/speechcommands/main.py index 492b8159..3f81e40b 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -1,13 +1,13 @@ import argparse import time -import kwt -import mlx.nn as nn -import mlx.data as dx -import mlx.core as mx -import mlx.optimizers as optim -from mlx.data.features import mfsc -from mlx.data.datasets import load_speechcommands +import kwt +import mlx.core as mx +import mlx.data as dx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.data.datasets import load_speechcommands +from mlx.data.features import mfsc parser = argparse.ArgumentParser(add_help=True) parser.add_argument( diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index 778eff39..899336e8 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -6,12 +6,12 @@ from typing import Tuple import mlx.core as mx from .model_io import ( - load_unet, - load_text_encoder, + _DEFAULT_MODEL, load_autoencoder, load_diffusion_config, + load_text_encoder, load_tokenizer, - _DEFAULT_MODEL, + load_unet, ) from .sampler import SimpleEulerSampler diff --git a/stable_diffusion/stable_diffusion/model_io.py b/stable_diffusion/stable_diffusion/model_io.py index 57879ef9..819754b7 100644 --- a/stable_diffusion/stable_diffusion/model_io.py +++ b/stable_diffusion/stable_diffusion/model_io.py @@ -3,20 +3,18 @@ import json from functools import partial +import mlx.core as mx import numpy as np from huggingface_hub import hf_hub_download +from mlx.utils import tree_unflatten from safetensors import safe_open as safetensor_open -import mlx.core as mx -from mlx.utils import tree_unflatten - from .clip import CLIPTextModel -from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig +from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig from .tokenizer import Tokenizer from .unet import UNetModel from .vae import Autoencoder - _DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base" _MODELS = { # See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license diff --git a/stable_diffusion/stable_diffusion/sampler.py b/stable_diffusion/stable_diffusion/sampler.py index a1edf931..0c7b80ee 100644 --- a/stable_diffusion/stable_diffusion/sampler.py +++ b/stable_diffusion/stable_diffusion/sampler.py @@ -1,9 +1,9 @@ # Copyright © 2023 Apple Inc. -from .config import DiffusionConfig - import mlx.core as mx +from .config import DiffusionConfig + def _linspace(a, b, num): x = mx.arange(0, num) / (num - 1) diff --git a/stable_diffusion/txt2image.py b/stable_diffusion/txt2image.py index 9ed3e6e2..a487cd22 100644 --- a/stable_diffusion/txt2image.py +++ b/stable_diffusion/txt2image.py @@ -1,14 +1,13 @@ # Copyright © 2023 Apple Inc. import argparse + +import mlx.core as mx from PIL import Image from tqdm import tqdm -import mlx.core as mx - from stable_diffusion import StableDiffusion - if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate images from a textual prompt using stable diffusion" diff --git a/t5/convert.py b/t5/convert.py index 089d262d..e2108a0c 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -1,6 +1,5 @@ -from transformers import T5ForConditionalGeneration import numpy as np - +from transformers import T5ForConditionalGeneration SHARED_REPLACEMENT_PATTERNS = [ (".block.", ".layers."), @@ -48,8 +47,7 @@ def convert(model_name, dtype): dtype = getattr(np, dtype) model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") weights = { - replace_key(k): v.numpy().astype(dtype) - for k, v in model.state_dict().items() + replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items() } file_name = model_name.replace("/", "-") print(f"Saving weights to {file_name}.npz") diff --git a/t5/hf_t5.py b/t5/hf_t5.py index ddd99610..ce01e355 100644 --- a/t5/hf_t5.py +++ b/t5/hf_t5.py @@ -1,7 +1,7 @@ -from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer - import argparse +from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration + def embed(t5_model: str): batch = [ @@ -51,4 +51,3 @@ if __name__ == "__main__": embed(args.model) else: generate(args.model) - diff --git a/t5/t5.py b/t5/t5.py index f80c3cb3..29ec4291 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,11 +1,11 @@ import argparse -from typing import Optional, Tuple, List from time import perf_counter_ns +from typing import List, Optional, Tuple -import numpy as np import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_unflatten, tree_map +import numpy as np +from mlx.utils import tree_map, tree_unflatten from transformers import T5Config, T5Tokenizer @@ -166,7 +166,7 @@ class DenseActivation(nn.Module): self.act = nn.relu elif activation == "gelu": self.act = nn.gelu - elif activation == "silu": + elif activation == "silu": self.act = nn.silu else: raise ValueError(f"Unknown activation: {activation}") @@ -337,7 +337,7 @@ class Tokenizer: self._tokenizer = T5Tokenizer.from_pretrained( args.model, legacy=False, - model_max_length=getattr(config, 'n_positions', 512) + model_max_length=getattr(config, "n_positions", 512), ) @property diff --git a/transformer_lm/datasets.py b/transformer_lm/datasets.py index 78db713e..7b077ef3 100644 --- a/transformer_lm/datasets.py +++ b/transformer_lm/datasets.py @@ -2,10 +2,11 @@ import io import itertools -import numpy as np import os -from urllib import request import zipfile +from urllib import request + +import numpy as np def load_dataset(dataname): diff --git a/transformer_lm/jax_main.py b/transformer_lm/jax_main.py index 947543f2..7b622068 100644 --- a/transformer_lm/jax_main.py +++ b/transformer_lm/jax_main.py @@ -1,14 +1,14 @@ # Copyright © 2023 Apple Inc. import functools -import jax -import jax.numpy as jnp import math -import numpy as np import time from collections import namedtuple import datasets +import jax +import jax.numpy as jnp +import numpy as np from tree_utils import tree_flatten """ diff --git a/transformer_lm/main.py b/transformer_lm/main.py index b0d2966b..dea5a982 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -3,15 +3,13 @@ import math import time -import numpy as np - +import datasets import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim +import numpy as np from mlx.utils import tree_flatten -import datasets - class TransformerLM(nn.Module): def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int): diff --git a/transformer_lm/tf_main.py b/transformer_lm/tf_main.py index 58165d70..8fa0a95a 100644 --- a/transformer_lm/tf_main.py +++ b/transformer_lm/tf_main.py @@ -3,11 +3,10 @@ import math import time +import datasets import numpy as np import tensorflow as tf -import datasets - def to_samples(context_size, dataset): tokens = dataset.size diff --git a/transformer_lm/torch_main.py b/transformer_lm/torch_main.py index e81f744c..b39924a9 100644 --- a/transformer_lm/torch_main.py +++ b/transformer_lm/torch_main.py @@ -3,11 +3,10 @@ import math import time +import datasets import numpy as np import torch -import datasets - def to_samples(context_size, dataset): tokens = dataset.size diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 228a3b36..10025952 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -5,10 +5,7 @@ import time import mlx.core as mx -from whisper import load_models -from whisper import audio -from whisper import decoding -from whisper import transcribe +from whisper import audio, decoding, load_models, transcribe audio_file = "whisper/assets/ls_test.flac" diff --git a/whisper/test.py b/whisper/test.py index 3e7630a9..13a1f91a 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -1,19 +1,18 @@ # Copyright © 2023 Apple Inc. +import os +import subprocess import unittest import mlx.core as mx import numpy as np -import os -import subprocess import torch import whisper import whisper.audio as audio +import whisper.decoding as decoding import whisper.load_models as load_models import whisper.torch_whisper as torch_whisper -import whisper.decoding as decoding - TEST_AUDIO = "whisper/assets/ls_test.flac" diff --git a/whisper/whisper/__init__.py b/whisper/whisper/__init__.py index 302e32ae..e234711c 100644 --- a/whisper/whisper/__init__.py +++ b/whisper/whisper/__init__.py @@ -1,6 +1,4 @@ # Copyright © 2023 Apple Inc. -from . import load_models -from . import audio -from . import decoding +from . import audio, decoding, load_models from .transcribe import transcribe diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index d5025444..b7786f92 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -1,12 +1,13 @@ # Copyright © 2023 Apple Inc. +import zlib from dataclasses import dataclass, field, replace from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union + import mlx.core as mx -from mlx.utils import tree_map import mlx.nn as nn import numpy as np -import zlib +from mlx.utils import tree_map from .audio import CHUNK_LENGTH from .tokenizer import Tokenizer, get_tokenizer diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index ffdccf44..32e62119 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -7,13 +7,11 @@ import warnings from typing import List import mlx.core as mx -from mlx.utils import tree_map - import torch +from mlx.utils import tree_map from tqdm import tqdm -from . import whisper -from . import torch_whisper +from . import torch_whisper, whisper _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 06f3c9ea..67232136 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -1,12 +1,12 @@ # Copyright © 2023 Apple Inc. +import sys +from typing import Optional, Tuple, Union + import mlx.core as mx import numpy as np -import sys -from typing import Optional, Tuple, Union import tqdm - from .audio import ( FRAMES_PER_SECOND, HOP_LENGTH, diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index 8ee6d7d9..dfeb1e73 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -2,13 +2,12 @@ import base64 import gzip -from dataclasses import dataclass import math +from dataclasses import dataclass from typing import Dict, Iterable, Optional import mlx.core as mx import mlx.nn as nn - import numpy as np from .decoding import decode as decode_function