From 1fd6aae871e9e21613ae90624cb4a72bdf709cc6 Mon Sep 17 00:00:00 2001 From: hehua2008 Date: Mon, 9 Dec 2024 14:09:04 +0800 Subject: [PATCH] Fix flux training with batch size (#1135) Co-authored-by: Angelos Katharopoulos --- flux/flux/sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py index 54c4fe35..e7a1080d 100644 --- a/flux/flux/sampler.py +++ b/flux/flux/sampler.py @@ -50,6 +50,7 @@ class FluxSampler: if noise is not None else mx.random.normal(x.shape, dtype=x.dtype, key=key) ) + t = t.reshape([-1] + [1] * (x.ndim - 1)) return x * (1 - t) + t * noise def step(self, pred, x_t, t, t_prev):