mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Refactor activation function and loss calculation (#325)
This commit is contained in:
@@ -26,12 +26,12 @@ class MLP(nn.Module):
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = mx.maximum(l(x), 0.0)
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
def loss_fn(model, X, y):
|
||||
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||
|
||||
|
||||
def eval_fn(model, X, y):
|
||||
|
Reference in New Issue
Block a user