diff --git a/flux/flux/flux.py b/flux/flux/flux.py index 3fd044ac..c97d7487 100644 --- a/flux/flux/flux.py +++ b/flux/flux/flux.py @@ -208,9 +208,10 @@ class FluxPipeline: x_0, x_ids = self._prepare_latent_images(x_0) # Forward process - t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) - eps = mx.random.normal(x_0.shape, dtype=self.dtype) - x_t = self.sampler.add_noise(x_0, t, noise=eps) + B, L = x_0.shape[:2] + t = self.sampler.random_timesteps(B, L, dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=x_0.dtype) + x_t = mx.stack([self.sampler.add_noise(x_0[idx], t[idx], noise=eps[idx]) for idx in range(B)]) x_t = mx.stop_gradient(x_t) # Do the denoising