mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
bug fix in transformer_lm example
This commit is contained in:
parent
8b965b2e33
commit
e488831e03
@ -81,13 +81,13 @@ def main(args):
|
|||||||
optimizer = optim.SGD(learning_rate=args.learning_rate)
|
optimizer = optim.SGD(learning_rate=args.learning_rate)
|
||||||
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
||||||
|
|
||||||
def eval_fn(params, dataset):
|
def eval_fn(model, dataset):
|
||||||
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
||||||
loss = 0
|
loss = 0
|
||||||
for s in range(0, targets.shape[0], batch_size):
|
for s in range(0, targets.shape[0], batch_size):
|
||||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||||
bx, by = map(mx.array, (bx, by))
|
bx, by = map(mx.array, (bx, by))
|
||||||
losses = self.loss(bx, by, reduce=False)
|
losses = model.loss(bx, by, reduce=False)
|
||||||
loss += mx.sum(losses).item()
|
loss += mx.sum(losses).item()
|
||||||
return loss / len(targets)
|
return loss / len(targets)
|
||||||
|
|
||||||
@ -110,9 +110,8 @@ def main(args):
|
|||||||
)
|
)
|
||||||
losses = []
|
losses = []
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
|
|
||||||
if (it + 1) % steps_per_eval == 0:
|
if (it + 1) % steps_per_eval == 0:
|
||||||
val_loss = eval_fn(params, valid)
|
val_loss = eval_fn(model, valid)
|
||||||
toc = time.perf_counter()
|
toc = time.perf_counter()
|
||||||
print(
|
print(
|
||||||
f"Iter {it + 1}: "
|
f"Iter {it + 1}: "
|
||||||
@ -123,7 +122,7 @@ def main(args):
|
|||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
|
|
||||||
if args.eval_test:
|
if args.eval_test:
|
||||||
test_loss = eval_fn(params, test)
|
test_loss = eval_fn(model, test)
|
||||||
test_ppl = math.exp(test_loss)
|
test_ppl = math.exp(test_loss)
|
||||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user