support for tiny llama (#129)

This commit is contained in:
Awni Hannun
2023-12-18 07:47:55 -08:00
committed by GitHub
parent 08e862336a
commit 44b546d446
3 changed files with 140 additions and 45 deletions

View File

@@ -24,6 +24,7 @@ class ModelArgs:
norm_eps: float
vocab_size: int
rope_theta: float
rope_traditional: bool = True
class RMSNorm(nn.Module):
@@ -77,7 +78,9 @@ class Attention(nn.Module):
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = RoPE(args.head_dim, traditional=True, base=args.rope_theta)
self.rope = RoPE(
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
)
def __call__(
self,
@@ -234,7 +237,7 @@ def generate(args):
input("Press enter to start generation")
print("------")
print(args.prompt)
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
skip = 0
prompt_processing = None
@@ -248,7 +251,7 @@ def generate(args):
mx.eval(token)
prompt_processing = toc("Prompt processing", start)
if len(tokens) >= args.num_tokens:
if len(tokens) >= args.max_tokens:
break
elif (len(tokens) % args.write_every) == 0:
@@ -261,8 +264,7 @@ def generate(args):
mx.eval(tokens)
full_gen = toc("Full generation", start)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
print()
print(s[skip:], flush=True)
print("------")
print(prompt_processing)
print(full_gen)
@@ -354,14 +356,18 @@ if __name__ == "__main__":
"model", help="Path to the model directory containing the MLX weights"
)
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument("prompt", help="The message to be processed by the model")
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--few-shot",
action="store_true",
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
)
parser.add_argument(
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
)
parser.add_argument(
"--write-every", type=int, default=1, help="After how many tokens to detokenize"