mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Refactor activation function and loss calculation (#325)
This commit is contained in:
parent
ce7b65e8c4
commit
2ba5d3db14
@ -26,12 +26,12 @@ class MLP(nn.Module):
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = mx.maximum(l(x), 0.0)
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
def loss_fn(model, X, y):
|
||||
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||
|
||||
|
||||
def eval_fn(model, X, y):
|
||||
|
@ -155,5 +155,5 @@ class StableDiffusion:
|
||||
|
||||
def decode(self, x_t):
|
||||
x = self.autoencoder.decode(x_t)
|
||||
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
|
||||
x = mx.clip(x / 2 + 0.5, 0, 1)
|
||||
return x
|
||||
|
Loading…
Reference in New Issue
Block a user