mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
Refactor activation function and loss calculation (#325)
This commit is contained in:
parent
ce7b65e8c4
commit
2ba5d3db14
@ -26,12 +26,12 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
for l in self.layers[:-1]:
|
for l in self.layers[:-1]:
|
||||||
x = mx.maximum(l(x), 0.0)
|
x = nn.relu(l(x))
|
||||||
return self.layers[-1](x)
|
return self.layers[-1](x)
|
||||||
|
|
||||||
|
|
||||||
def loss_fn(model, X, y):
|
def loss_fn(model, X, y):
|
||||||
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||||
|
|
||||||
|
|
||||||
def eval_fn(model, X, y):
|
def eval_fn(model, X, y):
|
||||||
|
@ -155,5 +155,5 @@ class StableDiffusion:
|
|||||||
|
|
||||||
def decode(self, x_t):
|
def decode(self, x_t):
|
||||||
x = self.autoencoder.decode(x_t)
|
x = self.autoencoder.decode(x_t)
|
||||||
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
|
x = mx.clip(x / 2 + 0.5, 0, 1)
|
||||||
return x
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user