mlx-examples/stable_diffusion/stable_diffusion/clip.py
2023-11-30 11:08:53 -08:00

71 lines
2.2 KiB
Python

# Copyright © 2023 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from .config import CLIPTextModelConfig
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
# Add biases to the attention projections to match CLIP
self.attention.query_proj.bias = mx.zeros(model_dims)
self.attention.key_proj.bias = mx.zeros(model_dims)
self.attention.value_proj.bias = mx.zeros(model_dims)
self.attention.out_proj.bias = mx.zeros(model_dims)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = nn.gelu_approx(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = [
CLIPEncoderLayer(config.model_dims, config.num_heads)
for i in range(config.num_layers)
]
self.final_layer_norm = nn.LayerNorm(config.model_dims)
def __call__(self, x):
# Extract some shapes
B, N = x.shape
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
for l in self.layers:
x = l(x, mask)
# Apply the final layernorm and return
return self.final_layer_norm(x)