Add llms subdir + update README (#145)

* add llms subdir + update README

* nits

* use same pre-commit as mlx

* update readmes a bit

* format
This commit is contained in:
Awni Hannun 2023-12-20 10:22:25 -08:00 committed by GitHub
parent aed14618ca
commit 27c0a8c002
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
62 changed files with 164 additions and 146 deletions

View File

@ -1,5 +1,11 @@
repos: repos:
- repo: https://github.com/psf/black - repo: https://github.com/psf/black-pre-commit-mirror
rev: 22.10.0 rev: 22.10.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args:
- --profile=black

View File

@ -5,18 +5,37 @@ framework](https://github.com/ml-explore/mlx).
The [MNIST](mnist) example is a good starting point to learn how to use MLX. The [MNIST](mnist) example is a good starting point to learn how to use MLX.
Some more useful examples include: Some more useful examples are listed below.
### Text Models
- [Transformer language model](transformer_lm) training. - [Transformer language model](transformer_lm) training.
- Large scale text generation with [LLaMA](llama), [Mistral](mistral) or [Phi](phi2). - Large scale text generation with [LLaMA](llms/llama),
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral) [Mistral](llms/mistral), [Phi-2](llms/phi2), and more in the [LLMs](llms)
directory.
- A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral).
- Parameter efficient fine-tuning with [LoRA](lora). - Parameter efficient fine-tuning with [LoRA](lora).
- Text-to-text multi-task Transformers with [T5](t5).
- Bidirectional language understanding with [BERT](bert).
### Image Models
- Generating images with [Stable Diffusion](stable_diffusion). - Generating images with [Stable Diffusion](stable_diffusion).
### Audio Models
- Speech recognition with [OpenAI's Whisper](whisper). - Speech recognition with [OpenAI's Whisper](whisper).
- Bidirectional language understanding with [BERT](bert)
### Other Models
- Semi-supervised learning on graph-structured data with [GCN](gcn). - Semi-supervised learning on graph-structured data with [GCN](gcn).
Note: You can now directly download a few converted checkpoints from the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face. ### Hugging Face
Note: You can now directly download a few converted checkpoints from the [MLX
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
We encourage you to join the community and [contribute new
models](https://github.com/ml-explore/mlx-examples/issues/155).
## Contributing ## Contributing

View File

@ -1,7 +1,7 @@
from transformers import BertModel
import argparse import argparse
import numpy import numpy
from transformers import BertModel
def replace_key(key: str) -> str: def replace_key(key: str) -> str:

View File

@ -1,7 +1,7 @@
from transformers import AutoModel, AutoTokenizer
import argparse import argparse
from transformers import AutoModel, AutoTokenizer
def run(bert_model: str): def run(bert_model: str):
batch = [ batch = [

View File

@ -1,13 +1,13 @@
import numpy as np import argparse
from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from transformers import BertTokenizer from typing import Optional
from mlx.utils import tree_unflatten
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import argparse
import numpy import numpy
import numpy as np
from mlx.utils import tree_unflatten
from transformers import BertTokenizer
@dataclass @dataclass

View File

@ -1,6 +1,7 @@
import math
import mlx.core as mx import mlx.core as mx
from mlx.data.datasets import load_cifar10 from mlx.data.datasets import load_cifar10
import math
def get_cifar10(batch_size, root=None): def get_cifar10(batch_size, root=None):

View File

@ -1,11 +1,11 @@
import argparse import argparse
import time import time
import resnet
import mlx.nn as nn
import mlx.core as mx
import mlx.optimizers as optim
from dataset import get_cifar10
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import resnet
from dataset import get_cifar10
parser = argparse.ArgumentParser(add_help=True) parser = argparse.ArgumentParser(add_help=True)
parser.add_argument( parser.add_argument(

View File

@ -6,11 +6,11 @@ There's no BatchNorm is mlx==0.0.4, using LayerNorm instead.
""" """
from typing import Any from typing import Any
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
__all__ = [ __all__ = [
"ResNet", "ResNet",
"resnet20", "resnet20",

View File

@ -1,9 +1,9 @@
import os import os
import requests
import tarfile import tarfile
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import requests
import scipy.sparse as sparse import scipy.sparse as sparse
""" """

View File

@ -3,10 +3,10 @@ from argparse import ArgumentParser
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
from datasets import download_cora, load_data, train_val_test_mask
from mlx.nn.losses import cross_entropy from mlx.nn.losses import cross_entropy
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from datasets import download_cora, load_data, train_val_test_mask
from gcn import GCN from gcn import GCN

View File

@ -20,8 +20,9 @@ weights you will need to request access from Meta:
- [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) - [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) - [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
> [!TIP] > [!TIP] Alternatively, you can also download a few converted checkpoints from
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step. > the [MLX Community](https://huggingface.co/mlx-community) organization on
> Hugging Face and skip the conversion step.
You can download the TinyLlama models directly from [Hugging You can download the TinyLlama models directly from [Hugging
Face](https://huggingface.co/TinyLlama). Face](https://huggingface.co/TinyLlama).

View File

@ -4,8 +4,9 @@ import argparse
import collections import collections
import glob import glob
import json import json
import numpy as np
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch

View File

@ -1,16 +1,16 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
from dataclasses import dataclass
import json import json
from pathlib import Path
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor
import time import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from sentencepiece import SentencePieceProcessor
@dataclass @dataclass

View File

@ -29,7 +29,10 @@ python convert.py
The conversion script will save the converted weights in the same location. The conversion script will save the converted weights in the same location.
> [!TIP] > [!TIP]
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step. > Alternatively, you can also download a few converted checkpoints from the
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
> Face and skip the conversion step.
### Run ### Run

View File

@ -2,10 +2,10 @@
import argparse import argparse
import json import json
import numpy as np
from pathlib import Path from pathlib import Path
import torch
import numpy as np
import torch
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")

View File

@ -1,15 +1,15 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
from dataclasses import dataclass
import json import json
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, List from typing import List, Optional, Tuple
from sentencepiece import SentencePieceProcessor
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_map, tree_unflatten from mlx.utils import tree_map, tree_unflatten
from sentencepiece import SentencePieceProcessor
@dataclass @dataclass

View File

@ -2,11 +2,10 @@
import unittest import unittest
import mistral
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_map from mlx.utils import tree_map
import mistral
class TestMistral(unittest.TestCase): class TestMistral(unittest.TestCase):
def test_model(self): def test_model(self):

View File

@ -3,8 +3,9 @@
import argparse import argparse
import glob import glob
import json import json
import numpy as np
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch

View File

@ -1,17 +1,17 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
from dataclasses import dataclass
import glob import glob
import json import json
import numpy as np from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, List from typing import List, Optional, Tuple
from sentencepiece import SentencePieceProcessor
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten from mlx.utils import tree_map, tree_unflatten
from sentencepiece import SentencePieceProcessor
@dataclass @dataclass

View File

@ -17,8 +17,10 @@ python convert.py
This will make the `weights.npz` file which MLX can read. This will make the `weights.npz` file which MLX can read.
> [!TIP] > [!TIP] Alternatively, you can also download a few converted checkpoints from
> Alternatively, you can also download a few converted checkpoints from the the [MLX Community](https://huggingface.co/mlx-community) organisation on Hugging Face and skip the conversion step. > the [MLX Community](https://huggingface.co/mlx-community) organization on
> Hugging Face and skip the conversion step.
## Generate ## Generate

View File

@ -1,5 +1,5 @@
from transformers import AutoModelForCausalLM
import numpy as np import numpy as np
from transformers import AutoModelForCausalLM
def replace_key(key: str) -> str: def replace_key(key: str) -> str:

View File

@ -1,13 +1,13 @@
import argparse import argparse
import math
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from dataclasses import dataclass
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import math from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass @dataclass
@ -167,7 +167,7 @@ def load_model(model_path: str):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script") parser = argparse.ArgumentParser(description="Phi-2 inference script")
parser.add_argument( parser.add_argument(
"--model_path", "--model-path",
type=str, type=str,
default="phi-2", default="phi-2",
help="The path to the model weights", help="The path to the model weights",
@ -178,7 +178,7 @@ if __name__ == "__main__":
default="Write a detailed analogy between mathematics and a lighthouse.", default="Write a detailed analogy between mathematics and a lighthouse.",
) )
parser.add_argument( parser.add_argument(
"--max_tokens", "--max-tokens",
"-m", "-m",
type=int, type=int,
default=100, default=100,

View File

@ -1,8 +1,9 @@
import argparse import argparse
from transformers import AutoModelForCausalLM import json
import numpy as np import numpy as np
import torch import torch
import json from transformers import AutoModelForCausalLM
def replace_key(key: str) -> str: def replace_key(key: str) -> str:

View File

@ -1,6 +1,7 @@
import argparse import argparse
from dataclasses import dataclass
import json import json
from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten

View File

@ -2,12 +2,12 @@
import argparse import argparse
import json import json
import numpy as np
from pathlib import Path
import shutil
import os import os
import torch import shutil
from pathlib import Path
import numpy as np
import torch
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -32,8 +32,8 @@ class WikiSQL:
def _maybe_download(self, data_dir): def _maybe_download(self, data_dir):
if not os.path.exists(data_dir): if not os.path.exists(data_dir):
import io import io
from urllib import request
import tarfile import tarfile
from urllib import request
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2" url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
r = request.urlopen(url) r = request.urlopen(url)

View File

@ -3,19 +3,17 @@
import argparse import argparse
import json import json
import math import math
import numpy as np
from pathlib import Path
from sentencepiece import SentencePieceProcessor
import time import time
from typing import Optional, Tuple, List from pathlib import Path
from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
from mlx.utils import tree_map, tree_flatten, tree_unflatten import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from models import LoRALinear, Model, ModelArgs
from models import ModelArgs, Model, LoRALinear from sentencepiece import SentencePieceProcessor
def build_parser(): def build_parser():

View File

@ -1,8 +1,8 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from dataclasses import dataclass
import math import math
from typing import Optional, Tuple, List from dataclasses import dataclass
from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn

View File

@ -1,9 +1,10 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import functools
import time
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import functools
import time
import mnist import mnist

View File

@ -3,11 +3,10 @@
import argparse import argparse
import time import time
import numpy as np
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np
import mnist import mnist

View File

@ -1,11 +1,12 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import gzip import gzip
import numpy as np
import os import os
import pickle import pickle
from urllib import request from urllib import request
import numpy as np
def mnist(save_dir="/tmp"): def mnist(save_dir="/tmp"):
""" """

View File

@ -1,9 +1,10 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import torch
import time import time
import torch
import mnist import mnist

View File

@ -1,9 +1,9 @@
from typing import Any from typing import Any
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"] __all__ = ["KWT", "kwt1", "kwt2", "kwt3"]

View File

@ -1,13 +1,13 @@
import argparse import argparse
import time import time
import kwt
import mlx.nn as nn
import mlx.data as dx
import mlx.core as mx
import mlx.optimizers as optim
from mlx.data.features import mfsc
from mlx.data.datasets import load_speechcommands
import kwt
import mlx.core as mx
import mlx.data as dx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.data.datasets import load_speechcommands
from mlx.data.features import mfsc
parser = argparse.ArgumentParser(add_help=True) parser = argparse.ArgumentParser(add_help=True)
parser.add_argument( parser.add_argument(

View File

@ -6,12 +6,12 @@ from typing import Tuple
import mlx.core as mx import mlx.core as mx
from .model_io import ( from .model_io import (
load_unet, _DEFAULT_MODEL,
load_text_encoder,
load_autoencoder, load_autoencoder,
load_diffusion_config, load_diffusion_config,
load_text_encoder,
load_tokenizer, load_tokenizer,
_DEFAULT_MODEL, load_unet,
) )
from .sampler import SimpleEulerSampler from .sampler import SimpleEulerSampler

View File

@ -3,20 +3,18 @@
import json import json
from functools import partial from functools import partial
import mlx.core as mx
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from mlx.utils import tree_unflatten
from safetensors import safe_open as safetensor_open from safetensors import safe_open as safetensor_open
import mlx.core as mx
from mlx.utils import tree_unflatten
from .clip import CLIPTextModel from .clip import CLIPTextModel
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .unet import UNetModel from .unet import UNetModel
from .vae import Autoencoder from .vae import Autoencoder
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base" _DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
_MODELS = { _MODELS = {
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license # See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license

View File

@ -1,9 +1,9 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from .config import DiffusionConfig
import mlx.core as mx import mlx.core as mx
from .config import DiffusionConfig
def _linspace(a, b, num): def _linspace(a, b, num):
x = mx.arange(0, num) / (num - 1) x = mx.arange(0, num) / (num - 1)

View File

@ -1,14 +1,13 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse import argparse
import mlx.core as mx
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import mlx.core as mx
from stable_diffusion import StableDiffusion from stable_diffusion import StableDiffusion
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion" description="Generate images from a textual prompt using stable diffusion"

View File

@ -1,6 +1,5 @@
from transformers import T5ForConditionalGeneration
import numpy as np import numpy as np
from transformers import T5ForConditionalGeneration
SHARED_REPLACEMENT_PATTERNS = [ SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."), (".block.", ".layers."),
@ -48,8 +47,7 @@ def convert(model_name, dtype):
dtype = getattr(np, dtype) dtype = getattr(np, dtype)
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = { weights = {
replace_key(k): v.numpy().astype(dtype) replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
for k, v in model.state_dict().items()
} }
file_name = model_name.replace("/", "-") file_name = model_name.replace("/", "-")
print(f"Saving weights to {file_name}.npz") print(f"Saving weights to {file_name}.npz")

View File

@ -1,7 +1,7 @@
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
import argparse import argparse
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
def embed(t5_model: str): def embed(t5_model: str):
batch = [ batch = [
@ -51,4 +51,3 @@ if __name__ == "__main__":
embed(args.model) embed(args.model)
else: else:
generate(args.model) generate(args.model)

View File

@ -1,11 +1,11 @@
import argparse import argparse
from typing import Optional, Tuple, List
from time import perf_counter_ns from time import perf_counter_ns
from typing import List, Optional, Tuple
import numpy as np
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_unflatten, tree_map import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import T5Config, T5Tokenizer from transformers import T5Config, T5Tokenizer
@ -337,7 +337,7 @@ class Tokenizer:
self._tokenizer = T5Tokenizer.from_pretrained( self._tokenizer = T5Tokenizer.from_pretrained(
args.model, args.model,
legacy=False, legacy=False,
model_max_length=getattr(config, 'n_positions', 512) model_max_length=getattr(config, "n_positions", 512),
) )
@property @property

View File

@ -2,10 +2,11 @@
import io import io
import itertools import itertools
import numpy as np
import os import os
from urllib import request
import zipfile import zipfile
from urllib import request
import numpy as np
def load_dataset(dataname): def load_dataset(dataname):

View File

@ -1,14 +1,14 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import functools import functools
import jax
import jax.numpy as jnp
import math import math
import numpy as np
import time import time
from collections import namedtuple from collections import namedtuple
import datasets import datasets
import jax
import jax.numpy as jnp
import numpy as np
from tree_utils import tree_flatten from tree_utils import tree_flatten
""" """

View File

@ -3,15 +3,13 @@
import math import math
import time import time
import numpy as np import datasets
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
import datasets
class TransformerLM(nn.Module): class TransformerLM(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int): def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):

View File

@ -3,11 +3,10 @@
import math import math
import time import time
import datasets
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import datasets
def to_samples(context_size, dataset): def to_samples(context_size, dataset):
tokens = dataset.size tokens = dataset.size

View File

@ -3,11 +3,10 @@
import math import math
import time import time
import datasets
import numpy as np import numpy as np
import torch import torch
import datasets
def to_samples(context_size, dataset): def to_samples(context_size, dataset):
tokens = dataset.size tokens = dataset.size

View File

@ -5,10 +5,7 @@ import time
import mlx.core as mx import mlx.core as mx
from whisper import load_models from whisper import audio, decoding, load_models, transcribe
from whisper import audio
from whisper import decoding
from whisper import transcribe
audio_file = "whisper/assets/ls_test.flac" audio_file = "whisper/assets/ls_test.flac"

View File

@ -1,19 +1,18 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import os
import subprocess
import unittest import unittest
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import os
import subprocess
import torch import torch
import whisper import whisper
import whisper.audio as audio import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models import whisper.load_models as load_models
import whisper.torch_whisper as torch_whisper import whisper.torch_whisper as torch_whisper
import whisper.decoding as decoding
TEST_AUDIO = "whisper/assets/ls_test.flac" TEST_AUDIO = "whisper/assets/ls_test.flac"

View File

@ -1,6 +1,4 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from . import load_models from . import audio, decoding, load_models
from . import audio
from . import decoding
from .transcribe import transcribe from .transcribe import transcribe

View File

@ -1,12 +1,13 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import zlib
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_map
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
import zlib from mlx.utils import tree_map
from .audio import CHUNK_LENGTH from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer from .tokenizer import Tokenizer, get_tokenizer

View File

@ -7,13 +7,11 @@ import warnings
from typing import List from typing import List
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_map
import torch import torch
from mlx.utils import tree_map
from tqdm import tqdm from tqdm import tqdm
from . import whisper from . import torch_whisper, whisper
from . import torch_whisper
_MODELS = { _MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",

View File

@ -1,12 +1,12 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import sys
from typing import Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import sys
from typing import Optional, Tuple, Union
import tqdm import tqdm
from .audio import ( from .audio import (
FRAMES_PER_SECOND, FRAMES_PER_SECOND,
HOP_LENGTH, HOP_LENGTH,

View File

@ -2,13 +2,12 @@
import base64 import base64
import gzip import gzip
from dataclasses import dataclass
import math import math
from dataclasses import dataclass
from typing import Dict, Iterable, Optional from typing import Dict, Iterable, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .decoding import decode as decode_function from .decoding import decode as decode_function