diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py index 54c4fe35..e7a1080d 100644 --- a/flux/flux/sampler.py +++ b/flux/flux/sampler.py @@ -50,6 +50,7 @@ class FluxSampler: if noise is not None else mx.random.normal(x.shape, dtype=x.dtype, key=key) ) + t = t.reshape([-1] + [1] * (x.ndim - 1)) return x * (1 - t) + t * noise def step(self, pred, x_t, t, t_prev):