From 2ba5d3db142f19d050c2beb156762028ccdcbff0 Mon Sep 17 00:00:00 2001 From: AtomicVar Date: Wed, 17 Jan 2024 05:42:56 +0800 Subject: [PATCH] Refactor activation function and loss calculation (#325) --- mnist/main.py | 4 ++-- stable_diffusion/stable_diffusion/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mnist/main.py b/mnist/main.py index 47092041..14352df7 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -26,12 +26,12 @@ class MLP(nn.Module): def __call__(self, x): for l in self.layers[:-1]: - x = mx.maximum(l(x), 0.0) + x = nn.relu(l(x)) return self.layers[-1](x) def loss_fn(model, X, y): - return mx.mean(nn.losses.cross_entropy(model(X), y)) + return nn.losses.cross_entropy(model(X), y, reduction="mean") def eval_fn(model, X, y): diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index 21e22a14..f9325ae6 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -155,5 +155,5 @@ class StableDiffusion: def decode(self, x_t): x = self.autoencoder.decode(x_t) - x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5)) + x = mx.clip(x / 2 + 0.5, 0, 1) return x