Refactor activation function and loss calculation (#325)

This commit is contained in:
AtomicVar 2024-01-17 05:42:56 +08:00 committed by GitHub
parent ce7b65e8c4
commit 2ba5d3db14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -26,12 +26,12 @@ class MLP(nn.Module):
def __call__(self, x):
for l in self.layers[:-1]:
x = mx.maximum(l(x), 0.0)
x = nn.relu(l(x))
return self.layers[-1](x)
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))
return nn.losses.cross_entropy(model(X), y, reduction="mean")
def eval_fn(model, X, y):

View File

@ -155,5 +155,5 @@ class StableDiffusion:
def decode(self, x_t):
x = self.autoencoder.decode(x_t)
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
x = mx.clip(x / 2 + 0.5, 0, 1)
return x