mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add llms subdir + update README (#145)
* add llms subdir + update README * nits * use same pre-commit as mlx * update readmes a bit * format
This commit is contained in:
parent
aed14618ca
commit
27c0a8c002
@ -1,5 +1,11 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 22.10.0
|
rev: 22.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.12.0
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
args:
|
||||||
|
- --profile=black
|
||||||
|
29
README.md
29
README.md
@ -5,18 +5,37 @@ framework](https://github.com/ml-explore/mlx).
|
|||||||
|
|
||||||
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
||||||
|
|
||||||
Some more useful examples include:
|
Some more useful examples are listed below.
|
||||||
|
|
||||||
|
### Text Models
|
||||||
|
|
||||||
- [Transformer language model](transformer_lm) training.
|
- [Transformer language model](transformer_lm) training.
|
||||||
- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2).
|
- Large scale text generation with [LLaMA](llms/llama),
|
||||||
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral)
|
[Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms)
|
||||||
|
directory.
|
||||||
|
- A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral).
|
||||||
- Parameter efficient fine-tuning with [LoRA](lora).
|
- Parameter efficient fine-tuning with [LoRA](lora).
|
||||||
|
- Text-to-text multi-task Transformers with [T5](t5).
|
||||||
|
- Bidirectional language understanding with [BERT](bert).
|
||||||
|
|
||||||
|
### Image Models
|
||||||
|
|
||||||
- Generating images with [Stable Diffusion](stable_diffusion).
|
- Generating images with [Stable Diffusion](stable_diffusion).
|
||||||
|
|
||||||
|
### Audio Models
|
||||||
|
|
||||||
- Speech recognition with [OpenAI's Whisper](whisper).
|
- Speech recognition with [OpenAI's Whisper](whisper).
|
||||||
- Bidirectional language understanding with [BERT](bert)
|
|
||||||
|
### Other Models
|
||||||
|
|
||||||
- Semi-supervised learning on graph-structured data with [GCN](gcn).
|
- Semi-supervised learning on graph-structured data with [GCN](gcn).
|
||||||
|
|
||||||
Note: You can now directly download a few converted checkpoints from the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face.
|
### Hugging Face
|
||||||
|
|
||||||
|
Note: You can now directly download a few converted checkpoints from the [MLX
|
||||||
|
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
||||||
|
We encourage you to join the community and [contribute new
|
||||||
|
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from transformers import BertModel
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
from transformers import BertModel
|
||||||
|
|
||||||
|
|
||||||
def replace_key(key: str) -> str:
|
def replace_key(key: str) -> str:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from transformers import AutoModel, AutoTokenizer
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def run(bert_model: str):
|
def run(bert_model: str):
|
||||||
batch = [
|
batch = [
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import numpy as np
|
import argparse
|
||||||
from typing import Optional
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from transformers import BertTokenizer
|
from typing import Optional
|
||||||
from mlx.utils import tree_unflatten
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import argparse
|
|
||||||
import numpy
|
import numpy
|
||||||
|
import numpy as np
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.data.datasets import load_cifar10
|
from mlx.data.datasets import load_cifar10
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
def get_cifar10(batch_size, root=None):
|
def get_cifar10(batch_size, root=None):
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
import resnet
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.optimizers as optim
|
|
||||||
from dataset import get_cifar10
|
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
import resnet
|
||||||
|
from dataset import get_cifar10
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(add_help=True)
|
parser = argparse.ArgumentParser(add_help=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -6,11 +6,11 @@ There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ResNet",
|
"ResNet",
|
||||||
"resnet20",
|
"resnet20",
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import requests
|
|
||||||
import tarfile
|
import tarfile
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
import scipy.sparse as sparse
|
import scipy.sparse as sparse
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -3,10 +3,10 @@ from argparse import ArgumentParser
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
|
from datasets import download_cora, load_data, train_val_test_mask
|
||||||
from mlx.nn.losses import cross_entropy
|
from mlx.nn.losses import cross_entropy
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
from datasets import download_cora, load_data, train_val_test_mask
|
|
||||||
from gcn import GCN
|
from gcn import GCN
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +20,9 @@ weights you will need to request access from Meta:
|
|||||||
- [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
- [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
||||||
- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP] Alternatively, you can also download a few converted checkpoints from
|
||||||
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step.
|
> the [MLX Community](https://huggingface.co/mlx-community) organization on
|
||||||
|
> Hugging Face and skip the conversion step.
|
||||||
|
|
||||||
You can download the TinyLlama models directly from [Hugging
|
You can download the TinyLlama models directly from [Hugging
|
||||||
Face](https://huggingface.co/TinyLlama).
|
Face](https://huggingface.co/TinyLlama).
|
@ -4,8 +4,9 @@ import argparse
|
|||||||
import collections
|
import collections
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -1,16 +1,16 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Tuple, List
|
|
||||||
from sentencepiece import SentencePieceProcessor
|
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
@ -29,7 +29,10 @@ python convert.py
|
|||||||
The conversion script will save the converted weights in the same location.
|
The conversion script will save the converted weights in the same location.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step.
|
> Alternatively, you can also download a few converted checkpoints from the
|
||||||
|
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
|
||||||
|
> Face and skip the conversion step.
|
||||||
|
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import torch
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
@ -1,15 +1,15 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
|
||||||
import json
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, List
|
from typing import List, Optional, Tuple
|
||||||
from sentencepiece import SentencePieceProcessor
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_map, tree_unflatten
|
from mlx.utils import tree_map, tree_unflatten
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import mistral
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
import mistral
|
|
||||||
|
|
||||||
|
|
||||||
class TestMistral(unittest.TestCase):
|
class TestMistral(unittest.TestCase):
|
||||||
def test_model(self):
|
def test_model(self):
|
@ -3,8 +3,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, List
|
from typing import List, Optional, Tuple
|
||||||
from sentencepiece import SentencePieceProcessor
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
from mlx.utils import tree_map, tree_unflatten
|
from mlx.utils import tree_map, tree_unflatten
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
0
phi2/.gitignore → llms/phi2/.gitignore
vendored
0
phi2/.gitignore → llms/phi2/.gitignore
vendored
@ -17,8 +17,10 @@ python convert.py
|
|||||||
|
|
||||||
This will make the `weights.npz` file which MLX can read.
|
This will make the `weights.npz` file which MLX can read.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP] Alternatively, you can also download a few converted checkpoints from
|
||||||
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step.
|
> the [MLX Community](https://huggingface.co/mlx-community) organization on
|
||||||
|
> Hugging Face and skip the conversion step.
|
||||||
|
|
||||||
|
|
||||||
## Generate
|
## Generate
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
from transformers import AutoModelForCausalLM
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
def replace_key(key: str) -> str:
|
def replace_key(key: str) -> str:
|
@ -1,13 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass
|
|
||||||
from mlx.utils import tree_unflatten
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import math
|
from mlx.utils import tree_unflatten
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -167,7 +167,7 @@ def load_model(model_path: str):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path",
|
"--model-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="phi-2",
|
default="phi-2",
|
||||||
help="The path to the model weights",
|
help="The path to the model weights",
|
||||||
@ -178,7 +178,7 @@ if __name__ == "__main__":
|
|||||||
default="Write a detailed analogy between mathematics and a lighthouse.",
|
default="Write a detailed analogy between mathematics and a lighthouse.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_tokens",
|
"--max-tokens",
|
||||||
"-m",
|
"-m",
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
default=100,
|
@ -1,8 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from transformers import AutoModelForCausalLM
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import json
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
def replace_key(key: str) -> str:
|
def replace_key(key: str) -> str:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
|
||||||
import json
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
|
@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -32,8 +32,8 @@ class WikiSQL:
|
|||||||
def _maybe_download(self, data_dir):
|
def _maybe_download(self, data_dir):
|
||||||
if not os.path.exists(data_dir):
|
if not os.path.exists(data_dir):
|
||||||
import io
|
import io
|
||||||
from urllib import request
|
|
||||||
import tarfile
|
import tarfile
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
|
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
|
||||||
r = request.urlopen(url)
|
r = request.urlopen(url)
|
||||||
|
14
lora/lora.py
14
lora/lora.py
@ -3,19 +3,17 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
from sentencepiece import SentencePieceProcessor
|
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Tuple, List
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
from mlx.utils import tree_map, tree_flatten, tree_unflatten
|
import numpy as np
|
||||||
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
from models import LoRALinear, Model, ModelArgs
|
||||||
from models import ModelArgs, Model, LoRALinear
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
|
||||||
def build_parser():
|
def build_parser():
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple, List
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import time
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import functools
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mnist
|
import mnist
|
||||||
|
|
||||||
|
@ -3,11 +3,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import mnist
|
import mnist
|
||||||
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import gzip
|
import gzip
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def mnist(save_dir="/tmp"):
|
def mnist(save_dir="/tmp"):
|
||||||
"""
|
"""
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
import mnist
|
import mnist
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
import kwt
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx.data as dx
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.optimizers as optim
|
|
||||||
from mlx.data.features import mfsc
|
|
||||||
from mlx.data.datasets import load_speechcommands
|
|
||||||
|
|
||||||
|
import kwt
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.data as dx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
from mlx.data.datasets import load_speechcommands
|
||||||
|
from mlx.data.features import mfsc
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(add_help=True)
|
parser = argparse.ArgumentParser(add_help=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -6,12 +6,12 @@ from typing import Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .model_io import (
|
from .model_io import (
|
||||||
load_unet,
|
_DEFAULT_MODEL,
|
||||||
load_text_encoder,
|
|
||||||
load_autoencoder,
|
load_autoencoder,
|
||||||
load_diffusion_config,
|
load_diffusion_config,
|
||||||
|
load_text_encoder,
|
||||||
load_tokenizer,
|
load_tokenizer,
|
||||||
_DEFAULT_MODEL,
|
load_unet,
|
||||||
)
|
)
|
||||||
from .sampler import SimpleEulerSampler
|
from .sampler import SimpleEulerSampler
|
||||||
|
|
||||||
|
@ -3,20 +3,18 @@
|
|||||||
import json
|
import json
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
from safetensors import safe_open as safetensor_open
|
from safetensors import safe_open as safetensor_open
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx.utils import tree_unflatten
|
|
||||||
|
|
||||||
from .clip import CLIPTextModel
|
from .clip import CLIPTextModel
|
||||||
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig
|
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .unet import UNetModel
|
from .unet import UNetModel
|
||||||
from .vae import Autoencoder
|
from .vae import Autoencoder
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from .config import DiffusionConfig
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .config import DiffusionConfig
|
||||||
|
|
||||||
|
|
||||||
def _linspace(a, b, num):
|
def _linspace(a, b, num):
|
||||||
x = mx.arange(0, num) / (num - 1)
|
x = mx.arange(0, num) / (num - 1)
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
from stable_diffusion import StableDiffusion
|
from stable_diffusion import StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Generate images from a textual prompt using stable diffusion"
|
description="Generate images from a textual prompt using stable diffusion"
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from transformers import T5ForConditionalGeneration
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from transformers import T5ForConditionalGeneration
|
||||||
|
|
||||||
SHARED_REPLACEMENT_PATTERNS = [
|
SHARED_REPLACEMENT_PATTERNS = [
|
||||||
(".block.", ".layers."),
|
(".block.", ".layers."),
|
||||||
@ -48,8 +47,7 @@ def convert(model_name, dtype):
|
|||||||
dtype = getattr(np, dtype)
|
dtype = getattr(np, dtype)
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||||
weights = {
|
weights = {
|
||||||
replace_key(k): v.numpy().astype(dtype)
|
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
|
||||||
for k, v in model.state_dict().items()
|
|
||||||
}
|
}
|
||||||
file_name = model_name.replace("/", "-")
|
file_name = model_name.replace("/", "-")
|
||||||
print(f"Saving weights to {file_name}.npz")
|
print(f"Saving weights to {file_name}.npz")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
def embed(t5_model: str):
|
def embed(t5_model: str):
|
||||||
batch = [
|
batch = [
|
||||||
@ -51,4 +51,3 @@ if __name__ == "__main__":
|
|||||||
embed(args.model)
|
embed(args.model)
|
||||||
else:
|
else:
|
||||||
generate(args.model)
|
generate(args.model)
|
||||||
|
|
||||||
|
10
t5/t5.py
10
t5/t5.py
@ -1,11 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import Optional, Tuple, List
|
|
||||||
from time import perf_counter_ns
|
from time import perf_counter_ns
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_unflatten, tree_map
|
import numpy as np
|
||||||
|
from mlx.utils import tree_map, tree_unflatten
|
||||||
from transformers import T5Config, T5Tokenizer
|
from transformers import T5Config, T5Tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -166,7 +166,7 @@ class DenseActivation(nn.Module):
|
|||||||
self.act = nn.relu
|
self.act = nn.relu
|
||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
self.act = nn.gelu
|
self.act = nn.gelu
|
||||||
elif activation == "silu":
|
elif activation == "silu":
|
||||||
self.act = nn.silu
|
self.act = nn.silu
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown activation: {activation}")
|
raise ValueError(f"Unknown activation: {activation}")
|
||||||
@ -337,7 +337,7 @@ class Tokenizer:
|
|||||||
self._tokenizer = T5Tokenizer.from_pretrained(
|
self._tokenizer = T5Tokenizer.from_pretrained(
|
||||||
args.model,
|
args.model,
|
||||||
legacy=False,
|
legacy=False,
|
||||||
model_max_length=getattr(config, 'n_positions', 512)
|
model_max_length=getattr(config, "n_positions", 512),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
import io
|
import io
|
||||||
import itertools
|
import itertools
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
from urllib import request
|
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(dataname):
|
def load_dataset(dataname):
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import time
|
import time
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
from tree_utils import tree_flatten
|
from tree_utils import tree_flatten
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -3,15 +3,13 @@
|
|||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import datasets
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
import datasets
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerLM(nn.Module):
|
class TransformerLM(nn.Module):
|
||||||
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
|
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
|
||||||
|
@ -3,11 +3,10 @@
|
|||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
import datasets
|
|
||||||
|
|
||||||
|
|
||||||
def to_samples(context_size, dataset):
|
def to_samples(context_size, dataset):
|
||||||
tokens = dataset.size
|
tokens = dataset.size
|
||||||
|
@ -3,11 +3,10 @@
|
|||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import datasets
|
|
||||||
|
|
||||||
|
|
||||||
def to_samples(context_size, dataset):
|
def to_samples(context_size, dataset):
|
||||||
tokens = dataset.size
|
tokens = dataset.size
|
||||||
|
@ -5,10 +5,7 @@ import time
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from whisper import load_models
|
from whisper import audio, decoding, load_models, transcribe
|
||||||
from whisper import audio
|
|
||||||
from whisper import decoding
|
|
||||||
from whisper import transcribe
|
|
||||||
|
|
||||||
audio_file = "whisper/assets/ls_test.flac"
|
audio_file = "whisper/assets/ls_test.flac"
|
||||||
|
|
||||||
|
@ -1,19 +1,18 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import whisper
|
import whisper
|
||||||
import whisper.audio as audio
|
import whisper.audio as audio
|
||||||
|
import whisper.decoding as decoding
|
||||||
import whisper.load_models as load_models
|
import whisper.load_models as load_models
|
||||||
import whisper.torch_whisper as torch_whisper
|
import whisper.torch_whisper as torch_whisper
|
||||||
import whisper.decoding as decoding
|
|
||||||
|
|
||||||
|
|
||||||
TEST_AUDIO = "whisper/assets/ls_test.flac"
|
TEST_AUDIO = "whisper/assets/ls_test.flac"
|
||||||
|
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from . import load_models
|
from . import audio, decoding, load_models
|
||||||
from . import audio
|
|
||||||
from . import decoding
|
|
||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import zlib
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zlib
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
from .audio import CHUNK_LENGTH
|
from .audio import CHUNK_LENGTH
|
||||||
from .tokenizer import Tokenizer, get_tokenizer
|
from .tokenizer import Tokenizer, get_tokenizer
|
||||||
|
@ -7,13 +7,11 @@ import warnings
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from mlx.utils import tree_map
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import whisper
|
from . import torch_whisper, whisper
|
||||||
from . import torch_whisper
|
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sys
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
from .audio import (
|
from .audio import (
|
||||||
FRAMES_PER_SECOND,
|
FRAMES_PER_SECOND,
|
||||||
HOP_LENGTH,
|
HOP_LENGTH,
|
||||||
|
@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import gzip
|
import gzip
|
||||||
from dataclasses import dataclass
|
|
||||||
import math
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, Optional
|
from typing import Dict, Iterable, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .decoding import decode as decode_function
|
from .decoding import decode as decode_function
|
||||||
|
Loading…
Reference in New Issue
Block a user