Utils to compare encoder output

This commit is contained in:
Juarez Bochi
2023-12-17 07:20:24 -05:00
parent 7e42349f4c
commit 4ec2b6eec3
2 changed files with 53 additions and 0 deletions

View File

@@ -355,6 +355,12 @@ if __name__ == "__main__":
help="",
default="translate English to German: That is good.",
)
parser.add_argument(
"--encode-only",
action='store_true',
default=False,
help="Whether to decode or not",
)
parser.add_argument(
"--max_tokens",
"-m",
@@ -384,6 +390,14 @@ if __name__ == "__main__":
prompt = mx.array(prompt)
if args.encode_only:
print("[INFO] Encoding with T5...", flush=True)
print(args.prompt, end="", flush=True)
embeddings = model.wte(prompt)
encoder_output = model.encoder(embeddings, mask=None)
print(encoder_output, flush=True)
exit(0)
print("[INFO] Generating with T5...", flush=True)
print(args.prompt, end="", flush=True)