Fix batched add_noise

This commit is contained in:
Angelos Katharopoulos 2024-12-08 21:35:16 -08:00
parent 67c68452bb
commit d50ad3ec27

View File

@ -50,6 +50,7 @@ class FluxSampler:
if noise is not None
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
)
t = t.reshape([-1] + [1] * (x.ndim - 1))
return x * (1 - t) + t * noise
def step(self, pred, x_t, t, t_prev):