mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Avoid upcasting and fix batch size > 1
This commit is contained in:
parent
070c58ed92
commit
aefe60e79d
@ -133,7 +133,7 @@ class Modulation(nn.Module):
|
||||
|
||||
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
||||
x = self.lin(nn.silu(x))
|
||||
xs = mx.split(x, self.multiplier, axis=-1)
|
||||
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
|
||||
|
||||
mod1 = ModulationOut(*xs[:3])
|
||||
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
|
||||
|
@ -117,7 +117,7 @@ class Flux(nn.Module):
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = mx.concatenate([txt_ids, img_ids], axis=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
pe = self.pe_embedder(ids).astype(img.dtype)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
Loading…
Reference in New Issue
Block a user