mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Fix flux training with batch size (#1135)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
2211b27388
commit
1fd6aae871
@ -50,6 +50,7 @@ class FluxSampler:
|
||||
if noise is not None
|
||||
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
|
||||
)
|
||||
t = t.reshape([-1] + [1] * (x.ndim - 1))
|
||||
return x * (1 - t) + t * noise
|
||||
|
||||
def step(self, pred, x_t, t, t_prev):
|
||||
|
Loading…
Reference in New Issue
Block a user