mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 12:13:25 +08:00
nits
This commit is contained in:
parent
b7bc811507
commit
88ca747e9e
@ -159,7 +159,6 @@ def get_per_token_logps(model, inputs, lengths):
|
|||||||
logits = model(inputs).astype(mx.float16)
|
logits = model(inputs).astype(mx.float16)
|
||||||
logits = logits[:, :-1, :]
|
logits = logits[:, :-1, :]
|
||||||
targets = inputs[:, 1:]
|
targets = inputs[:, 1:]
|
||||||
mx.eval(logits)
|
|
||||||
per_token_logps = []
|
per_token_logps = []
|
||||||
for i in range(logits.shape[0]):
|
for i in range(logits.shape[0]):
|
||||||
seq_len = int(lengths[i]) - 1
|
seq_len = int(lengths[i]) - 1
|
||||||
@ -172,6 +171,7 @@ def get_per_token_logps(model, inputs, lengths):
|
|||||||
axis=-1
|
axis=-1
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
per_token_logps.append(token_log_probs)
|
per_token_logps.append(token_log_probs)
|
||||||
|
mx.eval(logits)
|
||||||
return per_token_logps
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user