From ceb0ae34166338140176cc506d4fa5bbf5e16d1d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 8 Dec 2024 21:36:16 -0800 Subject: [PATCH] Revert "Fix shapes mismatch error in FluxPipeline.training_loss when batch-size >= 2" This reverts commit 67c68452bb16b8c0297fb3f66b4c7fb9136d00ab. --- flux/flux/flux.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flux/flux/flux.py b/flux/flux/flux.py index c97d7487..3fd044ac 100644 --- a/flux/flux/flux.py +++ b/flux/flux/flux.py @@ -208,10 +208,9 @@ class FluxPipeline: x_0, x_ids = self._prepare_latent_images(x_0) # Forward process - B, L = x_0.shape[:2] - t = self.sampler.random_timesteps(B, L, dtype=self.dtype) - eps = mx.random.normal(x_0.shape, dtype=x_0.dtype) - x_t = mx.stack([self.sampler.add_noise(x_0[idx], t[idx], noise=eps[idx]) for idx in range(B)]) + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) x_t = mx.stop_gradient(x_t) # Do the denoising