mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:46:09 +08:00
Fix batched add_noise
This commit is contained in:
parent
67c68452bb
commit
d50ad3ec27
@ -50,6 +50,7 @@ class FluxSampler:
|
||||
if noise is not None
|
||||
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
|
||||
)
|
||||
t = t.reshape([-1] + [1] * (x.ndim - 1))
|
||||
return x * (1 - t) + t * noise
|
||||
|
||||
def step(self, pred, x_t, t, t_prev):
|
||||
|
Loading…
Reference in New Issue
Block a user