mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-04 15:54:34 +08:00
Utils to compare encoder output
This commit is contained in:
14
t5/t5.py
14
t5/t5.py
@@ -355,6 +355,12 @@ if __name__ == "__main__":
|
||||
help="",
|
||||
default="translate English to German: That is good.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encode-only",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Whether to decode or not",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"-m",
|
||||
@@ -384,6 +390,14 @@ if __name__ == "__main__":
|
||||
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
if args.encode_only:
|
||||
print("[INFO] Encoding with T5...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
embeddings = model.wte(prompt)
|
||||
encoder_output = model.encoder(embeddings, mask=None)
|
||||
print(encoder_output, flush=True)
|
||||
exit(0)
|
||||
|
||||
print("[INFO] Generating with T5...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
|
||||
|
Reference in New Issue
Block a user