remove simplify (#379)

This commit is contained in:
Awni Hannun
2024-01-26 13:54:49 -08:00
committed by GitHub
parent 0b57f0eae6
commit 5aa652d3c2
6 changed files with 6 additions and 17 deletions

View File

@@ -28,8 +28,6 @@ class TransformerLM(nn.Module):
def loss(self, x, y, reduce=True):
logits = self(x)
losses = nn.losses.cross_entropy(logits, y)
mx.simplify(losses)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
@@ -96,7 +94,6 @@ def main(args):
inputs, targets = map(mx.array, (inputs, targets))
loss, grads = loss_and_grad_fn(inputs, targets)
model.update(optimizer.apply_gradients(grads, model))
mx.simplify(loss, model.parameters())
mx.eval(loss, model.parameters())
losses.append(loss.item())
if (it + 1) % steps_per_report == 0: