From aefe60e79de88b31e63a87b6bcfa1ae013d0d3dd Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 28 Sep 2024 01:41:56 -0700 Subject: [PATCH] Avoid upcasting and fix batch size > 1 --- flux/flux/layers.py | 2 +- flux/flux/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flux/flux/layers.py b/flux/flux/layers.py index 35ddac84..ca347fdb 100644 --- a/flux/flux/layers.py +++ b/flux/flux/layers.py @@ -133,7 +133,7 @@ class Modulation(nn.Module): def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]: x = self.lin(nn.silu(x)) - xs = mx.split(x, self.multiplier, axis=-1) + xs = mx.split(x[:, None, :], self.multiplier, axis=-1) mod1 = ModulationOut(*xs[:3]) mod2 = ModulationOut(*xs[3:]) if self.is_double else None diff --git a/flux/flux/model.py b/flux/flux/model.py index 52b3fe09..45385b65 100644 --- a/flux/flux/model.py +++ b/flux/flux/model.py @@ -117,7 +117,7 @@ class Flux(nn.Module): txt = self.txt_in(txt) ids = mx.concatenate([txt_ids, img_ids], axis=1) - pe = self.pe_embedder(ids) + pe = self.pe_embedder(ids).astype(img.dtype) for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe)