mlx-examples/flux/flux/model.py
2024-12-02 13:15:50 -08:00

137 lines
4.1 KiB
Python

# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from .layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if params.guidance_embed
else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = [
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
self.single_blocks = [
SingleStreamBlock(
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
)
for _ in range(params.depth_single_blocks)
]
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if k.startswith("model.diffusion_model."):
k = k[22:]
if k.endswith(".scale"):
k = k[:-6] + ".weight"
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
if f".{seq}." in k:
k = k.replace(f".{seq}.", f".{seq}.layers.")
break
new_weights[k] = w
return new_weights
def __call__(
self,
img: mx.array,
img_ids: mx.array,
txt: mx.array,
txt_ids: mx.array,
timesteps: mx.array,
y: mx.array,
guidance: Optional[mx.array] = None,
) -> mx.array:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = mx.concatenate([txt_ids, img_ids], axis=1)
pe = self.pe_embedder(ids).astype(img.dtype)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = mx.concatenate([txt, img], axis=1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec)
return img