From d50ad3ec27aefd8c4544b8ee3fb14cf07ba730df Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 8 Dec 2024 21:35:16 -0800 Subject: [PATCH] Fix batched add_noise --- flux/flux/sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py index 54c4fe35..e7a1080d 100644 --- a/flux/flux/sampler.py +++ b/flux/flux/sampler.py @@ -50,6 +50,7 @@ class FluxSampler: if noise is not None else mx.random.normal(x.shape, dtype=x.dtype, key=key) ) + t = t.reshape([-1] + [1] * (x.ndim - 1)) return x * (1 - t) + t * noise def step(self, pred, x_t, t, t_prev):