mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add llms subdir + update README (#145)
* add llms subdir + update README * nits * use same pre-commit as mlx * update readmes a bit * format
This commit is contained in:
57
llms/llama/README.md
Normal file
57
llms/llama/README.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Llama
|
||||
|
||||
An example of generating text with Llama (1 or 2) using MLX.
|
||||
|
||||
Llama is a set of open source language models from Meta AI Research[^1][^2]
|
||||
ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat
|
||||
and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3]
|
||||
|
||||
### Setup
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Next, download and convert the model. If you do not have access to the model
|
||||
weights you will need to request access from Meta:
|
||||
|
||||
- [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
|
||||
- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
||||
|
||||
> [!TIP] Alternatively, you can also download a few converted checkpoints from
|
||||
> the [MLX Community](https://huggingface.co/mlx-community) organization on
|
||||
> Hugging Face and skip the conversion step.
|
||||
|
||||
You can download the TinyLlama models directly from [Hugging
|
||||
Face](https://huggingface.co/TinyLlama).
|
||||
|
||||
Convert the weights with:
|
||||
|
||||
```
|
||||
python convert.py --model-path <path_to_torch_model>
|
||||
```
|
||||
|
||||
For TinyLlama use
|
||||
|
||||
```
|
||||
python convert.py --model-path <path_to_torch_model> --model-name tiny_llama
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights to MLX format, you can interact with the
|
||||
LlaMA model:
|
||||
|
||||
```
|
||||
python llama.py <path_to_model> <path_to_tokenizer.model> --prompt "hello"
|
||||
```
|
||||
|
||||
Run `python llama.py --help` for more details.
|
||||
|
||||
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
|
||||
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)
|
||||
[^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file)
|
139
llms/llama/convert.py
Normal file
139
llms/llama/convert.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def llama(model_path):
|
||||
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
|
||||
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
|
||||
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
|
||||
|
||||
def shard_key(k):
|
||||
keys = k.split(".")
|
||||
if len(keys) < 2:
|
||||
return None
|
||||
return keys[-2]
|
||||
|
||||
def unshard(k, v):
|
||||
wn = shard_key(k)
|
||||
if wn not in SHARD_WEIGHTS:
|
||||
return v
|
||||
elif wn in SHARD_FIRST:
|
||||
axis = 0
|
||||
elif wn in SHARD_SECOND:
|
||||
axis = 1
|
||||
else:
|
||||
raise ValueError("Invalid weight name")
|
||||
return np.concatenate(v, axis=axis)
|
||||
|
||||
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
|
||||
weights = collections.defaultdict(list)
|
||||
for wf in torch_files:
|
||||
state = torch.load(wf, map_location=torch.device("cpu"))
|
||||
for k, v in state.items():
|
||||
v = v.to(torch.float16).numpy()
|
||||
if shard_key(k) in SHARD_WEIGHTS:
|
||||
weights[k].append(v)
|
||||
else:
|
||||
weights[k] = v
|
||||
|
||||
for k, v in weights.items():
|
||||
weights[k] = unshard(k, v)
|
||||
return weights, None
|
||||
|
||||
|
||||
def tiny_llama(model_path):
|
||||
try:
|
||||
import transformers
|
||||
except ImportError as e:
|
||||
print("The transformers package must be installed for this model conversion:")
|
||||
print("pip install transformers")
|
||||
import sys
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
str(model_path)
|
||||
).state_dict()
|
||||
config = transformers.AutoConfig.from_pretrained(model_path)
|
||||
|
||||
# things to change
|
||||
# 1. there's no "model." in the weight names
|
||||
model = {k.replace("model.", ""): v for k, v in model.items()}
|
||||
|
||||
# 2. mlp is called feed_forward
|
||||
model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}
|
||||
|
||||
# 3. up_proj, down_proj, gate_proj
|
||||
model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
|
||||
model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
|
||||
model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}
|
||||
|
||||
# 4. layernorms
|
||||
model = {
|
||||
k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
|
||||
}
|
||||
model = {
|
||||
k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()
|
||||
}
|
||||
|
||||
# 5. lm head
|
||||
model = {k.replace("lm_head", "output"): v for k, v in model.items()}
|
||||
|
||||
# 6. token emb
|
||||
model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}
|
||||
|
||||
# 7. attention
|
||||
model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
|
||||
model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
|
||||
model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
|
||||
model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
|
||||
model = {k.replace("o_proj", "wo"): v for k, v in model.items()}
|
||||
|
||||
params = {}
|
||||
params["dim"] = config.hidden_size
|
||||
params["hidden_dim"] = config.intermediate_size
|
||||
params["n_heads"] = config.num_attention_heads
|
||||
if hasattr(config, "num_key_value_heads"):
|
||||
params["n_kv_heads"] = config.num_key_value_heads
|
||||
params["n_layers"] = config.num_hidden_layers
|
||||
params["vocab_size"] = config.vocab_size
|
||||
params["norm_eps"] = config.rms_norm_eps
|
||||
params["rope_traditional"] = False
|
||||
weights = {k: v.to(torch.float16).numpy() for k, v in model.items()}
|
||||
|
||||
return weights, params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
help="Path to the model. The MLX weights will also be saved there.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help=(
|
||||
"Name of the model to convert. Use 'llama' for models in the "
|
||||
"Llama family distributed by Meta including Llama 1, Llama 2, "
|
||||
"Coda Llama, and Llama chat."
|
||||
),
|
||||
choices=["tiny_llama", "llama"],
|
||||
default="llama",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = Path(args.model_path)
|
||||
weights, params = globals()[args.model_name](model_path)
|
||||
np.savez(str(model_path / "weights.npz"), **weights)
|
||||
if params is not None:
|
||||
with open(model_path / "params.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
390
llms/llama/llama.py
Normal file
390
llms/llama/llama.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
n_layers: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
n_heads: int
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
rope_theta: float
|
||||
rope_traditional: bool = True
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class RoPE(nn.RoPE):
|
||||
def __init__(self, dims: int, traditional: bool = False, base: float = 10000):
|
||||
super().__init__(dims, traditional)
|
||||
self.base = base
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return mx.reshape(rx, shape)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.n_heads: int = args.n_heads
|
||||
self.n_kv_heads: int = args.n_kv_heads
|
||||
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.scale = self.args.head_dim**-0.5
|
||||
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = RoPE(
|
||||
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
def repeat(a):
|
||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
||||
return a.reshape([B, self.n_heads, L, -1])
|
||||
|
||||
keys, values = map(repeat, (keys, values))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
|
||||
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(args=args)
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, cache = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out, cache
|
||||
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.tok_embeddings.weight.dtype)
|
||||
|
||||
x = self.tok_embeddings(x)
|
||||
for l in self.layers:
|
||||
x, _ = l(x, mask)
|
||||
x = self.norm(x)
|
||||
return self.output(x)
|
||||
|
||||
def generate(self, x, temp=1.0):
|
||||
cache = []
|
||||
|
||||
# Make an additive causal mask. We will need that to process the prompt.
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(self.tok_embeddings.weight.dtype)
|
||||
|
||||
# First we process the prompt x the same was as in __call__ but
|
||||
# save the caches in cache
|
||||
x = self.tok_embeddings(x)
|
||||
for l in self.layers:
|
||||
x, c = l(x, mask=mask)
|
||||
# We store the per layer cache in a simple python list
|
||||
cache.append(c)
|
||||
x = self.norm(x)
|
||||
# We only care about the last logits that generate the next token
|
||||
y = self.output(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
# y now has size [1]
|
||||
# Since MLX is lazily evaluated nothing is computed yet.
|
||||
# Calling y.item() would force the computation to happen at
|
||||
# this point but we can also choose not to do that and let the
|
||||
# user choose when to start the computation.
|
||||
yield y
|
||||
|
||||
# Now we parsed the prompt and generated the first token we
|
||||
# need to feed it back into the model and loop to generate the
|
||||
# rest.
|
||||
while True:
|
||||
# Unsqueezing the last dimension to add a sequence length
|
||||
# dimension of 1
|
||||
x = y[:, None]
|
||||
|
||||
x = self.tok_embeddings(x)
|
||||
for i in range(len(cache)):
|
||||
# We are overwriting the arrays in the cache list. When
|
||||
# the computation will happen, MLX will be discarding the
|
||||
# old cache the moment it is not needed anymore.
|
||||
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
|
||||
x = self.norm(x)
|
||||
y = self.output(x[:, -1])
|
||||
y = mx.random.categorical(y * (1 / temp))
|
||||
|
||||
yield y
|
||||
|
||||
|
||||
def tic():
|
||||
return time.time()
|
||||
|
||||
|
||||
def toc(msg, start):
|
||||
end = time.time()
|
||||
return f"[INFO] {msg}: {end - start:.3f} s"
|
||||
|
||||
|
||||
def generate(args):
|
||||
|
||||
input("Press enter to start generation")
|
||||
print("------")
|
||||
print(args.prompt)
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
start = tic()
|
||||
for token in model.generate(x, args.temp):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.max_tokens:
|
||||
break
|
||||
|
||||
elif (len(tokens) % args.write_every) == 0:
|
||||
# It is perfectly ok to eval things we have already eval-ed.
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], flush=True)
|
||||
print("------")
|
||||
print(prompt_processing)
|
||||
print(full_gen)
|
||||
|
||||
|
||||
def few_shot_generate(args):
|
||||
def possible_end(s):
|
||||
word = "[Instruction]"
|
||||
for i in range(len(word) - 1, 0, -1):
|
||||
if s[-i:] == word[:i]:
|
||||
return 0
|
||||
if s[-len(word) :] == word:
|
||||
return 1
|
||||
return -1
|
||||
|
||||
def generate(question):
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
tokens = []
|
||||
start = tic()
|
||||
for token in model.generate(x, args.temp):
|
||||
tokens.append(token)
|
||||
|
||||
if len(tokens) == 1:
|
||||
# Actually perform the computation to measure the prompt processing time
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.max_tokens:
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
token_list = [t.item() for t in tokens]
|
||||
s = tokenizer.decode(token_list)
|
||||
|
||||
end = possible_end(s)
|
||||
if end == 0:
|
||||
continue
|
||||
if end == 1:
|
||||
skip = len(s)
|
||||
break
|
||||
|
||||
print(s[skip:], end="", flush=True)
|
||||
skip = len(s)
|
||||
if token_list[-1] == tokenizer.eos_id():
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
|
||||
print("[INFO] Loading few-shot examples from: {}".format(args.few_shot))
|
||||
prompt = open(args.few_shot).read().strip()
|
||||
while True:
|
||||
question = input("Ask a question: ")
|
||||
generate(prompt.replace("{}", question))
|
||||
print()
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
model_path = Path(model_path)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
n_heads = config["n_heads"]
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
if "head_dim" not in config:
|
||||
config["head_dim"] = config["dim"] // n_heads
|
||||
if "hidden_dim" not in config:
|
||||
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
if config.get("vocab_size", -1) < 0:
|
||||
config["vocab_size"] = weights["output.weight"].shape[-1]
|
||||
if "rope_theta" not in config:
|
||||
config["rope_theta"] = 10000
|
||||
unused = ["multiple_of", "ffn_dim_multiplier"]
|
||||
for k in unused:
|
||||
if k in config:
|
||||
config.pop(k)
|
||||
model = Llama(ModelArgs(**config))
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Llama inference script")
|
||||
parser.add_argument(
|
||||
"model", help="Path to the model directory containing the MLX weights"
|
||||
)
|
||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model. Ignored when --few-shot is provided.",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--few-shot",
|
||||
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.8, help="The sampling temperature"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
print("[INFO] Loading model from disk.")
|
||||
model = load_model(args.model)
|
||||
if args.few_shot:
|
||||
few_shot_generate(args)
|
||||
else:
|
||||
generate(args)
|
3
llms/llama/requirements.txt
Normal file
3
llms/llama/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
mlx
|
||||
sentencepiece
|
||||
torch
|
23
llms/llama/sample_prompt.txt
Normal file
23
llms/llama/sample_prompt.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
[Instruction] Give the list of U.S. states bordering Canada
|
||||
[Answer] OK, here is the list of U.S. states located on the border with Canada:
|
||||
- Alaska
|
||||
- Michigan
|
||||
- Maine
|
||||
- Minnesota
|
||||
- Montana
|
||||
- New York
|
||||
- Washington
|
||||
- North Dakota
|
||||
- Ohio
|
||||
- Vermont
|
||||
- New Hampshire
|
||||
- Idaho
|
||||
- Pennsylvania
|
||||
[Instruction] Write a paragraph about "functional analysis"
|
||||
[Answer] OK, here is a paragraph on the topic of functional analysis:
|
||||
Functional analysis is a branch of mathematical analysis, the core of which is formed by the study of vector spaces endowed with some kind of limit-related structure (for example, inner product, norm, or topology) and the linear functions defined on these spaces and suitably respecting these structures. The historical roots of functional analysis lie in the study of spaces of functions and the formulation of properties of transformations of functions such as the Fourier transform as transformations defining, for example, continuous or unitary operators between function spaces. This point of view turned out to be particularly useful for the study of differential and integral equations.
|
||||
[Instruction] I am starting a new dog walking business. Can you help me find 2 possible names for the business?
|
||||
[Answer] OK, here are two possible names for a new dog walking business:
|
||||
The first option is "Paws on Patrol", and the second option is "The Dog Whisperer".
|
||||
[Instruction] {}
|
||||
[Answer]
|
Reference in New Issue
Block a user