mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-12 20:51:13 +08:00
71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .config import CLIPTextModelConfig
|
|
|
|
|
|
class CLIPEncoderLayer(nn.Module):
|
|
"""The transformer encoder layer from CLIP."""
|
|
|
|
def __init__(self, model_dims: int, num_heads: int):
|
|
super().__init__()
|
|
|
|
self.layer_norm1 = nn.LayerNorm(model_dims)
|
|
self.layer_norm2 = nn.LayerNorm(model_dims)
|
|
|
|
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
|
|
# Add biases to the attention projections to match CLIP
|
|
self.attention.query_proj.bias = mx.zeros(model_dims)
|
|
self.attention.key_proj.bias = mx.zeros(model_dims)
|
|
self.attention.value_proj.bias = mx.zeros(model_dims)
|
|
self.attention.out_proj.bias = mx.zeros(model_dims)
|
|
|
|
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
|
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
|
|
|
def __call__(self, x, attn_mask=None):
|
|
y = self.layer_norm1(x)
|
|
y = self.attention(y, y, y, attn_mask)
|
|
x = y + x
|
|
|
|
y = self.layer_norm2(x)
|
|
y = self.linear1(y)
|
|
y = nn.gelu_approx(y)
|
|
y = self.linear2(y)
|
|
x = y + x
|
|
|
|
return x
|
|
|
|
|
|
class CLIPTextModel(nn.Module):
|
|
"""Implements the text encoder transformer from CLIP."""
|
|
|
|
def __init__(self, config: CLIPTextModelConfig):
|
|
super().__init__()
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
|
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
|
self.layers = [
|
|
CLIPEncoderLayer(config.model_dims, config.num_heads)
|
|
for i in range(config.num_layers)
|
|
]
|
|
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
|
|
|
def __call__(self, x):
|
|
# Extract some shapes
|
|
B, N = x.shape
|
|
|
|
# Compute the embeddings
|
|
x = self.token_embedding(x)
|
|
x = x + self.position_embedding.weight[:N]
|
|
|
|
# Compute the features from the transformer
|
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
|
|
for l in self.layers:
|
|
x = l(x, mask)
|
|
|
|
# Apply the final layernorm and return
|
|
return self.final_layer_norm(x)
|