mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:46:09 +08:00
Revert "Fix shapes mismatch error in FluxPipeline.training_loss when batch-size >= 2"
This reverts commit 67c68452bb
.
This commit is contained in:
parent
d50ad3ec27
commit
ceb0ae3416
@ -208,10 +208,9 @@ class FluxPipeline:
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
||||
# Forward process
|
||||
B, L = x_0.shape[:2]
|
||||
t = self.sampler.random_timesteps(B, L, dtype=self.dtype)
|
||||
eps = mx.random.normal(x_0.shape, dtype=x_0.dtype)
|
||||
x_t = mx.stack([self.sampler.add_noise(x_0[idx], t[idx], noise=eps[idx]) for idx in range(B)])
|
||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||
x_t = mx.stop_gradient(x_t)
|
||||
|
||||
# Do the denoising
|
||||
|
Loading…
Reference in New Issue
Block a user