mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00

* add segment anything model * add readme * reorg file structure * update * lint * minor updates * ack * fix weight loading * simplify * fix to run notebooks * amg in mlx * remove torch dependency * nit in README * return indices in nms * simplify * bugfix / simplify * fix bug' * simplify * fix notebook and remove output * couple more nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
253 lines
8.1 KiB
Python
253 lines
8.1 KiB
Python
import math
|
|
from typing import List, Tuple, Type, Union
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .common import LayerNorm2d
|
|
|
|
|
|
class MaskDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
transformer_dim: int,
|
|
transformer: nn.Module,
|
|
num_multimask_outputs: int = 3,
|
|
activation: Type[nn.Module] = nn.GELU,
|
|
iou_head_depth: int = 3,
|
|
iou_head_hidden_dim: int = 256,
|
|
) -> None:
|
|
"""
|
|
Predicts masks given an image and prompt embeddings, using a
|
|
transformer architecture.
|
|
|
|
Args:
|
|
transformer_dim (int): the channel dimension of the transformer
|
|
transformer (nn.Module): the transformer used to predict masks
|
|
num_multimask_outputs (int): the number of masks to predict
|
|
when disambiguating masks
|
|
activation (nn.Module): the type of activation to use when
|
|
upscaling masks
|
|
iou_head_depth (int): the depth of the MLP used to predict
|
|
mask quality
|
|
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
|
used to predict mask quality
|
|
"""
|
|
super().__init__()
|
|
self.transformer_dim = transformer_dim
|
|
self.transformer = transformer
|
|
|
|
self.num_multimask_outputs = num_multimask_outputs
|
|
|
|
self.iou_token = nn.Embedding(1, transformer_dim)
|
|
self.num_mask_tokens = num_multimask_outputs + 1
|
|
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
|
|
|
self.upscale_conv1 = ConvTranspose2d(
|
|
transformer_dim,
|
|
transformer_dim // 4,
|
|
kernel_size=2,
|
|
stride=2,
|
|
padding=1,
|
|
)
|
|
self.upscale_layer_norm = LayerNorm2d(transformer_dim // 4)
|
|
self.activation = activation()
|
|
self.upscale_conv2 = ConvTranspose2d(
|
|
transformer_dim // 4,
|
|
transformer_dim // 8,
|
|
kernel_size=2,
|
|
stride=2,
|
|
padding=1,
|
|
)
|
|
self.output_hypernetworks_mlps = [
|
|
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 1)
|
|
for i in range(self.num_mask_tokens)
|
|
]
|
|
|
|
self.iou_prediction_head = MLP(
|
|
transformer_dim,
|
|
iou_head_hidden_dim,
|
|
self.num_mask_tokens,
|
|
iou_head_depth - 2,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
image_embeddings: mx.array,
|
|
image_pe: mx.array,
|
|
sparse_prompt_embeddings: mx.array,
|
|
dense_prompt_embeddings: mx.array,
|
|
multimask_output: bool,
|
|
) -> Tuple[mx.array, mx.array]:
|
|
"""
|
|
Predict masks given image and prompt embeddings.
|
|
|
|
Args:
|
|
image_embeddings (mx.array): the embeddings from the image encoder
|
|
image_pe (mx.array): positional encoding
|
|
sparse_prompt_embeddings (mx.array): the embeddings of the points and boxes
|
|
dense_prompt_embeddings (mx.array): the embeddings of the mask inputs
|
|
multimask_output (bool): Whether to return multiple masks or a single
|
|
mask.
|
|
|
|
Returns:
|
|
mx.array: batched predicted masks
|
|
mx.array: batched predictions of mask quality
|
|
"""
|
|
masks, iou_pred = self.predict_masks(
|
|
image_embeddings=image_embeddings,
|
|
image_pe=image_pe,
|
|
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
|
dense_prompt_embeddings=dense_prompt_embeddings,
|
|
)
|
|
|
|
# Select the correct mask or masks for output
|
|
if multimask_output:
|
|
mask_slice = slice(1, None)
|
|
else:
|
|
mask_slice = slice(0, 1)
|
|
masks = masks[:, :, :, mask_slice]
|
|
iou_pred = iou_pred[:, mask_slice]
|
|
|
|
# Prepare output
|
|
return masks, iou_pred
|
|
|
|
def predict_masks(
|
|
self,
|
|
image_embeddings: mx.array,
|
|
image_pe: mx.array,
|
|
sparse_prompt_embeddings: mx.array,
|
|
dense_prompt_embeddings: mx.array,
|
|
) -> Tuple[mx.array, mx.array]:
|
|
"""Predicts masks. See '__call__' for more details."""
|
|
# Concatenate output tokens
|
|
output_tokens = mx.concatenate(
|
|
[self.iou_token.weight, self.mask_tokens.weight], axis=0
|
|
)
|
|
output_tokens = mx.broadcast_to(
|
|
output_tokens[None],
|
|
[
|
|
sparse_prompt_embeddings.shape[0],
|
|
output_tokens.shape[0],
|
|
output_tokens.shape[1],
|
|
],
|
|
)
|
|
tokens = mx.concatenate((output_tokens, sparse_prompt_embeddings), axis=1)
|
|
|
|
# Expand per-image data in batch direction to be per-mask
|
|
src = mx.repeat(image_embeddings, repeats=tokens.shape[0], axis=0)
|
|
src = src + dense_prompt_embeddings
|
|
b, h, w, c = src.shape
|
|
|
|
# Run the transformer
|
|
hs, src = self.transformer(src, image_pe, tokens)
|
|
iou_token_out = hs[:, 0, :]
|
|
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
|
|
|
# Upscale mask embeddings and predict masks using the mask tokens
|
|
src = src.reshape(b, h, w, c)
|
|
src = self.upscale_conv1(src)
|
|
src = self.upscale_layer_norm(src)
|
|
src = self.activation(src)
|
|
src = self.upscale_conv2(src)
|
|
upscaled_embedding = self.activation(src)
|
|
hyper_in_list: List[mx.array] = []
|
|
for i in range(self.num_mask_tokens):
|
|
hyper_in_list.append(
|
|
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
|
)
|
|
hyper_in = mx.stack(hyper_in_list, axis=1)
|
|
b, h, w, c = upscaled_embedding.shape
|
|
|
|
masks = (
|
|
(hyper_in @ upscaled_embedding.reshape(b, h * w, c).transpose(0, 2, 1))
|
|
.transpose(0, 2, 1)
|
|
.reshape(b, h, w, -1)
|
|
)
|
|
|
|
# Generate mask quality predictions
|
|
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
|
|
return masks, iou_pred
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
output_dim: int,
|
|
num_layers: int,
|
|
sigmoid_output: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.num_layers = num_layers
|
|
self.proj_in = nn.Linear(input_dim, hidden_dim)
|
|
self.layers = [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)]
|
|
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
|
self.sigmoid_output = sigmoid_output
|
|
|
|
def __call__(self, x):
|
|
x = nn.relu(self.proj_in(x))
|
|
for i, layer in enumerate(self.layers):
|
|
x = nn.relu(layer(x))
|
|
x = self.proj_out(x)
|
|
if self.sigmoid_output:
|
|
x = mx.sigmoid(x)
|
|
return x
|
|
|
|
|
|
# TODO: Naive implem. Replace when mlx.nn support conv_transpose
|
|
class ConvTranspose2d(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Union[int, tuple],
|
|
stride: Union[int, tuple] = 1,
|
|
padding: Union[int, tuple] = 0,
|
|
dilation: Union[int, tuple] = 1,
|
|
bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
kernel_size, stride, padding = map(
|
|
lambda x: (x, x) if isinstance(x, int) else x,
|
|
(kernel_size, stride, padding),
|
|
)
|
|
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
|
self.weight = mx.random.uniform(
|
|
low=-scale,
|
|
high=scale,
|
|
shape=(out_channels, *kernel_size, in_channels),
|
|
)
|
|
if bias:
|
|
self.bias = mx.zeros((out_channels,))
|
|
|
|
self.padding = padding
|
|
self.stride = stride
|
|
self.dilation = dilation
|
|
|
|
def _extra_repr(self):
|
|
return (
|
|
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
|
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
|
f"padding={self.padding}, dilation={self.dilation}, "
|
|
f"bias={'bias' in self}"
|
|
)
|
|
|
|
def __call__(self, x):
|
|
y = mx.conv_general(
|
|
x,
|
|
self.weight,
|
|
stride=1,
|
|
padding=self.padding,
|
|
kernel_dilation=self.dilation,
|
|
input_dilation=self.stride,
|
|
flip=True,
|
|
)
|
|
if "bias" in self:
|
|
y = y + self.bias
|
|
return y
|