mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Revert "Fix shapes mismatch error in FluxPipeline.training_loss when batch-size >= 2"
This reverts commit 67c68452bb
.
This commit is contained in:
parent
d50ad3ec27
commit
ceb0ae3416
@ -208,10 +208,9 @@ class FluxPipeline:
|
|||||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||||
|
|
||||||
# Forward process
|
# Forward process
|
||||||
B, L = x_0.shape[:2]
|
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||||
t = self.sampler.random_timesteps(B, L, dtype=self.dtype)
|
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||||
eps = mx.random.normal(x_0.shape, dtype=x_0.dtype)
|
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||||
x_t = mx.stack([self.sampler.add_noise(x_0[idx], t[idx], noise=eps[idx]) for idx in range(B)])
|
|
||||||
x_t = mx.stop_gradient(x_t)
|
x_t = mx.stop_gradient(x_t)
|
||||||
|
|
||||||
# Do the denoising
|
# Do the denoising
|
||||||
|
Loading…
Reference in New Issue
Block a user