Avoid upcasting and fix batch size > 1

This commit is contained in:
Angelos Katharopoulos 2024-09-28 01:41:56 -07:00
parent 070c58ed92
commit aefe60e79d
2 changed files with 2 additions and 2 deletions

View File

@ -133,7 +133,7 @@ class Modulation(nn.Module):
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
x = self.lin(nn.silu(x))
xs = mx.split(x, self.multiplier, axis=-1)
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
mod1 = ModulationOut(*xs[:3])
mod2 = ModulationOut(*xs[3:]) if self.is_double else None

View File

@ -117,7 +117,7 @@ class Flux(nn.Module):
txt = self.txt_in(txt)
ids = mx.concatenate([txt_ids, img_ids], axis=1)
pe = self.pe_embedder(ids)
pe = self.pe_embedder(ids).astype(img.dtype)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)