mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
support for tiny llama (#129)
This commit is contained in:
@@ -24,6 +24,7 @@ class ModelArgs:
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
rope_theta: float
|
||||
rope_traditional: bool = True
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@@ -77,7 +78,9 @@ class Attention(nn.Module):
|
||||
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
|
||||
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
|
||||
self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta)
|
||||
self.rope = RoPE(
|
||||
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -234,7 +237,7 @@ def generate(args):
|
||||
|
||||
input("Press enter to start generation")
|
||||
print("------")
|
||||
|
||||
print(args.prompt)
|
||||
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
|
||||
skip = 0
|
||||
prompt_processing = None
|
||||
@@ -248,7 +251,7 @@ def generate(args):
|
||||
mx.eval(token)
|
||||
prompt_processing = toc("Prompt processing", start)
|
||||
|
||||
if len(tokens) >= args.num_tokens:
|
||||
if len(tokens) >= args.max_tokens:
|
||||
break
|
||||
|
||||
elif (len(tokens) % args.write_every) == 0:
|
||||
@@ -261,8 +264,7 @@ def generate(args):
|
||||
mx.eval(tokens)
|
||||
full_gen = toc("Full generation", start)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s[skip:], end="", flush=True)
|
||||
print()
|
||||
print(s[skip:], flush=True)
|
||||
print("------")
|
||||
print(prompt_processing)
|
||||
print(full_gen)
|
||||
@@ -354,14 +356,18 @@ if __name__ == "__main__":
|
||||
"model", help="Path to the model directory containing the MLX weights"
|
||||
)
|
||||
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
|
||||
parser.add_argument("prompt", help="The message to be processed by the model")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--few-shot",
|
||||
action="store_true",
|
||||
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
|
||||
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
|
||||
Reference in New Issue
Block a user