mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
remove simplify (#379)
This commit is contained in:
@@ -28,8 +28,6 @@ class TransformerLM(nn.Module):
|
||||
def loss(self, x, y, reduce=True):
|
||||
logits = self(x)
|
||||
losses = nn.losses.cross_entropy(logits, y)
|
||||
mx.simplify(losses)
|
||||
|
||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||
|
||||
|
||||
@@ -96,7 +94,6 @@ def main(args):
|
||||
inputs, targets = map(mx.array, (inputs, targets))
|
||||
loss, grads = loss_and_grad_fn(inputs, targets)
|
||||
model.update(optimizer.apply_gradients(grads, model))
|
||||
mx.simplify(loss, model.parameters())
|
||||
mx.eval(loss, model.parameters())
|
||||
losses.append(loss.item())
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
|
Reference in New Issue
Block a user