diff --git a/llms/MANIFEST.in b/llms/MANIFEST.in new file mode 100644 index 00000000..05b93159 --- /dev/null +++ b/llms/MANIFEST.in @@ -0,0 +1,2 @@ +include mlx_lm/requirements.txt +recursive-include mlx_lm/ *.py diff --git a/llms/README.md b/llms/README.md new file mode 100644 index 00000000..a066ad89 --- /dev/null +++ b/llms/README.md @@ -0,0 +1,110 @@ +## Generate Text with LLMs and MLX + +The easiest way to get started is to install the `mlx-lm` package: + +```shell +pip install mlx-lm +``` + +### Python API + +You can use `mlx-lm` as a module: + +```python +from mlx_lm import load, generate + +model, tokenizer = load("mistralai/Mistral-7B-v0.1") + +response = generate(model, tokenizer, prompt="hello", verbose=True) +``` + +To see a description of all the arguments you can do: + +``` +>>> help(generate) +``` + +The `mlx-lm` package also comes with functionality to quantize and optionally +upload models to the Hugging Face Hub. + +You can convert models in the Python API with: + +```python +from mlx_lm import convert + +upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit" + +convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo) +``` + +This will generate a 4-bit quantized Mistral-7B and upload it to the +repo `mlx-community/My-Mistral-7B-v0.1-4bit`. It will also save the +converted model in the path `mlx_model` by default. + +To see a description of all the arguments you can do: + +``` +>>> help(convert) +``` + +### Command Line + +You can also use `mlx-lm` from the command line with: + +``` +python -m mlx_lm.generate --model mistralai/Mistral-7B-v0.1 --prompt "hello" +``` + +This will download a Mistral 7B model from the Hugging Face Hub and generate +text using the given prompt. + +For a full list of options run: + +``` +python -m mlx_lm generate --help +``` + +To quantize a model from the command line run: + +``` +python -m mlx_lm.convert --hf-path mistralai/Mistral-7B-v0.1 -q +``` + +For more options run: + +``` +python -m mlx_lm.convert --help +``` + +You can upload new models to Hugging Face by specifying `--upload-repo` to +`convert`. For example, to upload a quantized Mistral-7B model to the +[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do: + +``` +python -m mlx_lm.convert \ + --hf-path mistralai/Mistral-7B-v0.1 \ + -q \ + --upload-repo mlx-community/my-4bit-mistral \ +``` + +### Supported Models + +The example supports Hugging Face format Mistral, Llama, and Phi-2 style +models. If the model you want to run is not supported, file an +[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, +submit a pull request. + +Here are a few examples of Hugging Face models that work with this example: + +- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) +- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) +- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) +- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat) +- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) + +Most +[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), +[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending), +and +[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending) +style models should work out of the box. diff --git a/llms/hf_llm/.gitignore b/llms/hf_llm/.gitignore deleted file mode 100644 index 9666d735..00000000 --- a/llms/hf_llm/.gitignore +++ /dev/null @@ -1 +0,0 @@ -mlx_model \ No newline at end of file diff --git a/llms/hf_llm/README.md b/llms/hf_llm/README.md deleted file mode 100644 index e2734adb..00000000 --- a/llms/hf_llm/README.md +++ /dev/null @@ -1,86 +0,0 @@ -## Generate Text with MLX and :hugs: Hugging Face - -This an example of large language model text generation that can pull models from -the Hugging Face Hub. - -### Setup - -Install the dependencies: - -``` -pip install -r requirements.txt -``` - -### Run - -``` -python generate.py --model --prompt "hello" -``` - -For example: - -``` -python generate.py --model mistralai/Mistral-7B-v0.1 --prompt "hello" -``` - -will download the Mistral 7B model and generate text using the given prompt. - -The `` should be either a path to a local directory or a Hugging -Face repo with weights stored in `safetensors` format. If you use a repo from -the Hugging Face Hub, then the model will be downloaded and cached the first -time you run it. See the [Models](#models) section for a full list of supported models. - -Run `python generate.py --help` to see all the options. - - -### Models - -The example supports Hugging Face format Mistral, Llama, and Phi-2 style models. If the -model you want to run is not supported, file an -[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, -submit a pull request. - -Here are a few examples of Hugging Face models that work with this example: - -- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) -- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) -- [TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T) -- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) -- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat) -- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) - -Most -[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), -[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending), -and -[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending) -style models should work out of the box. - -### Convert new models - -You can convert (change the data type or quantize) models using the -`convert.py` script. This script takes a Hugging Face repo as input and outputs -a model directory (which you can optionally also upload to Hugging Face). - -For example, to make a 4-bit quantized model, run: - -``` -python convert.py --hf-path -q -``` - -For more options run: - -``` -python convert.py --help -``` - -You can upload new models to Hugging Face by specifying `--upload-repo` to -`convert.py`. For example, to upload a quantized Mistral-7B model to the -[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do: - -``` -python convert.py \ - --hf-path mistralai/Mistral-7B-v0.1 \ - -q \ - --upload mlx-community/my-4bit-mistral \ -``` diff --git a/llms/hf_llm/models.py b/llms/hf_llm/models.py deleted file mode 100644 index d0f71e1b..00000000 --- a/llms/hf_llm/models.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import glob -import inspect -import json -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn -from huggingface_hub import snapshot_download -from transformers import AutoTokenizer - - -@dataclass -class ModelArgs: - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int = None - rope_theta: float = 10000 - rope_traditional: bool = False - model_type: str = None - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["type"] != "linear": - raise ValueError("rope_scaling 'type' currently only supports 'linear'") - - @classmethod - def from_dict(cls, params): - return cls( - **{ - k: v - for k, v in params.items() - if k in inspect.signature(cls).parameters - } - ) - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, eps: float = 1e-5): - super().__init__() - self.weight = mx.ones((dims,)) - self.eps = eps - - def _norm(self, x): - return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps) - - def __call__(self, x): - output = self._norm(x.astype(mx.float32)).astype(x.dtype) - return self.weight * output - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.repeats = n_heads // n_kv_heads - - head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads, L, -1]) - - if self.repeats > 1: - keys, values = map(repeat, (keys, values)) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out, cache - - -class LlamaModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - h = self.embed_tokens(inputs) - - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) - - if cache is None: - cache = [None] * len(self.layers) - - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) - - return self.norm(h), cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model = LlamaModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache - - -def load(path_or_hf_repo: str): - # If the path exists, it will try to load model form it - # otherwise download and cache from the hf_repo and cache - model_path = Path(path_or_hf_repo) - if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], - ) - ) - - with open(model_path / "config.json", "r") as f: - config = json.loads(f.read()) - quantization = config.get("quantization", None) - model_args = ModelArgs.from_dict(config) - - weight_files = glob.glob(str(model_path / "*.safetensors")) - if len(weight_files) == 0: - raise FileNotFoundError("No safetensors found in {}".format(model_path)) - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf).items()) - - model = Model(model_args) - if quantization is not None: - nn.QuantizedLinear.quantize_module(model, **quantization) - - model.load_weights(list(weights.items())) - - mx.eval(model.parameters()) - tokenizer = AutoTokenizer.from_pretrained(model_path) - return model, tokenizer - - -def generate(prompt: mx.array, model: Model, temp: float = 0.0): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temp)) - - y = prompt - cache = None - while True: - logits, cache = model(y[None], cache=cache) - logits = logits[:, -1, :] - y = sample(logits) - yield y diff --git a/llms/mlx_lm/README.md b/llms/mlx_lm/README.md new file mode 100644 index 00000000..f7fd1dad --- /dev/null +++ b/llms/mlx_lm/README.md @@ -0,0 +1,7 @@ +## Generate Text with MLX and :hugs: Hugging Face + +This an example of large language model text generation that can pull models from +the Hugging Face Hub. + +For more information on this example, see the +[README](../README.md) in the parent directory. diff --git a/llms/mlx_lm/UPLOAD.md b/llms/mlx_lm/UPLOAD.md new file mode 100644 index 00000000..f5de3655 --- /dev/null +++ b/llms/mlx_lm/UPLOAD.md @@ -0,0 +1,37 @@ +### Packaging for PyPI + +Install `build` and `twine`: + +``` +pip install --user --upgrade build +pip install --user --upgrade twine +``` + +Generate the source distribution and wheel: + +``` +python -m build +``` + +> [!warning] +> Use a test server first + +#### Test Upload + +Upload to test server: + +``` +python -m twine upload --repository testpypi dist/* +``` + +Install from test server and check that it works: + +``` +python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm +``` + +#### Upload + +``` +python -m twine upload dist/* +``` diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py new file mode 100644 index 00000000..21e66182 --- /dev/null +++ b/llms/mlx_lm/__init__.py @@ -0,0 +1,2 @@ +from .convert import convert +from .utils import generate, load diff --git a/llms/hf_llm/convert.py b/llms/mlx_lm/convert.py similarity index 81% rename from llms/hf_llm/convert.py rename to llms/mlx_lm/convert.py index 488c7213..ebd1ef4f 100644 --- a/llms/hf_llm/convert.py +++ b/llms/mlx_lm/convert.py @@ -9,7 +9,8 @@ import mlx.core as mx import mlx.nn as nn import transformers from mlx.utils import tree_flatten -from utils import get_model_path, load + +from .utils import get_model_path, load MAX_FILE_SIZE_GB = 15 @@ -73,26 +74,30 @@ def fetch_from_hub( return weights, config.to_dict(), tokenizer -def quantize(weights: dict, config: dict, args: argparse.Namespace) -> tuple: +def quantize_model( + weights: dict, config: dict, hf_path: str, q_group_size: int, q_bits: int +) -> tuple: """ Applies quantization to the model weights. Args: weights (dict): Model weights. config (dict): Model configuration. - args (argparse.Namespace): Command-line arguments. + hf_path (str): HF model path.. + q_group_size (int): Group size for quantization. + q_bits (int): Bits per weight for quantization. Returns: tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - model, _ = load(args.hf_path) + model, _ = load(hf_path) model.load_weights(list(weights.items())) - nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits) + nn.QuantizedLinear.quantize_module(model, q_group_size, q_bits) quantized_config["quantization"] = { - "group_size": args.q_group_size, - "bits": args.q_bits, + "group_size": q_group_size, + "bits": q_bits, } quantized_weights = dict(tree_flatten(model.parameters())) @@ -148,7 +153,7 @@ Refer to the [original model card](https://huggingface.co/{hf_path}) for more de pip install mlx git clone https://github.com/ml-explore/mlx-examples.git cd mlx-examples/llms/hf_llm -python generate.py --model {repo_id} --prompt "My name is" +python generate.py --model {upload_repo} --prompt "My name is" ``` """ card.save(os.path.join(path, "README.md")) @@ -164,20 +169,24 @@ python generate.py --model {repo_id} --prompt "My name is" ) -if __name__ == "__main__": - parser = configure_parser() - args = parser.parse_args() - +def convert( + hf_path: str, + mlx_path: str = "mlx_model", + quantize: bool = False, + q_group_size: int = 64, + q_bits: int = 4, + dtype: str = "float16", + upload_repo: str = None, +): print("[INFO] Loading") - weights, config, tokenizer = fetch_from_hub(args.hf_path) - - dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) + weights, config, tokenizer = fetch_from_hub(hf_path) + dtype = mx.float16 if quantize else getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} - if args.quantize: + if quantize: print("[INFO] Quantizing") - weights, config = quantize(weights, config, args) + weights, config = quantize_model(weights, config, hf_path, q_group_size, q_bits) - mlx_path = Path(args.mlx_path) + mlx_path = Path(mlx_path) mlx_path.mkdir(parents=True, exist_ok=True) shards = make_shards(weights) for i, shard in enumerate(shards): @@ -186,5 +195,11 @@ if __name__ == "__main__": with open(mlx_path / "config.json", "w") as fid: json.dump(config, fid, indent=4) - if args.upload_repo is not None: - upload_to_hub(mlx_path, args.upload_repo, args.hf_path) + if upload_repo is not None: + upload_to_hub(mlx_path, upload_repo, hf_path) + + +if __name__ == "__main__": + parser = configure_parser() + args = parser.parse_args() + convert(**vars(args)) diff --git a/llms/hf_llm/generate.py b/llms/mlx_lm/generate.py similarity index 93% rename from llms/hf_llm/generate.py rename to llms/mlx_lm/generate.py index 1e906c89..237fb056 100644 --- a/llms/hf_llm/generate.py +++ b/llms/mlx_lm/generate.py @@ -2,7 +2,8 @@ import argparse import time import mlx.core as mx -from utils import generate, load + +from .utils import generate_step, load DEFAULT_MODEL_PATH = "mlx_model" DEFAULT_PROMPT = "hello" @@ -47,7 +48,9 @@ def main(args): tic = time.time() tokens = [] skip = 0 - for token, n in zip(generate(prompt, model, args.temp), range(args.max_tokens)): + for token, n in zip( + generate_step(prompt, model, args.temp), range(args.max_tokens) + ): if token == tokenizer.eos_token_id: break if n == 0: diff --git a/llms/hf_llm/models/__init__.py b/llms/mlx_lm/models/__init__.py similarity index 100% rename from llms/hf_llm/models/__init__.py rename to llms/mlx_lm/models/__init__.py diff --git a/llms/hf_llm/models/base.py b/llms/mlx_lm/models/base.py similarity index 100% rename from llms/hf_llm/models/base.py rename to llms/mlx_lm/models/base.py diff --git a/llms/hf_llm/models/llama.py b/llms/mlx_lm/models/llama.py similarity index 100% rename from llms/hf_llm/models/llama.py rename to llms/mlx_lm/models/llama.py diff --git a/llms/hf_llm/models/phi2.py b/llms/mlx_lm/models/phi2.py similarity index 100% rename from llms/hf_llm/models/phi2.py rename to llms/mlx_lm/models/phi2.py diff --git a/llms/hf_llm/requirements.txt b/llms/mlx_lm/requirements.txt similarity index 50% rename from llms/hf_llm/requirements.txt rename to llms/mlx_lm/requirements.txt index 4447dc86..c78cefa2 100644 --- a/llms/hf_llm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.0.7 +mlx numpy transformers -protobuf \ No newline at end of file +protobuf diff --git a/llms/hf_llm/utils.py b/llms/mlx_lm/utils.py similarity index 75% rename from llms/hf_llm/utils.py rename to llms/mlx_lm/utils.py index 71d8941f..aa464a6e 100644 --- a/llms/hf_llm/utils.py +++ b/llms/mlx_lm/utils.py @@ -6,13 +6,12 @@ from typing import Generator, Tuple import mlx.core as mx import mlx.nn as nn +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer, PreTrainedTokenizer # Local imports -import models.llama as llama -import models.phi2 as phi2 -from huggingface_hub import snapshot_download -from models.base import BaseModelArgs -from transformers import AutoTokenizer, PreTrainedTokenizer +from .models import llama, phi2 +from .models.base import BaseModelArgs # Constants MODEL_MAPPING = { @@ -64,11 +63,11 @@ def get_model_path(path_or_hf_repo: str) -> Path: return model_path -def generate( +def generate_step( prompt: mx.array, model: nn.Module, temp: float = 0.0 ) -> Generator[mx.array, None, None]: """ - Generate text based on the given prompt and model. + A generator producing text based on the given prompt from the model. Args: prompt (mx.array): The input prompt. @@ -76,7 +75,7 @@ def generate( temp (float): The temperature for sampling. If temp is 0, use max sampling. Yields: - mx.array: The generated text. + Generator[mx.array]: A generator producing one token per call. """ def sample(logits: mx.array) -> mx.array: @@ -95,6 +94,46 @@ def generate( yield y +def generate( + model: nn.Module, + tokenizer: PreTrainedTokenizer, + prompt: str, + temp: float = 0.0, + max_tokens: int = 100, + verbose: bool = False, +) -> str: + """ + Generate text from the model. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (str): The string prompt. + temp (float): The temperature for sampling (default 0). + max_tokens (int): The maximum number of tokens (default 100). + """ + + prompt = mx.array(tokenizer.encode(prompt)) + + tokens = [] + skip = 0 + for token, _ in zip(generate_step(prompt, model, temp), range(max_tokens)): + if token == tokenizer.eos_token_id: + break + + tokens.append(token.item()) + + if verbose: + s = tokenizer.decode(tokens) + print(s[skip:], end="", flush=True) + skip = len(s) + + tokens = tokenizer.decode(tokens)[skip:] + if verbose: + print(tokens, flush=True) + return tokens + + def load(path_or_hf_repo: str) -> Tuple[nn.Module, PreTrainedTokenizer]: """ Load the model from a given path or a huggingface repository. diff --git a/llms/setup.py b/llms/setup.py new file mode 100644 index 00000000..3b3fdb7e --- /dev/null +++ b/llms/setup.py @@ -0,0 +1,23 @@ +import sys +from pathlib import Path + +import pkg_resources +from setuptools import setup + +with open(Path(__file__).parent / "mlx_lm/requirements.txt") as fid: + requirements = [str(r) for r in pkg_resources.parse_requirements(fid)] +setup( + name="mlx-lm", + version="0.0.1", + description="LLMs on Apple silicon with MLX and the Hugging Face Hub", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + readme="README.md", + author_email="mlx@group.apple.com", + author="MLX Contributors", + url="https://github.com/ml-explore/mlx-examples", + license="MIT", + install_requires=requirements, + packages=["mlx_lm", "mlx_lm.models"], + python_requires=">=3.8", +)