mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Fix batched add_noise
This commit is contained in:
parent
67c68452bb
commit
d50ad3ec27
@ -50,6 +50,7 @@ class FluxSampler:
|
|||||||
if noise is not None
|
if noise is not None
|
||||||
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
|
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
|
||||||
)
|
)
|
||||||
|
t = t.reshape([-1] + [1] * (x.ndim - 1))
|
||||||
return x * (1 - t) + t * noise
|
return x * (1 - t) + t * noise
|
||||||
|
|
||||||
def step(self, pred, x_t, t, t_prev):
|
def step(self, pred, x_t, t, t_prev):
|
||||||
|
Loading…
Reference in New Issue
Block a user