diff --git a/flux/flux/flux.py b/flux/flux/flux.py index c97d7487..3fd044ac 100644 --- a/flux/flux/flux.py +++ b/flux/flux/flux.py @@ -208,10 +208,9 @@ class FluxPipeline: x_0, x_ids = self._prepare_latent_images(x_0) # Forward process - B, L = x_0.shape[:2] - t = self.sampler.random_timesteps(B, L, dtype=self.dtype) - eps = mx.random.normal(x_0.shape, dtype=x_0.dtype) - x_t = mx.stack([self.sampler.add_noise(x_0[idx], t[idx], noise=eps[idx]) for idx in range(B)]) + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) x_t = mx.stop_gradient(x_t) # Do the denoising