mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-16 08:08:08 +08:00
nits
This commit is contained in:
@@ -159,7 +159,6 @@ def get_per_token_logps(model, inputs, lengths):
|
||||
logits = model(inputs).astype(mx.float16)
|
||||
logits = logits[:, :-1, :]
|
||||
targets = inputs[:, 1:]
|
||||
mx.eval(logits)
|
||||
per_token_logps = []
|
||||
for i in range(logits.shape[0]):
|
||||
seq_len = int(lengths[i]) - 1
|
||||
@@ -172,6 +171,7 @@ def get_per_token_logps(model, inputs, lengths):
|
||||
axis=-1
|
||||
).squeeze(-1)
|
||||
per_token_logps.append(token_log_probs)
|
||||
mx.eval(logits)
|
||||
return per_token_logps
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user