mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 03:41:17 +08:00
nits
This commit is contained in:
parent
b7bc811507
commit
88ca747e9e
@ -159,7 +159,6 @@ def get_per_token_logps(model, inputs, lengths):
|
||||
logits = model(inputs).astype(mx.float16)
|
||||
logits = logits[:, :-1, :]
|
||||
targets = inputs[:, 1:]
|
||||
mx.eval(logits)
|
||||
per_token_logps = []
|
||||
for i in range(logits.shape[0]):
|
||||
seq_len = int(lengths[i]) - 1
|
||||
@ -172,6 +171,7 @@ def get_per_token_logps(model, inputs, lengths):
|
||||
axis=-1
|
||||
).squeeze(-1)
|
||||
per_token_logps.append(token_log_probs)
|
||||
mx.eval(logits)
|
||||
return per_token_logps
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user