Add llms subdir + update README (#145)

* add llms subdir + update README

* nits

* use same pre-commit as mlx

* update readmes a bit

* format
This commit is contained in:
Awni Hannun 2023-12-20 10:22:25 -08:00 committed by GitHub
parent aed14618ca
commit 27c0a8c002
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
62 changed files with 164 additions and 146 deletions

View File

@ -1,5 +1,11 @@
repos:
- repo: https://github.com/psf/black
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 22.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args:
- --profile=black

View File

@ -5,18 +5,37 @@ framework](https://github.com/ml-explore/mlx).
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
Some more useful examples include:
Some more useful examples are listed below.
### Text Models
- [Transformer language model](transformer_lm) training.
- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2).
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral)
- Large scale text generation with [LLaMA](llms/llama),
[Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms)
directory.
- A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral).
- Parameter efficient fine-tuning with [LoRA](lora).
- Text-to-text multi-task Transformers with [T5](t5).
- Bidirectional language understanding with [BERT](bert).
### Image Models
- Generating images with [Stable Diffusion](stable_diffusion).
### Audio Models
- Speech recognition with [OpenAI's Whisper](whisper).
- Bidirectional language understanding with [BERT](bert)
### Other Models
- Semi-supervised learning on graph-structured data with [GCN](gcn).
Note: You can now directly download a few converted checkpoints from the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face.
### Hugging Face
Note: You can now directly download a few converted checkpoints from the [MLX
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
We encourage you to join the community and [contribute new
models](https://github.com/ml-explore/mlx-examples/issues/155).
## Contributing

View File

@ -1,7 +1,7 @@
from transformers import BertModel
import argparse
import numpy
from transformers import BertModel
def replace_key(key: str) -> str:

View File

@ -1,7 +1,7 @@
from transformers import AutoModel, AutoTokenizer
import argparse
from transformers import AutoModel, AutoTokenizer
def run(bert_model: str):
batch = [

View File

@ -1,13 +1,13 @@
import numpy as np
from typing import Optional
import argparse
from dataclasses import dataclass
from transformers import BertTokenizer
from mlx.utils import tree_unflatten
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
import argparse
import numpy
import numpy as np
from mlx.utils import tree_unflatten
from transformers import BertTokenizer
@dataclass

View File

@ -1,6 +1,7 @@
import math
import mlx.core as mx
from mlx.data.datasets import load_cifar10
import math
def get_cifar10(batch_size, root=None):

View File

@ -1,11 +1,11 @@
import argparse
import time
import resnet
import mlx.nn as nn
import mlx.core as mx
import mlx.optimizers as optim
from dataset import get_cifar10
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import resnet
from dataset import get_cifar10
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(

View File

@ -6,11 +6,11 @@ There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
"""
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = [
"ResNet",
"resnet20",

View File

@ -1,9 +1,9 @@
import os
import requests
import tarfile
import mlx.core as mx
import numpy as np
import requests
import scipy.sparse as sparse
"""

View File

@ -3,10 +3,10 @@ from argparse import ArgumentParser
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from datasets import download_cora, load_data, train_val_test_mask
from mlx.nn.losses import cross_entropy
from mlx.utils import tree_flatten
from datasets import download_cora, load_data, train_val_test_mask
from gcn import GCN

View File

@ -20,8 +20,9 @@ weights you will need to request access from Meta:
- [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
> [!TIP]
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step.
> [!TIP] Alternatively, you can also download a few converted checkpoints from
> the [MLX Community](https://huggingface.co/mlx-community) organization on
> Hugging Face and skip the conversion step.
You can download the TinyLlama models directly from [Hugging
Face](https://huggingface.co/TinyLlama).

View File

@ -4,8 +4,9 @@ import argparse
import collections
import glob
import json
import numpy as np
from pathlib import Path
import numpy as np
import torch

View File

@ -1,16 +1,16 @@
# Copyright © 2023 Apple Inc.
import argparse
from dataclasses import dataclass
import json
from pathlib import Path
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from sentencepiece import SentencePieceProcessor
@dataclass

View File

@ -29,7 +29,10 @@ python convert.py
The conversion script will save the converted weights in the same location.
> [!TIP]
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step.
> Alternatively, you can also download a few converted checkpoints from the
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
> Face and skip the conversion step.
### Run

View File

@ -2,10 +2,10 @@
import argparse
import json
import numpy as np
from pathlib import Path
import torch
import numpy as np
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")

View File

@ -1,15 +1,15 @@
# Copyright © 2023 Apple Inc.
import argparse
from dataclasses import dataclass
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten
from sentencepiece import SentencePieceProcessor
@dataclass

View File

@ -2,11 +2,10 @@
import unittest
import mistral
import mlx.core as mx
from mlx.utils import tree_map
import mistral
class TestMistral(unittest.TestCase):
def test_model(self):

View File

@ -3,8 +3,9 @@
import argparse
import glob
import json
import numpy as np
from pathlib import Path
import numpy as np
import torch
@ -45,7 +46,7 @@ if __name__ == "__main__":
args = json.load(fid)
args["model_type"] = "mixtral"
with open(model_path / "config.json", "w") as f:
json.dump(args, f, indent=4)
json.dump(args, f, indent=4)
torch_files = glob.glob(str(model_path / "consolidated.*.pt"))
torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2]))

View File

@ -1,17 +1,17 @@
# Copyright © 2023 Apple Inc.
import argparse
from dataclasses import dataclass
import glob
import json
import numpy as np
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from sentencepiece import SentencePieceProcessor
@dataclass

View File

@ -17,8 +17,10 @@ python convert.py
This will make the `weights.npz` file which MLX can read.
> [!TIP]
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step.
> [!TIP] Alternatively, you can also download a few converted checkpoints from
> the [MLX Community](https://huggingface.co/mlx-community) organization on
> Hugging Face and skip the conversion step.
## Generate

View File

@ -1,5 +1,5 @@
from transformers import AutoModelForCausalLM
import numpy as np
from transformers import AutoModelForCausalLM
def replace_key(key: str) -> str:

View File

@ -1,13 +1,13 @@
import argparse
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
import mlx.core as mx
import mlx.nn as nn
import math
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass
@ -167,7 +167,7 @@ def load_model(model_path: str):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script")
parser.add_argument(
"--model_path",
"--model-path",
type=str,
default="phi-2",
help="The path to the model weights",
@ -178,7 +178,7 @@ if __name__ == "__main__":
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max_tokens",
"--max-tokens",
"-m",
type=int,
default=100,

View File

@ -1,8 +1,9 @@
import argparse
from transformers import AutoModelForCausalLM
import json
import numpy as np
import torch
import json
from transformers import AutoModelForCausalLM
def replace_key(key: str) -> str:

View File

@ -1,6 +1,7 @@
import argparse
from dataclasses import dataclass
import json
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten

View File

@ -2,12 +2,12 @@
import argparse
import json
import numpy as np
from pathlib import Path
import shutil
import os
import torch
import shutil
from pathlib import Path
import numpy as np
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(

View File

@ -32,8 +32,8 @@ class WikiSQL:
def _maybe_download(self, data_dir):
if not os.path.exists(data_dir):
import io
from urllib import request
import tarfile
from urllib import request
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
r = request.urlopen(url)

View File

@ -3,19 +3,17 @@
import argparse
import json
import math
import numpy as np
from pathlib import Path
from sentencepiece import SentencePieceProcessor
import time
from typing import Optional, Tuple, List
from pathlib import Path
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_map, tree_flatten, tree_unflatten
from models import ModelArgs, Model, LoRALinear
import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from models import LoRALinear, Model, ModelArgs
from sentencepiece import SentencePieceProcessor
def build_parser():

View File

@ -1,8 +1,8 @@
# Copyright © 2023 Apple Inc.
from dataclasses import dataclass
import math
from typing import Optional, Tuple, List
from dataclasses import dataclass
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn

View File

@ -1,9 +1,10 @@
# Copyright © 2023 Apple Inc.
import functools
import time
import jax
import jax.numpy as jnp
import functools
import time
import mnist

View File

@ -3,11 +3,10 @@
import argparse
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import mnist

View File

@ -1,11 +1,12 @@
# Copyright © 2023 Apple Inc.
import gzip
import numpy as np
import os
import pickle
from urllib import request
import numpy as np
def mnist(save_dir="/tmp"):
"""

View File

@ -1,9 +1,10 @@
# Copyright © 2023 Apple Inc.
import argparse
import torch
import time
import torch
import mnist

View File

@ -1,9 +1,9 @@
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]

View File

@ -1,13 +1,13 @@
import argparse
import time
import kwt
import mlx.nn as nn
import mlx.data as dx
import mlx.core as mx
import mlx.optimizers as optim
from mlx.data.features import mfsc
from mlx.data.datasets import load_speechcommands
import kwt
import mlx.core as mx
import mlx.data as dx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.data.datasets import load_speechcommands
from mlx.data.features import mfsc
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(

View File

@ -6,12 +6,12 @@ from typing import Tuple
import mlx.core as mx
from .model_io import (
load_unet,
load_text_encoder,
_DEFAULT_MODEL,
load_autoencoder,
load_diffusion_config,
load_text_encoder,
load_tokenizer,
_DEFAULT_MODEL,
load_unet,
)
from .sampler import SimpleEulerSampler

View File

@ -3,20 +3,18 @@
import json
from functools import partial
import mlx.core as mx
import numpy as np
from huggingface_hub import hf_hub_download
from mlx.utils import tree_unflatten
from safetensors import safe_open as safetensor_open
import mlx.core as mx
from mlx.utils import tree_unflatten
from .clip import CLIPTextModel
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
from .tokenizer import Tokenizer
from .unet import UNetModel
from .vae import Autoencoder
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
_MODELS = {
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license

View File

@ -1,9 +1,9 @@
# Copyright © 2023 Apple Inc.
from .config import DiffusionConfig
import mlx.core as mx
from .config import DiffusionConfig
def _linspace(a, b, num):
x = mx.arange(0, num) / (num - 1)

View File

@ -1,14 +1,13 @@
# Copyright © 2023 Apple Inc.
import argparse
import mlx.core as mx
from PIL import Image
from tqdm import tqdm
import mlx.core as mx
from stable_diffusion import StableDiffusion
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"

View File

@ -1,6 +1,5 @@
from transformers import T5ForConditionalGeneration
import numpy as np
from transformers import T5ForConditionalGeneration
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
@ -48,8 +47,7 @@ def convert(model_name, dtype):
dtype = getattr(np, dtype)
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {
replace_key(k): v.numpy().astype(dtype)
for k, v in model.state_dict().items()
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
}
file_name = model_name.replace("/", "-")
print(f"Saving weights to {file_name}.npz")

View File

@ -1,7 +1,7 @@
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
import argparse
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
def embed(t5_model: str):
batch = [
@ -51,4 +51,3 @@ if __name__ == "__main__":
embed(args.model)
else:
generate(args.model)

View File

@ -1,11 +1,11 @@
import argparse
from typing import Optional, Tuple, List
from time import perf_counter_ns
from typing import List, Optional, Tuple
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten, tree_map
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import T5Config, T5Tokenizer
@ -166,7 +166,7 @@ class DenseActivation(nn.Module):
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
@ -337,7 +337,7 @@ class Tokenizer:
self._tokenizer = T5Tokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=getattr(config, 'n_positions', 512)
model_max_length=getattr(config, "n_positions", 512),
)
@property

View File

@ -2,10 +2,11 @@
import io
import itertools
import numpy as np
import os
from urllib import request
import zipfile
from urllib import request
import numpy as np
def load_dataset(dataname):

View File

@ -1,14 +1,14 @@
# Copyright © 2023 Apple Inc.
import functools
import jax
import jax.numpy as jnp
import math
import numpy as np
import time
from collections import namedtuple
import datasets
import jax
import jax.numpy as jnp
import numpy as np
from tree_utils import tree_flatten
"""

View File

@ -3,15 +3,13 @@
import math
import time
import numpy as np
import datasets
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten
import datasets
class TransformerLM(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):

View File

@ -3,11 +3,10 @@
import math
import time
import datasets
import numpy as np
import tensorflow as tf
import datasets
def to_samples(context_size, dataset):
tokens = dataset.size

View File

@ -3,11 +3,10 @@
import math
import time
import datasets
import numpy as np
import torch
import datasets
def to_samples(context_size, dataset):
tokens = dataset.size

View File

@ -5,10 +5,7 @@ import time
import mlx.core as mx
from whisper import load_models
from whisper import audio
from whisper import decoding
from whisper import transcribe
from whisper import audio, decoding, load_models, transcribe
audio_file = "whisper/assets/ls_test.flac"

View File

@ -1,19 +1,18 @@
# Copyright © 2023 Apple Inc.
import os
import subprocess
import unittest
import mlx.core as mx
import numpy as np
import os
import subprocess
import torch
import whisper
import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models
import whisper.torch_whisper as torch_whisper
import whisper.decoding as decoding
TEST_AUDIO = "whisper/assets/ls_test.flac"

View File

@ -1,6 +1,4 @@
# Copyright © 2023 Apple Inc.
from . import load_models
from . import audio
from . import decoding
from . import audio, decoding, load_models
from .transcribe import transcribe

View File

@ -1,12 +1,13 @@
# Copyright © 2023 Apple Inc.
import zlib
from dataclasses import dataclass, field, replace
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import mlx.core as mx
from mlx.utils import tree_map
import mlx.nn as nn
import numpy as np
import zlib
from mlx.utils import tree_map
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer

View File

@ -7,13 +7,11 @@ import warnings
from typing import List
import mlx.core as mx
from mlx.utils import tree_map
import torch
from mlx.utils import tree_map
from tqdm import tqdm
from . import whisper
from . import torch_whisper
from . import torch_whisper, whisper
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",

View File

@ -1,12 +1,12 @@
# Copyright © 2023 Apple Inc.
import sys
from typing import Optional, Tuple, Union
import mlx.core as mx
import numpy as np
import sys
from typing import Optional, Tuple, Union
import tqdm
from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,

View File

@ -2,13 +2,12 @@
import base64
import gzip
from dataclasses import dataclass
import math
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .decoding import decode as decode_function