Fix shapes mismatch error in FluxPipeline.training_loss when batch-size >= 2

This commit is contained in:
hehua2008 2024-12-04 22:37:04 +08:00
parent 1727959a27
commit 67c68452bb

View File

@ -208,9 +208,10 @@ class FluxPipeline:
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
B, L = x_0.shape[:2]
t = self.sampler.random_timesteps(B, L, dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=x_0.dtype)
x_t = mx.stack([self.sampler.add_noise(x_0[idx], t[idx], noise=eps[idx]) for idx in range(B)])
x_t = mx.stop_gradient(x_t)
# Do the denoising