mlx-examples/segment_anything/segment_anything/image_encoder.py
Shiyu 8353bbbf93
Segment Anything Model (#552)
* add segment anything model

* add readme

* reorg file structure

* update

* lint

* minor updates

* ack

* fix weight loading

* simplify

* fix to run notebooks

* amg in mlx

* remove torch dependency

* nit in README

* return indices in nms

* simplify

* bugfix / simplify

* fix bug'

* simplify

* fix notebook and remove output

* couple more nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-02 16:45:51 -07:00

423 lines
14 KiB
Python

from typing import Optional, Tuple, Type
import mlx.core as mx
import mlx.nn as nn
from .common import LayerNorm2d, MLPBlock
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = mx.zeros(
[1, img_size // patch_size, img_size // patch_size, embed_dim]
)
else:
self.pos_embed = None
self.layers = []
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.layers.append(block)
self.neck = Neck(embed_dim, out_chans)
def __call__(self, x: mx.array) -> mx.array:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.layers:
x = blk(x)
x = self.neck(x)
return x
class Neck(nn.Module):
def __init__(self, embed_dim, out_chans):
super().__init__()
self.conv1 = nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
)
self.layer_norm1 = LayerNorm2d(out_chans)
self.conv2 = nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
)
self.layer_norm2 = LayerNorm2d(out_chans)
def __call__(self, x):
return self.layer_norm2(self.conv2(self.layer_norm1(self.conv1(x))))
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.layer_norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.layer_norm2 = norm_layer(dim)
self.mlp = MLPBlock(
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
)
self.window_size = window_size
def __call__(self, x: mx.array) -> mx.array:
shortcut = x
x = self.layer_norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.layer_norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = mx.zeros(shape=(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = mx.zeros(shape=(2 * input_size[1] - 1, head_dim))
def __call__(self, x: mx.array) -> mx.array:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
self.qkv(x)
.reshape(B, H * W, 3, self.num_heads, -1)
.transpose(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C)
qkv = qkv.reshape(3, B * self.num_heads, H * W, -1)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q * self.scale) @ k.transpose(0, 2, 1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(
attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
)
attn = mx.softmax(attn, axis=-1)
x = (
(attn @ v)
.reshape(B, self.num_heads, H, W, -1)
.transpose(0, 2, 3, 1, 4)
.reshape(B, H, W, -1)
)
x = self.proj(x)
return x
def window_partition(x: mx.array, window_size: int) -> Tuple[mx.array, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (mx.array): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = mx.pad(x, ((0, 0), (0, pad_w), (0, pad_h), (0, 0)))
Hp, Wp = H + pad_h, W + pad_w
x = x.reshape(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.transpose(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: mx.array,
window_size: int,
pad_hw: Tuple[int, int],
hw: Tuple[int, int],
) -> mx.array:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (mx.array): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.reshape(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.transpose(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :]
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: mx.array) -> mx.array:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (mx.array): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = rel_pos.reshape(1, rel_pos.shape[0], -1).transpose(0, 2, 1)
scale_factor = (
max_rel_dist[0] / rel_pos_resized.shape[1],
max_rel_dist[1] / rel_pos_resized.shape[2],
)
rel_pos_resized = nn.Upsample(scale_factor=scale_factor, mode="linear")(
rel_pos_resized
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).transpose(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = mx.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = mx.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.astype(mx.int64)]
def add_decomposed_rel_pos(
attn,
q,
rel_pos_h: mx.array,
rel_pos_w: mx.array,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> mx.array:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (mx.array): attention map.
q (mx.array): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (mx.array): relative position embeddings (Lh, C) for height axis.
rel_pos_w (mx.array): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (mx.array): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
# TODO: replace mx.einsum when its ready
# workaround for these einsum computations
# rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
# rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
rel_h = r_q @ Rh.transpose(0, 2, 1)
rel_w = (r_q.transpose(0, 2, 1, 3) @ Rw.transpose(0, 2, 1)).transpose(0, 2, 1, 3)
attn = (
attn.reshape(B, q_h, q_w, k_h, k_w)
+ rel_h[:, :, :, :, None]
+ rel_w[:, :, :, None, :]
).reshape(B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.projection = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def __call__(self, x: mx.array) -> mx.array:
x = self.projection(x)
return x