mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Avoid upcasting and fix batch size > 1
This commit is contained in:
@@ -133,7 +133,7 @@ class Modulation(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
||||||
x = self.lin(nn.silu(x))
|
x = self.lin(nn.silu(x))
|
||||||
xs = mx.split(x, self.multiplier, axis=-1)
|
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
|
||||||
|
|
||||||
mod1 = ModulationOut(*xs[:3])
|
mod1 = ModulationOut(*xs[:3])
|
||||||
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
|
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
|
||||||
|
@@ -117,7 +117,7 @@ class Flux(nn.Module):
|
|||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = mx.concatenate([txt_ids, img_ids], axis=1)
|
ids = mx.concatenate([txt_ids, img_ids], axis=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids).astype(img.dtype)
|
||||||
|
|
||||||
for block in self.double_blocks:
|
for block in self.double_blocks:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
Reference in New Issue
Block a user