Refactor activation function and loss calculation (#325)

This commit is contained in:
AtomicVar
2024-01-17 05:42:56 +08:00
committed by GitHub
parent ce7b65e8c4
commit 2ba5d3db14
2 changed files with 3 additions and 3 deletions

View File

@@ -26,12 +26,12 @@ class MLP(nn.Module):
def __call__(self, x):
for l in self.layers[:-1]:
x = mx.maximum(l(x), 0.0)
x = nn.relu(l(x))
return self.layers[-1](x)
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))
return nn.losses.cross_entropy(model(X), y, reduction="mean")
def eval_fn(model, X, y):