Revert "Fix shapes mismatch error in FluxPipeline.training_loss when batch-size >= 2"

This reverts commit 67c68452bb.
This commit is contained in:
Angelos Katharopoulos 2024-12-08 21:36:16 -08:00
parent d50ad3ec27
commit ceb0ae3416

View File

@ -208,10 +208,9 @@ class FluxPipeline:
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process
B, L = x_0.shape[:2]
t = self.sampler.random_timesteps(B, L, dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=x_0.dtype)
x_t = mx.stack([self.sampler.add_noise(x_0[idx], t[idx], noise=eps[idx]) for idx in range(B)])
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
x_t = mx.stop_gradient(x_t)
# Do the denoising