Fix flux training with batch size (#1135)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
hehua2008 2024-12-09 14:09:04 +08:00 committed by GitHub
parent 2211b27388
commit 1fd6aae871
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -50,6 +50,7 @@ class FluxSampler:
if noise is not None
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
)
t = t.reshape([-1] + [1] * (x.ndim - 1))
return x * (1 - t) + t * noise
def step(self, pred, x_t, t, t_prev):