mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
bug fix in transformer_lm example
This commit is contained in:
parent
8b965b2e33
commit
e488831e03
@ -81,13 +81,13 @@ def main(args):
|
||||
optimizer = optim.SGD(learning_rate=args.learning_rate)
|
||||
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
||||
|
||||
def eval_fn(params, dataset):
|
||||
def eval_fn(model, dataset):
|
||||
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
||||
loss = 0
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
bx, by = map(mx.array, (bx, by))
|
||||
losses = self.loss(bx, by, reduce=False)
|
||||
losses = model.loss(bx, by, reduce=False)
|
||||
loss += mx.sum(losses).item()
|
||||
return loss / len(targets)
|
||||
|
||||
@ -110,9 +110,8 @@ def main(args):
|
||||
)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
|
||||
if (it + 1) % steps_per_eval == 0:
|
||||
val_loss = eval_fn(params, valid)
|
||||
val_loss = eval_fn(model, valid)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
@ -123,7 +122,7 @@ def main(args):
|
||||
tic = time.perf_counter()
|
||||
|
||||
if args.eval_test:
|
||||
test_loss = eval_fn(params, test)
|
||||
test_loss = eval_fn(model, test)
|
||||
test_ppl = math.exp(test_loss)
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user