mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-23 16:46:52 +08:00
Segment Anything Model (#552)
* add segment anything model * add readme * reorg file structure * update * lint * minor updates * ack * fix weight loading * simplify * fix to run notebooks * amg in mlx * remove torch dependency * nit in README * return indices in nms * simplify * bugfix / simplify * fix bug' * simplify * fix notebook and remove output * couple more nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
89b0b75250
commit
8353bbbf93
@ -13,3 +13,4 @@ MLX Examples was developed with contributions from the following individuals:
|
||||
- Gabrijel Boduljak: Implemented `CLIP`.
|
||||
- Markus Enzweiler: Added the `cvae` examples.
|
||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||
- Shiyu Li: Added the `Segment Anything Model`.
|
||||
|
39
segment_anything/README.md
Normal file
39
segment_anything/README.md
Normal file
@ -0,0 +1,39 @@
|
||||
# Segment Anything
|
||||
|
||||
An implementation of the Segment Anything Model (SAM) in MLX. See the original
|
||||
repo by Meta AI for more details.[^1]
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Convert
|
||||
|
||||
```bash
|
||||
python convert.py --hf-path facebook/sam-vit-base --mlx-path sam-vit-base
|
||||
```
|
||||
|
||||
The `safetensors` weight file and configs are downloaded from Hugging Face,
|
||||
converted, and saved in the directory specified by `--mlx-path`.
|
||||
|
||||
The model sizes are:
|
||||
|
||||
- `facebook/sam-vit-base`
|
||||
- `facebook/sam-vit-large`
|
||||
- `facebook/sam-vit-huge`
|
||||
|
||||
## Run
|
||||
|
||||
See examples `notebooks/predictor_example.ipynb` and
|
||||
`notebooks/automatic_mask_generator_example.ipynb` to try the Segment Anything
|
||||
Model with MLX.
|
||||
|
||||
You can also generate masks from the command line:
|
||||
|
||||
```bash
|
||||
python main.py --model <path/to/model> --input <image_or_folder> --output <path/to/output>
|
||||
```
|
||||
|
||||
[^1]: The original Segment Anything [GitHub repo](https://github.com/facebookresearch/segment-anything/tree/main).
|
91
segment_anything/convert.py
Normal file
91
segment_anything/convert.py
Normal file
@ -0,0 +1,91 @@
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def save_weights(save_path: Union[str, Path], weights: Dict[str, mx.array]) -> None:
|
||||
"""Save model weights into specified directory."""
|
||||
if isinstance(save_path, str):
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_size = sum(v.nbytes for v in weights.values())
|
||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||
|
||||
model_path = save_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights)
|
||||
|
||||
for weight_name in weights.keys():
|
||||
index_data["weight_map"][weight_name] = "model.safetensors"
|
||||
|
||||
index_data["weight_map"] = {
|
||||
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||
}
|
||||
|
||||
with open(save_path / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(index_data, f, indent=4)
|
||||
|
||||
|
||||
def download(hf_repo):
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=hf_repo,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
resume_download=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def convert(model_path):
|
||||
weight_file = str(model_path / "model.safetensors")
|
||||
weights = mx.load(weight_file)
|
||||
|
||||
mlx_weights = dict()
|
||||
for k, v in weights.items():
|
||||
if k in {
|
||||
"vision_encoder.patch_embed.projection.weight",
|
||||
"vision_encoder.neck.conv1.weight",
|
||||
"vision_encoder.neck.conv2.weight",
|
||||
"prompt_encoder.mask_embed.conv1.weight",
|
||||
"prompt_encoder.mask_embed.conv2.weight",
|
||||
"prompt_encoder.mask_embed.conv3.weight",
|
||||
}:
|
||||
v = v.transpose(0, 2, 3, 1)
|
||||
if k in {
|
||||
"mask_decoder.upscale_conv1.weight",
|
||||
"mask_decoder.upscale_conv2.weight",
|
||||
}:
|
||||
v = v.transpose(1, 2, 3, 0)
|
||||
mlx_weights[k] = v
|
||||
return mlx_weights
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Meta SAM weights to MLX")
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
default="facebook/sam-vit-base",
|
||||
type=str,
|
||||
help="Path to the Hugging Face model repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="sam-vit-base",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = download(args.hf_path)
|
||||
|
||||
mlx_path = Path(args.mlx_path)
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mlx_weights = convert(model_path)
|
||||
save_weights(mlx_path, mlx_weights)
|
||||
shutil.copy(model_path / "config.json", mlx_path / "config.json")
|
225
segment_anything/main.py
Normal file
225
segment_anything/main.py
Normal file
@ -0,0 +1,225 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import cv2
|
||||
|
||||
from segment_anything import SamAutomaticMaskGenerator, sam
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Runs automatic mask generation on an input image or directory of images, "
|
||||
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
|
||||
"as well as pycocotools if saving in RLE format."
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to either a single input image or folder of images.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Path to the directory where masks will be output. Output will be either a folder "
|
||||
"of PNGs per image or a single json with COCO-style masks."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The path to the SAM model to use for mask generation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--convert-to-rle",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
|
||||
"Requires pycocotools."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings = parser.add_argument_group("AMG Settings")
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-side",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Generate masks by sampling a grid over the image with this many points to a side.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--points-per-batch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="How many input points to process simultaneously in one batch.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--pred-iou-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Exclude masks with a stability score lower than this threshold.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--stability-score-offset",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Larger values perturb the mask more when measuring stability score.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--box-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding a duplicate mask.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-layers",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
|
||||
"The value sets how many different scales to crop at."
|
||||
),
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-nms-thresh",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The overlap threshold for excluding duplicate masks across different crops.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-overlap-ratio",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Larger numbers mean image crops will overlap more.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--crop-n-points-downscale-factor",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of points-per-side in each layer of crop is reduced by this factor.",
|
||||
)
|
||||
|
||||
amg_settings.add_argument(
|
||||
"--min-mask-region-area",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Disconnected mask regions or holes with area smaller than this value "
|
||||
"in pixels are removed by postprocessing."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
|
||||
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
|
||||
metadata = [header]
|
||||
for i, mask_data in enumerate(masks):
|
||||
mask = mask_data["segmentation"]
|
||||
filename = f"{i}.png"
|
||||
cv2.imwrite(os.path.join(path, filename), mask * 255)
|
||||
mask_metadata = [
|
||||
str(i),
|
||||
str(mask_data["area"]),
|
||||
*[str(x) for x in mask_data["bbox"]],
|
||||
*[str(x) for x in mask_data["point_coords"][0]],
|
||||
str(mask_data["predicted_iou"]),
|
||||
str(mask_data["stability_score"]),
|
||||
*[str(x) for x in mask_data["crop_box"]],
|
||||
]
|
||||
row = ",".join(mask_metadata)
|
||||
metadata.append(row)
|
||||
metadata_path = os.path.join(path, "metadata.csv")
|
||||
with open(metadata_path, "w") as f:
|
||||
f.write("\n".join(metadata))
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_amg_kwargs(args):
|
||||
amg_kwargs = {
|
||||
"points_per_side": args.points_per_side,
|
||||
"points_per_batch": args.points_per_batch,
|
||||
"pred_iou_thresh": args.pred_iou_thresh,
|
||||
"stability_score_thresh": args.stability_score_thresh,
|
||||
"stability_score_offset": args.stability_score_offset,
|
||||
"box_nms_thresh": args.box_nms_thresh,
|
||||
"crop_n_layers": args.crop_n_layers,
|
||||
"crop_nms_thresh": args.crop_nms_thresh,
|
||||
"crop_overlap_ratio": args.crop_overlap_ratio,
|
||||
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
|
||||
"min_mask_region_area": args.min_mask_region_area,
|
||||
}
|
||||
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
|
||||
return amg_kwargs
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
print("Loading model...")
|
||||
model = sam.load(args.model)
|
||||
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
|
||||
amg_kwargs = get_amg_kwargs(args)
|
||||
generator = SamAutomaticMaskGenerator(model, output_mode=output_mode, **amg_kwargs)
|
||||
|
||||
if not os.path.isdir(args.input):
|
||||
targets = [args.input]
|
||||
else:
|
||||
targets = [
|
||||
f
|
||||
for f in os.listdir(args.input)
|
||||
if not os.path.isdir(os.path.join(args.input, f))
|
||||
]
|
||||
targets = [os.path.join(args.input, f) for f in targets]
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
for t in targets:
|
||||
print(f"Processing '{t}'...")
|
||||
image = cv2.imread(t)
|
||||
if image is None:
|
||||
print(f"Could not load '{t}' as an image, skipping...")
|
||||
continue
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
masks = generator.generate(image)
|
||||
|
||||
base = os.path.basename(t)
|
||||
base = os.path.splitext(base)[0]
|
||||
save_base = os.path.join(args.output, base)
|
||||
if output_mode == "binary_mask":
|
||||
os.makedirs(save_base, exist_ok=False)
|
||||
write_masks_to_folder(masks, save_base)
|
||||
else:
|
||||
save_file = save_base + ".json"
|
||||
with open(save_file, "w") as f:
|
||||
json.dump(masks, f)
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -0,0 +1,257 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Automatically generating object masks with SAM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook walks through how to automatically segment objects in an image. It is modified from [original SAM GitHub repo](https://github.com/facebookresearch/segment-anything/).\n",
|
||||
"\n",
|
||||
"Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image. This method was used to generate the dataset SA-1B. \n",
|
||||
"\n",
|
||||
"The class `SamAutomaticMaskGenerator` implements this. It samples single-point input prompts in a grid over the image, from each of which SAM then predicts multiple masks. The masks are filtered for quality and deduplicated using non-max suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set-up"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import cv2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def show_anns(anns):\n",
|
||||
" if len(anns) == 0:\n",
|
||||
" return\n",
|
||||
" sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)\n",
|
||||
" ax = plt.gca()\n",
|
||||
" ax.set_autoscale_on(False)\n",
|
||||
"\n",
|
||||
" img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))\n",
|
||||
" img[:,:,3] = 0\n",
|
||||
" for ann in sorted_anns:\n",
|
||||
" m = ann['segmentation']\n",
|
||||
" color_mask = np.concatenate([np.random.random(3), [0.35]])\n",
|
||||
" img[m] = color_mask\n",
|
||||
" ax.imshow(img)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"image = cv2.imread('images/dog.jpg')\n",
|
||||
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(20,20))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Automatic mask generation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append(\"..\")\n",
|
||||
"from segment_anything import SamAutomaticMaskGenerator\n",
|
||||
"from segment_anything.sam import load\n",
|
||||
"\n",
|
||||
"sam_checkpoint = \"../sam-vit-base\"\n",
|
||||
"sam = load(sam_checkpoint)\n",
|
||||
"\n",
|
||||
"mask_generator = SamAutomaticMaskGenerator(sam)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To generate masks, run `generate` on an image."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks = mask_generator.generate(image)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Mask generation returns a list over masks. Each item is a dictionary with keys:\n",
|
||||
"* `segmentation` : the mask\n",
|
||||
"* `area` : the area of the mask in pixels\n",
|
||||
"* `bbox` : the boundary box of the mask in XYWH format\n",
|
||||
"* `predicted_iou` : the model's own prediction for the quality of the mask\n",
|
||||
"* `point_coords` : the sampled input point that generated this mask\n",
|
||||
"* `stability_score` : an additional measure of mask quality\n",
|
||||
"* `crop_box` : the crop of the image used to generate this mask in XYWH format"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(len(masks))\n",
|
||||
"print(masks[0].keys())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Show all the masks overlayed on the image."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(20,20))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"show_anns(masks)\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show() "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Automatic mask generation options"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Generation can be automatically run on crops of the image to get better results for smaller objects. Post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mask_generator_2 = SamAutomaticMaskGenerator(\n",
|
||||
" model=sam,\n",
|
||||
" points_per_side=32,\n",
|
||||
" pred_iou_thresh=0.86,\n",
|
||||
" stability_score_thresh=0.92,\n",
|
||||
" crop_n_layers=1,\n",
|
||||
" crop_n_points_downscale_factor=2,\n",
|
||||
" min_mask_region_area=100, # Requires open-cv to run post-processing\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks2 = mask_generator_2.generate(image)\n",
|
||||
"len(masks2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(20,20))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"show_anns(masks2)\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show() "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.17"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
BIN
segment_anything/notebooks/images/dog.jpg
Normal file
BIN
segment_anything/notebooks/images/dog.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 98 KiB |
BIN
segment_anything/notebooks/images/groceries.jpg
Normal file
BIN
segment_anything/notebooks/images/groceries.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 164 KiB |
BIN
segment_anything/notebooks/images/truck.jpg
Normal file
BIN
segment_anything/notebooks/images/truck.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 265 KiB |
629
segment_anything/notebooks/predictor_example.ipynb
Normal file
629
segment_anything/notebooks/predictor_example.ipynb
Normal file
@ -0,0 +1,629 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Segmenting from Prompts\n",
|
||||
"\n",
|
||||
"This notebook walks through predicting object segmentations from a provided prompt. It uses the `Predictor` class. It is modified from [original SAM GitHub repo](https://github.com/facebookresearch/segment-anything/).\n",
|
||||
"\n",
|
||||
"### Setup\n",
|
||||
"Necessary imports and helper functions for displaying points, boxes, and masks."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import mlx.core as mx\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def show_mask(mask, ax, random_color=False):\n",
|
||||
" if random_color:\n",
|
||||
" color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
|
||||
" else:\n",
|
||||
" color = np.array([30/255, 144/255, 255/255, 0.6])\n",
|
||||
" h, w = mask.shape[:2]\n",
|
||||
" mask_image = np.array(mask).reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
|
||||
" ax.imshow(mask_image)\n",
|
||||
" \n",
|
||||
"def show_points(coords, labels, ax, marker_size=375):\n",
|
||||
" pos_points = np.array(coords)[labels==1]\n",
|
||||
" neg_points = np.array(coords)[labels==0]\n",
|
||||
" ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
|
||||
" ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n",
|
||||
" \n",
|
||||
"def show_box(box, ax):\n",
|
||||
" box = box.tolist()\n",
|
||||
" x0, y0 = box[0], box[1]\n",
|
||||
" w, h = box[2] - box[0], box[3] - box[1]\n",
|
||||
" ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"image = cv2.imread('images/truck.jpg')\n",
|
||||
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"plt.axis('on')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Selecting objects with SAM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append(\"..\")\n",
|
||||
"from segment_anything.sam import load\n",
|
||||
"from segment_anything.predictor import SamPredictor\n",
|
||||
"\n",
|
||||
"sam_checkpoint = \"../sam-vit-base\"\n",
|
||||
"sam = load(sam_checkpoint)\n",
|
||||
"predictor = SamPredictor(sam)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictor.set_image(image)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_point = mx.array([[500, 375]])\n",
|
||||
"input_label = mx.array([1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"show_points(input_point, input_label, plt.gca())\n",
|
||||
"plt.axis('on')\n",
|
||||
"plt.show() "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Predict with `SamPredictor.predict`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, scores, logits = predictor.predict(\n",
|
||||
" point_coords=input_point[None],\n",
|
||||
" point_labels=input_label[None],\n",
|
||||
" multimask_output=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in range(masks.shape[-1]):\n",
|
||||
" mask = masks[..., i]\n",
|
||||
" score = scores[..., i].item()\n",
|
||||
" plt.figure(figsize=(10,10))\n",
|
||||
" plt.imshow(image)\n",
|
||||
" show_mask(mask[0], plt.gca())\n",
|
||||
" show_points(input_point, input_label, plt.gca())\n",
|
||||
" plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
|
||||
" plt.axis('off')\n",
|
||||
" plt.show() "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Specifying a specific object with additional points"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_point = mx.array([[500, 375], [1125, 625]])\n",
|
||||
"input_label = mx.array([1, 1])\n",
|
||||
"mask_input = logits[..., mx.argmax(scores)] # Choose the model's best mask"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, _, _ = predictor.predict(\n",
|
||||
" point_coords=input_point[None],\n",
|
||||
" point_labels=input_label[None],\n",
|
||||
" mask_input=mask_input[..., None],\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10,10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"show_mask(masks[0], plt.gca())\n",
|
||||
"show_points(input_point, input_label, plt.gca())\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show() "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_point = mx.array([[500, 375], [1125, 625]])\n",
|
||||
"input_label = mx.array([1, 0])\n",
|
||||
"mask_input = logits[..., mx.argmax(scores)] # Choose the model's best mask"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, _, _ = predictor.predict(\n",
|
||||
" point_coords=input_point[None],\n",
|
||||
" point_labels=input_label[None],\n",
|
||||
" mask_input=mask_input[..., None],\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10, 10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"show_mask(masks[0], plt.gca())\n",
|
||||
"show_points(input_point, input_label, plt.gca())\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show() "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Specifying a specific object with a box"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The model can also take a box as input, provided in xyxy format."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_box = mx.array([425, 600, 700, 875])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, _, _ = predictor.predict(\n",
|
||||
" point_coords=None,\n",
|
||||
" point_labels=None,\n",
|
||||
" box=input_box[None, :],\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10, 10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"show_mask(masks[0, ..., 0], plt.gca())\n",
|
||||
"show_box(input_box, plt.gca())\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Combining points and boxes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Points and boxes may be combined, just by including both types of prompts to the predictor. Here this can be used to select just the trucks's tire, instead of the entire wheel."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_box = mx.array([425, 600, 700, 875])\n",
|
||||
"input_point = mx.array([[575, 750]])\n",
|
||||
"input_label = mx.array([0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, _, _ = predictor.predict(\n",
|
||||
" point_coords=input_point[None],\n",
|
||||
" point_labels=input_label[None],\n",
|
||||
" box=input_box,\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10, 10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"show_mask(masks[0, ..., 0], plt.gca())\n",
|
||||
"show_box(input_box, plt.gca())\n",
|
||||
"show_points(input_point, input_label, plt.gca())\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Batched prompt inputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`SamPredictor` can take multiple input prompts for the same image."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_boxes = mx.array([\n",
|
||||
" [75, 275, 1725, 850],\n",
|
||||
" [425, 600, 700, 875],\n",
|
||||
" [1375, 550, 1650, 800],\n",
|
||||
" [1240, 675, 1400, 750],\n",
|
||||
"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"masks, _, _ = predictor.predict(\n",
|
||||
" point_coords=None,\n",
|
||||
" point_labels=None,\n",
|
||||
" box=input_boxes,\n",
|
||||
" multimask_output=False,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(10, 10))\n",
|
||||
"plt.imshow(image)\n",
|
||||
"for mask in masks:\n",
|
||||
" show_mask(mask, plt.gca(), random_color=True)\n",
|
||||
"for box in input_boxes:\n",
|
||||
" show_box(box, plt.gca())\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## End-to-end batched inference"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If all prompts are available in advance, it is possible to run SAM directly in an end-to-end fashion. This also allows batching over images."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"image1 = image # truck.jpg from above\n",
|
||||
"image1_boxes = mx.array([\n",
|
||||
" [75, 275, 1725, 850],\n",
|
||||
" [425, 600, 700, 875],\n",
|
||||
" [1375, 550, 1650, 800],\n",
|
||||
" [1240, 675, 1400, 750],\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"image2 = cv2.imread('images/groceries.jpg')\n",
|
||||
"image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)\n",
|
||||
"image2_boxes = mx.array([\n",
|
||||
" [450, 170, 520, 350],\n",
|
||||
" [350, 190, 450, 350],\n",
|
||||
" [500, 170, 580, 350],\n",
|
||||
" [580, 170, 640, 350],\n",
|
||||
"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Both images and prompts are input as mlx array that are already transformed to the correct frame. Inputs are packaged as a list over images, which each element is a dict that takes the following keys:\n",
|
||||
"* `image`: The input image as a mlx array in HWC format.\n",
|
||||
"* `original_size`: The size of the image before transforming for input to SAM, in (H, W) format.\n",
|
||||
"* `point_coords`: Batched coordinates of point prompts.\n",
|
||||
"* `point_labels`: Batched labels of point prompts.\n",
|
||||
"* `boxes`: Batched input boxes.\n",
|
||||
"* `mask_inputs`: Batched input masks.\n",
|
||||
"\n",
|
||||
"If a prompt is not present, the key can be excluded."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from segment_anything.utils.transforms import ResizeLongestSide\n",
|
||||
"resize_transform = ResizeLongestSide(sam.vision_encoder.img_size)\n",
|
||||
"\n",
|
||||
"def prepare_image(image, transform, device):\n",
|
||||
" image = transform.apply_image(image)\n",
|
||||
" image = mx.array(image)\n",
|
||||
" return image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batched_input = [\n",
|
||||
" {\n",
|
||||
" 'image': prepare_image(image1, resize_transform, sam),\n",
|
||||
" 'boxes': resize_transform.apply_boxes(image1_boxes, image1.shape[:2]),\n",
|
||||
" 'original_size': image1.shape[:2]\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" 'image': prepare_image(image2, resize_transform, sam),\n",
|
||||
" 'boxes': resize_transform.apply_boxes(image2_boxes, image2.shape[:2]),\n",
|
||||
" 'original_size': image2.shape[:2]\n",
|
||||
" }\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run the model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batched_output = sam(batched_input, multimask_output=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The output is a list over results for each input image, where list elements are dictionaries with the following keys:\n",
|
||||
"* `masks`: A batched mlx array of predicted binary masks, the size of the original image.\n",
|
||||
"* `iou_predictions`: The model's prediction of the quality for each mask.\n",
|
||||
"* `low_res_logits`: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, ax = plt.subplots(1, 2, figsize=(20, 20))\n",
|
||||
"\n",
|
||||
"ax[0].imshow(image1)\n",
|
||||
"for mask in batched_output[0]['masks']:\n",
|
||||
" show_mask(np.array(mask), ax[0], random_color=True)\n",
|
||||
"for box in image1_boxes:\n",
|
||||
" show_box(np.array(box), ax[0])\n",
|
||||
"ax[0].axis('off')\n",
|
||||
"\n",
|
||||
"ax[1].imshow(image2)\n",
|
||||
"for mask in batched_output[1]['masks']:\n",
|
||||
" show_mask(np.array(mask), ax[1], random_color=True)\n",
|
||||
"for box in image2_boxes:\n",
|
||||
" show_box(np.array(box), ax[1])\n",
|
||||
"ax[1].axis('off')\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.17"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
3
segment_anything/requirements.txt
Normal file
3
segment_anything/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
matplotlib
|
||||
opencv-python
|
||||
huggingface_hub
|
1
segment_anything/segment_anything/__init__.py
Normal file
1
segment_anything/segment_anything/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
425
segment_anything/segment_anything/automatic_mask_generator.py
Normal file
425
segment_anything/segment_anything/automatic_mask_generator.py
Normal file
@ -0,0 +1,425 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from .predictor import SamPredictor
|
||||
from .sam import Sam
|
||||
from .utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_mlx,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
|
||||
|
||||
class SamAutomaticMaskGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
model: Sam,
|
||||
points_per_side: Optional[int] = 32,
|
||||
points_per_batch: int = 64,
|
||||
pred_iou_thresh: float = 0.88,
|
||||
stability_score_thresh: float = 0.95,
|
||||
stability_score_offset: float = 1.0,
|
||||
box_nms_thresh: float = 0.7,
|
||||
crop_n_layers: int = 0,
|
||||
crop_nms_thresh: float = 0.7,
|
||||
crop_overlap_ratio: float = 512 / 1500,
|
||||
crop_n_points_downscale_factor: int = 1,
|
||||
point_grids: Optional[List[mx.array]] = None,
|
||||
min_mask_region_area: int = 0,
|
||||
output_mode: str = "binary_mask",
|
||||
) -> None:
|
||||
"""
|
||||
Using a SAM model, generates masks for the entire image.
|
||||
Generates a grid of point prompts over the image, then filters
|
||||
low quality and duplicate masks. The default settings are chosen
|
||||
for SAM with a ViT-H backbone.
|
||||
|
||||
Arguments:
|
||||
model (Sam): The SAM model to use for mask prediction.
|
||||
points_per_side (int or None): The number of points to be sampled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point sampling.
|
||||
points_per_batch (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks.
|
||||
crop_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_n_points_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(mx.array) or None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
min_mask_region_area (int): If >0, postprocessing will be applied
|
||||
to remove disconnected regions and holes in masks with area smaller
|
||||
than min_mask_region_area. Requires opencv.
|
||||
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
||||
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
||||
For large resolutions, 'binary_mask' may consume large amounts of
|
||||
memory.
|
||||
"""
|
||||
|
||||
assert (points_per_side is None) != (
|
||||
point_grids is None
|
||||
), "Exactly one of points_per_side or point_grid must be provided."
|
||||
if points_per_side is not None:
|
||||
self.point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
crop_n_layers,
|
||||
crop_n_points_downscale_factor,
|
||||
)
|
||||
elif point_grids is not None:
|
||||
self.point_grids = point_grids
|
||||
else:
|
||||
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
||||
|
||||
assert output_mode in [
|
||||
"binary_mask",
|
||||
"uncompressed_rle",
|
||||
"coco_rle",
|
||||
], f"Unknown output_mode {output_mode}."
|
||||
if output_mode == "coco_rle":
|
||||
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
||||
|
||||
if min_mask_region_area > 0:
|
||||
import cv2 # type: ignore # noqa: F401
|
||||
|
||||
self.predictor = SamPredictor(model)
|
||||
self.points_per_batch = points_per_batch
|
||||
self.pred_iou_thresh = pred_iou_thresh
|
||||
self.stability_score_thresh = stability_score_thresh
|
||||
self.stability_score_offset = stability_score_offset
|
||||
self.box_nms_thresh = box_nms_thresh
|
||||
self.crop_n_layers = crop_n_layers
|
||||
self.crop_nms_thresh = crop_nms_thresh
|
||||
self.crop_overlap_ratio = crop_overlap_ratio
|
||||
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
||||
self.min_mask_region_area = min_mask_region_area
|
||||
self.output_mode = output_mode
|
||||
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generates masks for the given image.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
||||
|
||||
Returns:
|
||||
list(dict(str, any)): A list over records for masks. Each record is
|
||||
a dict containing the following keys:
|
||||
segmentation (dict(str, any) or np.ndarray): The mask. If
|
||||
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
||||
is a dictionary containing the RLE.
|
||||
bbox (list(float)): The box around the mask, in XYWH format.
|
||||
area (int): The area in pixels of the mask.
|
||||
predicted_iou (float): The model's own prediction of the mask's
|
||||
quality. This is filtered by the pred_iou_thresh parameter.
|
||||
point_coords (list(list(float))): The point coordinates input
|
||||
to the model to generate this mask.
|
||||
stability_score (float): A measure of the mask's quality. This
|
||||
is filtered on using the stability_score_thresh parameter.
|
||||
crop_box (list(float)): The crop of the image used to generate
|
||||
the mask, given in XYWH format.
|
||||
"""
|
||||
|
||||
# Generate masks
|
||||
mask_data = self._generate_masks(image)
|
||||
|
||||
# Filter small disconnected regions and holes in masks
|
||||
if self.min_mask_region_area > 0:
|
||||
mask_data = self.postprocess_small_regions(
|
||||
mask_data,
|
||||
self.min_mask_region_area,
|
||||
max(self.box_nms_thresh, self.crop_nms_thresh),
|
||||
)
|
||||
|
||||
# Encode masks
|
||||
if self.output_mode == "coco_rle":
|
||||
mask_data["segmentations"] = [
|
||||
coco_encode_rle(rle) for rle in mask_data["rles"]
|
||||
]
|
||||
elif self.output_mode == "binary_mask":
|
||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
||||
else:
|
||||
mask_data["segmentations"] = mask_data["rles"]
|
||||
|
||||
# Write mask records
|
||||
curr_anns = []
|
||||
for idx in range(len(mask_data["segmentations"])):
|
||||
ann = {
|
||||
"segmentation": mask_data["segmentations"][idx],
|
||||
"area": area_from_rle(mask_data["rles"][idx]),
|
||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
||||
"point_coords": [mask_data["points"][idx].tolist()],
|
||||
"stability_score": mask_data["stability_score"][idx].item(),
|
||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
||||
}
|
||||
curr_anns.append(ann)
|
||||
|
||||
return curr_anns
|
||||
|
||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
||||
orig_size = image.shape[:2]
|
||||
crop_boxes, layer_idxs = generate_crop_boxes(
|
||||
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
||||
)
|
||||
|
||||
# Iterate over image crops
|
||||
data = MaskData()
|
||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
||||
data.cat(crop_data)
|
||||
|
||||
# Remove duplicate masks between crops
|
||||
if len(crop_boxes) > 1:
|
||||
# Prefer masks from smaller crops
|
||||
scores = 1 / box_area(data["crop_boxes"])
|
||||
keep_by_nms = non_max_supression(
|
||||
data["boxes"].astype(mx.float32),
|
||||
scores,
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
data.to_numpy()
|
||||
return data
|
||||
|
||||
def _process_crop(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
crop_box: List[int],
|
||||
crop_layer_idx: int,
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
# Crop the image and calculate embeddings
|
||||
x0, y0, x1, y1 = crop_box
|
||||
cropped_im = image[y0:y1, x0:x1, :]
|
||||
cropped_im_size = cropped_im.shape[:2]
|
||||
self.predictor.set_image(cropped_im)
|
||||
|
||||
# Get points for this crop
|
||||
points_scale = mx.array(cropped_im_size[::-1])[None]
|
||||
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
||||
|
||||
# Generate masks for this crop in batches
|
||||
data = MaskData()
|
||||
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
||||
batch_data = self._process_batch(
|
||||
points, cropped_im_size, crop_box, orig_size
|
||||
)
|
||||
data.cat(batch_data)
|
||||
del batch_data
|
||||
self.predictor.reset_image()
|
||||
|
||||
# Remove duplicates within this crop.
|
||||
keep_by_nms = non_max_supression(
|
||||
data["boxes"].astype(mx.float32),
|
||||
data["iou_preds"],
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
# Return to the original image frame
|
||||
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
||||
data["points"] = uncrop_points(data["points"], crop_box)
|
||||
data["crop_boxes"] = mx.array([crop_box for _ in range(len(data["rles"]))])
|
||||
return data
|
||||
|
||||
def _process_batch(
|
||||
self,
|
||||
points: np.ndarray,
|
||||
im_size: Tuple[int, ...],
|
||||
crop_box: List[int],
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
masks, iou_preds, _ = self.predictor.predict(
|
||||
points[:, None, :],
|
||||
mx.ones((points.shape[0], 1), dtype=mx.int64),
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
masks = masks.transpose(0, 3, 1, 2)
|
||||
# Serialize predictions and store in MaskData
|
||||
data = MaskData(
|
||||
masks=masks.flatten(0, 1),
|
||||
iou_preds=iou_preds.flatten(0, 1),
|
||||
points=mx.repeat(points, masks.shape[1], axis=0),
|
||||
)
|
||||
del masks
|
||||
|
||||
# Filter by predicted IoU
|
||||
if self.pred_iou_thresh > 0.0:
|
||||
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Calculate stability score
|
||||
data["stability_score"] = calculate_stability_score(
|
||||
data["masks"],
|
||||
self.predictor.model.mask_threshold,
|
||||
self.stability_score_offset,
|
||||
)
|
||||
if self.stability_score_thresh > 0.0:
|
||||
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Threshold masks and calculate boxes
|
||||
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
||||
data["boxes"] = batched_mask_to_box(data["masks"])
|
||||
|
||||
# Filter boxes that touch crop boundaries
|
||||
keep_mask = ~is_box_near_crop_edge(
|
||||
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
||||
)
|
||||
if not mx.all(keep_mask):
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Compress to RLE
|
||||
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
||||
data["rles"] = mask_to_rle_mlx(data["masks"])
|
||||
del data["masks"]
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def postprocess_small_regions(
|
||||
mask_data: MaskData, min_area: int, nms_thresh: float
|
||||
) -> MaskData:
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates.
|
||||
|
||||
Edits mask_data in place.
|
||||
|
||||
Requires open-cv as a dependency.
|
||||
"""
|
||||
if len(mask_data["rles"]) == 0:
|
||||
return mask_data
|
||||
|
||||
# Filter small disconnected regions and holes
|
||||
new_masks = []
|
||||
scores = []
|
||||
for rle in mask_data["rles"]:
|
||||
mask = rle_to_mask(rle)
|
||||
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
||||
unchanged = not changed
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
||||
unchanged = unchanged and not changed
|
||||
|
||||
new_masks.append(mx.array(mask)[None])
|
||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
||||
# so NMS will prefer ones that didn't need postprocessing
|
||||
scores.append(float(unchanged))
|
||||
scores = mx.array(scores)
|
||||
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
masks = mx.concatenate(new_masks, axis=0)
|
||||
boxes = batched_mask_to_box(masks)
|
||||
keep_by_nms = non_max_supression(
|
||||
boxes.astype(mx.float32),
|
||||
scores,
|
||||
iou_threshold=nms_thresh,
|
||||
)
|
||||
# Only recalculate RLEs for masks that have changed
|
||||
for i_mask, keep in enumerate(keep_by_nms):
|
||||
if not keep:
|
||||
continue
|
||||
if scores[i_mask] == 0.0:
|
||||
mask_mlx = masks[i_mask][None]
|
||||
mask_data["rles"][i_mask] = mask_to_rle_mlx(mask_mlx)[0]
|
||||
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
||||
mask_data.filter(keep_by_nms)
|
||||
|
||||
return mask_data
|
||||
|
||||
|
||||
def box_area(boxes: mx.array) -> mx.array:
|
||||
"""
|
||||
Computes the area of a set of bounding boxes, which are specified by their
|
||||
(x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (mx.array[N, 4]): boxes for which the area will be computed. They
|
||||
are expected to be in (x1, y1, x2, y2) format with
|
||||
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
||||
|
||||
Returns:
|
||||
mx.array[N]: the area for each box
|
||||
"""
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
|
||||
def batched_iou(boxes_a: mx.array, boxes_b: mx.array) -> mx.array:
|
||||
"""Compute IoU for batched boxes.
|
||||
|
||||
Args:
|
||||
boxes_a (mx.array): [..., [x1, y1, x2, y2]] sized Mx4
|
||||
boxes_b (mx.array): [..., [x1, y1, x2, y2]] sized Nx4
|
||||
|
||||
Returns:
|
||||
mx.array: MxN
|
||||
"""
|
||||
|
||||
area_a = box_area(boxes_a) # M
|
||||
area_b = box_area(boxes_b) # N
|
||||
|
||||
top_left = mx.maximum(boxes_a[:, None, :2], boxes_b[:, :2])
|
||||
bottom_right = mx.minimum(boxes_a[:, None, 2:], boxes_b[:, 2:])
|
||||
|
||||
area_inter = mx.prod(mx.clip(bottom_right - top_left, a_min=0, a_max=None), 2)
|
||||
|
||||
return area_inter / (area_a[:, None] + area_b - area_inter)
|
||||
|
||||
|
||||
def non_max_supression(
|
||||
boxes: mx.array, scores: mx.array, iou_threshold: float = 0.5
|
||||
) -> mx.array:
|
||||
sort_index = mx.argsort(-scores)
|
||||
boxes = boxes[sort_index]
|
||||
|
||||
n_boxes = boxes.shape[0]
|
||||
ious = batched_iou(boxes, boxes)
|
||||
ious -= mx.eye(n_boxes)
|
||||
|
||||
ious = np.array(ious)
|
||||
keep = np.ones(n_boxes, dtype=np.bool_)
|
||||
for i, iou in enumerate(ious):
|
||||
if not keep[i]:
|
||||
continue
|
||||
|
||||
condition = iou <= iou_threshold
|
||||
keep = keep & condition
|
||||
|
||||
return sort_index[mx.array(np.where(keep)[0])]
|
35
segment_anything/segment_anything/common.py
Normal file
35
segment_anything/segment_anything/common.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Type
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
mlp_dim: int,
|
||||
act: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
||||
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
||||
self.act = act()
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.lin2(self.act(self.lin1(x)))
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = mx.ones(num_channels)
|
||||
self.bias = mx.zeros(num_channels)
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
u = x.mean(3, keepdims=True)
|
||||
s = ((x - u) ** 2).mean(3, keepdims=True)
|
||||
x = (x - u) / mx.sqrt(s + self.eps)
|
||||
x = self.weight * x + self.bias
|
||||
return x
|
422
segment_anything/segment_anything/image_encoder.py
Normal file
422
segment_anything/segment_anything/image_encoder.py
Normal file
@ -0,0 +1,422 @@
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .common import LayerNorm2d, MLPBlock
|
||||
|
||||
|
||||
class ImageEncoderViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
patch_size (int): Patch size.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
depth (int): Depth of ViT.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_abs_pos (bool): If True, use absolute positional embeddings.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks.
|
||||
global_attn_indexes (list): Indexes for blocks using global attention.
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
kernel_size=(patch_size, patch_size),
|
||||
stride=(patch_size, patch_size),
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
if use_abs_pos:
|
||||
# Initialize absolute positional embedding with pretrain image size.
|
||||
self.pos_embed = mx.zeros(
|
||||
[1, img_size // patch_size, img_size // patch_size, embed_dim]
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.layers = []
|
||||
for i in range(depth):
|
||||
block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
window_size=window_size if i not in global_attn_indexes else 0,
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
)
|
||||
self.layers.append(block)
|
||||
|
||||
self.neck = Neck(embed_dim, out_chans)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
||||
for blk in self.layers:
|
||||
x = blk(x)
|
||||
|
||||
x = self.neck(x)
|
||||
return x
|
||||
|
||||
|
||||
class Neck(nn.Module):
|
||||
def __init__(self, embed_dim, out_chans):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
)
|
||||
self.layer_norm1 = LayerNorm2d(out_chans)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_chans,
|
||||
out_chans,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
self.layer_norm2 = LayerNorm2d(out_chans)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.layer_norm2(self.conv2(self.layer_norm1(self.conv1(x))))
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks. If it equals 0, then
|
||||
use global attention.
|
||||
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.layer_norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
input_size=input_size if window_size == 0 else (window_size, window_size),
|
||||
)
|
||||
|
||||
self.layer_norm2 = norm_layer(dim)
|
||||
self.mlp = MLPBlock(
|
||||
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
|
||||
)
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
shortcut = x
|
||||
x = self.layer_norm1(x)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
|
||||
x = self.attn(x)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.layer_norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert (
|
||||
input_size is not None
|
||||
), "Input size must be provided if using relative positional encoding."
|
||||
# initialize relative positional embeddings
|
||||
self.rel_pos_h = mx.zeros(shape=(2 * input_size[0] - 1, head_dim))
|
||||
self.rel_pos_w = mx.zeros(shape=(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B, H * W, 3, self.num_heads, -1)
|
||||
.transpose(2, 0, 3, 1, 4)
|
||||
)
|
||||
|
||||
# q, k, v with shape (B * nHead, H * W, C)
|
||||
qkv = qkv.reshape(3, B * self.num_heads, H * W, -1)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q * self.scale) @ k.transpose(0, 2, 1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn = add_decomposed_rel_pos(
|
||||
attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
|
||||
)
|
||||
|
||||
attn = mx.softmax(attn, axis=-1)
|
||||
x = (
|
||||
(attn @ v)
|
||||
.reshape(B, self.num_heads, H, W, -1)
|
||||
.transpose(0, 2, 3, 1, 4)
|
||||
.reshape(B, H, W, -1)
|
||||
)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x: mx.array, window_size: int) -> Tuple[mx.array, Tuple[int, int]]:
|
||||
"""
|
||||
Partition into non-overlapping windows with padding if needed.
|
||||
Args:
|
||||
x (mx.array): input tokens with [B, H, W, C].
|
||||
window_size (int): window size.
|
||||
|
||||
Returns:
|
||||
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||
(Hp, Wp): padded height and width before partition
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = mx.pad(x, ((0, 0), (0, pad_w), (0, pad_h), (0, 0)))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.reshape(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = x.transpose(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(
|
||||
windows: mx.array,
|
||||
window_size: int,
|
||||
pad_hw: Tuple[int, int],
|
||||
hw: Tuple[int, int],
|
||||
) -> mx.array:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
Args:
|
||||
windows (mx.array): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
pad_hw (Tuple): padded height and width (Hp, Wp).
|
||||
hw (Tuple): original height and width (H, W) before padding.
|
||||
|
||||
Returns:
|
||||
x: unpartitioned sequences with [B, H, W, C].
|
||||
"""
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.reshape(
|
||||
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
||||
)
|
||||
x = x.transpose(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :]
|
||||
return x
|
||||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: mx.array) -> mx.array:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
Args:
|
||||
q_size (int): size of query q.
|
||||
k_size (int): size of key k.
|
||||
rel_pos (mx.array): relative position embeddings (L, C).
|
||||
|
||||
Returns:
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = rel_pos.reshape(1, rel_pos.shape[0], -1).transpose(0, 2, 1)
|
||||
scale_factor = (
|
||||
max_rel_dist[0] / rel_pos_resized.shape[1],
|
||||
max_rel_dist[1] / rel_pos_resized.shape[2],
|
||||
)
|
||||
rel_pos_resized = nn.Upsample(scale_factor=scale_factor, mode="linear")(
|
||||
rel_pos_resized
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).transpose(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = mx.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = mx.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.astype(mx.int64)]
|
||||
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
attn,
|
||||
q,
|
||||
rel_pos_h: mx.array,
|
||||
rel_pos_w: mx.array,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> mx.array:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||
Args:
|
||||
attn (mx.array): attention map.
|
||||
q (mx.array): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (mx.array): relative position embeddings (Lh, C) for height axis.
|
||||
rel_pos_w (mx.array): relative position embeddings (Lw, C) for width axis.
|
||||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
attn (mx.array): attention map with added relative positional embeddings.
|
||||
"""
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
|
||||
# TODO: replace mx.einsum when its ready
|
||||
# workaround for these einsum computations
|
||||
# rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
# rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
rel_h = r_q @ Rh.transpose(0, 2, 1)
|
||||
rel_w = (r_q.transpose(0, 2, 1, 3) @ Rw.transpose(0, 2, 1)).transpose(0, 2, 1, 3)
|
||||
|
||||
attn = (
|
||||
attn.reshape(B, q_h, q_w, k_h, k_w)
|
||||
+ rel_h[:, :, :, :, None]
|
||||
+ rel_w[:, :, :, None, :]
|
||||
).reshape(B, q_h * q_w, k_h * k_w)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
kernel_size (Tuple): kernel size of the projection layer.
|
||||
stride (Tuple): stride of the projection layer.
|
||||
padding (Tuple): padding size of the projection layer.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.projection = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.projection(x)
|
||||
return x
|
252
segment_anything/segment_anything/mask_decoder.py
Normal file
252
segment_anything/segment_anything/mask_decoder.py
Normal file
@ -0,0 +1,252 @@
|
||||
import math
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .common import LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
) -> None:
|
||||
"""
|
||||
Predicts masks given an image and prompt embeddings, using a
|
||||
transformer architecture.
|
||||
|
||||
Args:
|
||||
transformer_dim (int): the channel dimension of the transformer
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict
|
||||
when disambiguating masks
|
||||
activation (nn.Module): the type of activation to use when
|
||||
upscaling masks
|
||||
iou_head_depth (int): the depth of the MLP used to predict
|
||||
mask quality
|
||||
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
||||
used to predict mask quality
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.upscale_conv1 = ConvTranspose2d(
|
||||
transformer_dim,
|
||||
transformer_dim // 4,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=1,
|
||||
)
|
||||
self.upscale_layer_norm = LayerNorm2d(transformer_dim // 4)
|
||||
self.activation = activation()
|
||||
self.upscale_conv2 = ConvTranspose2d(
|
||||
transformer_dim // 4,
|
||||
transformer_dim // 8,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=1,
|
||||
)
|
||||
self.output_hypernetworks_mlps = [
|
||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 1)
|
||||
for i in range(self.num_mask_tokens)
|
||||
]
|
||||
|
||||
self.iou_prediction_head = MLP(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
self.num_mask_tokens,
|
||||
iou_head_depth - 2,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image_embeddings: mx.array,
|
||||
image_pe: mx.array,
|
||||
sparse_prompt_embeddings: mx.array,
|
||||
dense_prompt_embeddings: mx.array,
|
||||
multimask_output: bool,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Args:
|
||||
image_embeddings (mx.array): the embeddings from the image encoder
|
||||
image_pe (mx.array): positional encoding
|
||||
sparse_prompt_embeddings (mx.array): the embeddings of the points and boxes
|
||||
dense_prompt_embeddings (mx.array): the embeddings of the mask inputs
|
||||
multimask_output (bool): Whether to return multiple masks or a single
|
||||
mask.
|
||||
|
||||
Returns:
|
||||
mx.array: batched predicted masks
|
||||
mx.array: batched predictions of mask quality
|
||||
"""
|
||||
masks, iou_pred = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
mask_slice = slice(1, None)
|
||||
else:
|
||||
mask_slice = slice(0, 1)
|
||||
masks = masks[:, :, :, mask_slice]
|
||||
iou_pred = iou_pred[:, mask_slice]
|
||||
|
||||
# Prepare output
|
||||
return masks, iou_pred
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: mx.array,
|
||||
image_pe: mx.array,
|
||||
sparse_prompt_embeddings: mx.array,
|
||||
dense_prompt_embeddings: mx.array,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Predicts masks. See '__call__' for more details."""
|
||||
# Concatenate output tokens
|
||||
output_tokens = mx.concatenate(
|
||||
[self.iou_token.weight, self.mask_tokens.weight], axis=0
|
||||
)
|
||||
output_tokens = mx.broadcast_to(
|
||||
output_tokens[None],
|
||||
[
|
||||
sparse_prompt_embeddings.shape[0],
|
||||
output_tokens.shape[0],
|
||||
output_tokens.shape[1],
|
||||
],
|
||||
)
|
||||
tokens = mx.concatenate((output_tokens, sparse_prompt_embeddings), axis=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
src = mx.repeat(image_embeddings, repeats=tokens.shape[0], axis=0)
|
||||
src = src + dense_prompt_embeddings
|
||||
b, h, w, c = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, image_pe, tokens)
|
||||
iou_token_out = hs[:, 0, :]
|
||||
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.reshape(b, h, w, c)
|
||||
src = self.upscale_conv1(src)
|
||||
src = self.upscale_layer_norm(src)
|
||||
src = self.activation(src)
|
||||
src = self.upscale_conv2(src)
|
||||
upscaled_embedding = self.activation(src)
|
||||
hyper_in_list: List[mx.array] = []
|
||||
for i in range(self.num_mask_tokens):
|
||||
hyper_in_list.append(
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
||||
)
|
||||
hyper_in = mx.stack(hyper_in_list, axis=1)
|
||||
b, h, w, c = upscaled_embedding.shape
|
||||
|
||||
masks = (
|
||||
(hyper_in @ upscaled_embedding.reshape(b, h * w, c).transpose(0, 2, 1))
|
||||
.transpose(0, 2, 1)
|
||||
.reshape(b, h, w, -1)
|
||||
)
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.proj_in = nn.Linear(input_dim, hidden_dim)
|
||||
self.layers = [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)]
|
||||
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
||||
self.sigmoid_output = sigmoid_output
|
||||
|
||||
def __call__(self, x):
|
||||
x = nn.relu(self.proj_in(x))
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.relu(layer(x))
|
||||
x = self.proj_out(x)
|
||||
if self.sigmoid_output:
|
||||
x = mx.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
# TODO: Naive implem. Replace when mlx.nn support conv_transpose
|
||||
class ConvTranspose2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, tuple],
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
dilation: Union[int, tuple] = 1,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
lambda x: (x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
)
|
||||
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, *kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, dilation={self.dilation}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_general(
|
||||
x,
|
||||
self.weight,
|
||||
stride=1,
|
||||
padding=self.padding,
|
||||
kernel_dilation=self.dilation,
|
||||
input_dilation=self.stride,
|
||||
flip=True,
|
||||
)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
170
segment_anything/segment_anything/predictor.py
Normal file
170
segment_anything/segment_anything/predictor.py
Normal file
@ -0,0 +1,170 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from .sam import Sam
|
||||
from .utils.transforms import ResizeLongestSide
|
||||
|
||||
|
||||
class SamPredictor:
|
||||
def __init__(
|
||||
self,
|
||||
sam_model: Sam,
|
||||
) -> None:
|
||||
"""
|
||||
Uses SAM to calculate the image embedding for an image, and then
|
||||
allow repeated, efficient mask prediction given prompts.
|
||||
|
||||
Args:
|
||||
sam_model (Sam): The model to use for mask prediction.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = sam_model
|
||||
self.transform = ResizeLongestSide(sam_model.vision_encoder.img_size)
|
||||
self.reset_image()
|
||||
|
||||
def set_image(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
image_format: str = "RGB",
|
||||
) -> None:
|
||||
"""
|
||||
Calculates the image embeddings for the provided image, allowing
|
||||
masks to be predicted with the 'predict' method.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The image for calculating masks. Expects an
|
||||
image in HWC uint8 format, with pixel values in [0, 255].
|
||||
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
||||
"""
|
||||
self.reset_image()
|
||||
assert image_format in [
|
||||
"RGB",
|
||||
"BGR",
|
||||
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
||||
if image_format != self.model.image_format:
|
||||
image = image[..., ::-1]
|
||||
|
||||
# Transform the image to the form expected by the model
|
||||
input_image = self.transform.apply_image(image)
|
||||
input_image = mx.array(input_image)[None, :, :, :]
|
||||
|
||||
self.original_size = image.shape[:2]
|
||||
self.input_size = input_image.shape[1:3]
|
||||
input_image = self.model.preprocess(input_image)
|
||||
self.features = self.model.vision_encoder(input_image)
|
||||
self.is_image_set = True
|
||||
|
||||
def predict(
|
||||
self,
|
||||
point_coords: Optional[mx.array],
|
||||
point_labels: Optional[mx.array],
|
||||
box: Optional[mx.array] = None,
|
||||
mask_input: Optional[mx.array] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
) -> Tuple[mx.array, mx.array, mx.array]:
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Input prompts are batched mlx tensors and are expected to already be
|
||||
transformed to the input frame using ResizeLongestSide.
|
||||
|
||||
Args:
|
||||
point_coords (mx.array or None): A BxNx2 array of point prompts to the
|
||||
model. Each point is in (X,Y) in pixels.
|
||||
point_labels (mx.array or None): A BxN array of labels for the
|
||||
point prompts. 1 indicates a foreground point and 0 indicates a
|
||||
background point.
|
||||
box (mx.array or None): A size 4 array giving a box prompt to the
|
||||
model, in XYXY format.
|
||||
mask_input (mx.array): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form BxHxWx1, where
|
||||
for SAM, H=W=256. Masks returned by a previous iteration of the
|
||||
predict method do not need further transformation.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
return_logits (bool): If true, returns un-thresholded masks logits
|
||||
instead of a binary mask.
|
||||
|
||||
Returns:
|
||||
(mx.array): The output masks in BxHxWxC format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(mx.array): An array of shape BxC containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(mx.array): An array of shape BxHxWxC, where C is the number
|
||||
of masks and H=W=256. These low res logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
"""
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError(
|
||||
"An image must be set with .set_image(...) before mask prediction."
|
||||
)
|
||||
|
||||
# Transform input prompts
|
||||
points = None
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
), "point_labels must be supplied if point_coords is supplied."
|
||||
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
||||
points = (point_coords, point_labels)
|
||||
if box is not None:
|
||||
box = self.transform.apply_boxes(box, self.original_size)
|
||||
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=box,
|
||||
masks=mask_input,
|
||||
pe_layer=self.model.shared_image_embedding,
|
||||
)
|
||||
|
||||
# Predict masks
|
||||
low_res_masks, iou_predictions = self.model.mask_decoder(
|
||||
image_embeddings=self.features,
|
||||
image_pe=self.model.shared_image_embedding(
|
||||
self.model.prompt_encoder.image_embedding_size
|
||||
),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
# Upscale the masks to the original image resolution
|
||||
masks = self.model.postprocess_masks(
|
||||
low_res_masks, self.input_size, self.original_size
|
||||
)
|
||||
|
||||
if not return_logits:
|
||||
masks = masks > self.model.mask_threshold
|
||||
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
def get_image_embedding(self) -> mx.array:
|
||||
"""
|
||||
Returns the image embeddings for the currently set image, with
|
||||
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
||||
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
||||
"""
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError(
|
||||
"An image must be set with .set_image(...) to generate an embedding."
|
||||
)
|
||||
assert (
|
||||
self.features is not None
|
||||
), "Features must exist if an image has been set."
|
||||
return self.features
|
||||
|
||||
def reset_image(self) -> None:
|
||||
"""Resets the currently set image."""
|
||||
self.is_image_set = False
|
||||
self.features = None
|
||||
self.orig_h = None
|
||||
self.orig_w = None
|
||||
self.input_h = None
|
||||
self.input_w = None
|
229
segment_anything/segment_anything/prompt_encoder.py
Normal file
229
segment_anything/segment_anything/prompt_encoder.py
Normal file
@ -0,0 +1,229 @@
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .common import LayerNorm2d
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
image_embedding_size: Tuple[int, int],
|
||||
input_image_size: Tuple[int, int],
|
||||
mask_in_chans: int,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""
|
||||
Encodes prompts for input to SAM's mask decoder.
|
||||
|
||||
Args:
|
||||
embed_dim (int): The prompts' embedding dimension
|
||||
image_embedding_size (tuple(int, int)): The spatial size of the
|
||||
image embedding, as (H, W).
|
||||
input_image_size (int): The padded size of the image as input
|
||||
to the image encoder, as (H, W).
|
||||
mask_in_chans (int): The number of hidden channels used for
|
||||
encoding input masks.
|
||||
activation (nn.Module): The activation to use when encoding
|
||||
input masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.input_image_size = input_image_size
|
||||
self.image_embedding_size = image_embedding_size
|
||||
|
||||
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
||||
self.point_embed = [
|
||||
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
||||
]
|
||||
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
self.mask_input_size = (
|
||||
4 * image_embedding_size[0],
|
||||
4 * image_embedding_size[1],
|
||||
)
|
||||
self.mask_embed = MaskEmbed(embed_dim, mask_in_chans, activation)
|
||||
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
def _embed_points(
|
||||
self,
|
||||
points: mx.array,
|
||||
labels: mx.array,
|
||||
pad: bool,
|
||||
pe_layer: nn.Module,
|
||||
) -> mx.array:
|
||||
"""Embeds point prompts."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
padding_point = mx.zeros((points.shape[0], 1, 2))
|
||||
padding_label = -mx.ones((labels.shape[0], 1))
|
||||
points = mx.concatenate([points, padding_point], axis=1)
|
||||
labels = mx.concatenate([labels, padding_label], axis=1)
|
||||
point_embedding = pe_layer.forward_with_coords(points, self.input_image_size)
|
||||
point_embedding = mx.where(
|
||||
labels[..., None] == -1,
|
||||
self.not_a_point_embed.weight[:, None],
|
||||
point_embedding,
|
||||
)
|
||||
point_embedding = mx.where(
|
||||
labels[..., None] == 0,
|
||||
point_embedding + self.point_embed[0].weight[:, None],
|
||||
point_embedding,
|
||||
)
|
||||
point_embedding = mx.where(
|
||||
labels[..., None] == 1,
|
||||
point_embedding + self.point_embed[1].weight[:, None],
|
||||
point_embedding,
|
||||
)
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: mx.array, pe_layer: nn.Module) -> mx.array:
|
||||
"""Embeds box prompts."""
|
||||
boxes = boxes + 0.5 # Shift to center of pixel
|
||||
coords = boxes.reshape(-1, 2, 2)
|
||||
corner_embedding = pe_layer.forward_with_coords(coords, self.input_image_size)
|
||||
corner_embedding[:, 0, :] += self.point_embed[2].weight
|
||||
corner_embedding[:, 1, :] += self.point_embed[3].weight
|
||||
return corner_embedding
|
||||
|
||||
def _embed_masks(self, masks: mx.array) -> mx.array:
|
||||
"""Embeds mask inputs."""
|
||||
mask_embedding = self.mask_embed(masks)
|
||||
return mask_embedding
|
||||
|
||||
def _get_batch_size(
|
||||
self,
|
||||
points: Optional[Tuple[mx.array, mx.array]],
|
||||
boxes: Optional[mx.array],
|
||||
masks: Optional[mx.array],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
return boxes.shape[0]
|
||||
elif masks is not None:
|
||||
return masks.shape[0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
points: Optional[Tuple[mx.array, mx.array]],
|
||||
boxes: Optional[mx.array],
|
||||
masks: Optional[mx.array],
|
||||
pe_layer: nn.Module,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Embeds different types of prompts, returning both sparse and dense
|
||||
embeddings.
|
||||
|
||||
Args:
|
||||
points (tuple(mx.array, mx.array) or none): point coordinates
|
||||
and labels to embed
|
||||
boxes (mx.array or none): boxes to embed
|
||||
masks (mx.array or none): masks to embed
|
||||
pe_layer (PositionEmbeddingRandom): shared position embedding
|
||||
layer
|
||||
|
||||
Returns:
|
||||
mx.array: sparse embeddings for the points and boxes, with shape
|
||||
BxNx(embed_dim), where N is determined by the number of input points
|
||||
and boxes.
|
||||
mx.array: dense embeddings for the masks, in the shape
|
||||
Bx(embed_H)x(embed_W)x(embed_dim)
|
||||
"""
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = mx.zeros((bs, 0, self.embed_dim))
|
||||
if points is not None:
|
||||
coords, labels = points
|
||||
point_embeddings = self._embed_points(
|
||||
coords, labels, pad=(boxes is None), pe_layer=pe_layer
|
||||
)
|
||||
sparse_embeddings = mx.concatenate(
|
||||
[sparse_embeddings, point_embeddings], axis=1
|
||||
)
|
||||
if boxes is not None:
|
||||
box_embeddings = self._embed_boxes(boxes, pe_layer=pe_layer)
|
||||
sparse_embeddings = mx.concatenate(
|
||||
[sparse_embeddings, box_embeddings], axis=1
|
||||
)
|
||||
|
||||
if masks is not None:
|
||||
dense_embeddings = self._embed_masks(masks)
|
||||
else:
|
||||
dense_embeddings = mx.broadcast_to(
|
||||
self.no_mask_embed.weight,
|
||||
shape=(
|
||||
bs,
|
||||
self.image_embedding_size[0],
|
||||
self.image_embedding_size[1],
|
||||
self.embed_dim,
|
||||
),
|
||||
)
|
||||
|
||||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
class MaskEmbed(nn.Module):
|
||||
def __init__(self, embed_dim, mask_in_chans, activation):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2)
|
||||
self.layer_norm1 = LayerNorm2d(mask_in_chans // 4)
|
||||
self.conv2 = nn.Conv2d(
|
||||
mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2
|
||||
)
|
||||
self.layer_norm2 = LayerNorm2d(mask_in_chans)
|
||||
self.conv3 = nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1)
|
||||
self.activation = activation()
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.activation(self.layer_norm1(self.conv1(x)))
|
||||
x = self.activation(self.layer_norm2(self.conv2(x)))
|
||||
return self.conv3(x)
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
self.positional_embedding = scale * mx.random.normal((2, num_pos_feats))
|
||||
|
||||
def _pe_encoding(self, coords: mx.array) -> mx.array:
|
||||
"""Positionally encode points that are normalized to [0,1]."""
|
||||
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
||||
coords = 2 * coords - 1
|
||||
coords = coords @ self.positional_embedding
|
||||
coords = 2 * mx.pi * coords
|
||||
# outputs d_1 x ... x d_n x C shape
|
||||
return mx.concatenate([mx.sin(coords), mx.cos(coords)], axis=-1)
|
||||
|
||||
def __call__(self, size: Tuple[int, int]) -> mx.array:
|
||||
"""Generate positional encoding for a grid of the specified size."""
|
||||
h, w = size
|
||||
grid = mx.ones((h, w), dtype=mx.float32)
|
||||
y_embed = grid.cumsum(axis=0) - 0.5
|
||||
x_embed = grid.cumsum(axis=1) - 0.5
|
||||
y_embed = y_embed / h
|
||||
x_embed = x_embed / w
|
||||
|
||||
pe = self._pe_encoding(mx.stack([x_embed, y_embed], axis=-1))
|
||||
return pe # HWC
|
||||
|
||||
def forward_with_coords(
|
||||
self, coords_input: mx.array, image_size: Tuple[int, int]
|
||||
) -> mx.array:
|
||||
"""Positionally encode points that are not normalized to [0,1]."""
|
||||
coords = coords_input * 1
|
||||
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
||||
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
||||
return self._pe_encoding(coords.astype(mx.float32)) # B x N x C
|
240
segment_anything/segment_anything/sam.py
Normal file
240
segment_anything/segment_anything/sam.py
Normal file
@ -0,0 +1,240 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .image_encoder import ImageEncoderViT
|
||||
from .mask_decoder import MaskDecoder
|
||||
from .prompt_encoder import PositionEmbeddingRandom, PromptEncoder
|
||||
from .transformer import TwoWayTransformer
|
||||
|
||||
|
||||
class Sam(nn.Module):
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = "RGB"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_encoder: ImageEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
||||
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
||||
) -> None:
|
||||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
|
||||
Args:
|
||||
vision_encoder (ImageEncoderViT): The backbone used to encode the
|
||||
image into image embeddings that allow for efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
||||
and encoded prompts.
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
"""
|
||||
super().__init__()
|
||||
self.vision_encoder = vision_encoder
|
||||
self.prompt_encoder = prompt_encoder
|
||||
self.mask_decoder = mask_decoder
|
||||
self._pixel_mean = mx.array(pixel_mean).reshape(1, 1, -1)
|
||||
self._pixel_std = mx.array(pixel_std).reshape(1, 1, -1)
|
||||
self.shared_image_embedding = PositionEmbeddingRandom(
|
||||
prompt_encoder.embed_dim // 2
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
batched_input: List[Dict[str, Any]],
|
||||
multimask_output: bool,
|
||||
) -> List[Dict[str, mx.array]]:
|
||||
"""
|
||||
Predicts masks end-to-end from provided images and prompts.
|
||||
If prompts are not known in advance, using SamPredictor is
|
||||
recommended over calling the model directly.
|
||||
|
||||
Args:
|
||||
batched_input (list(dict)): A list over input images, each a
|
||||
dictionary with the following keys. A prompt key can be
|
||||
excluded if it is not present.
|
||||
'image': The image as a mlx tensor in HxWx3 format,
|
||||
already transformed for input to the model.
|
||||
'original_size': (tuple(int, int)) The original size of
|
||||
the image before transformation, as (H, W).
|
||||
'point_coords': (mx.array) Batched point prompts for
|
||||
this image, with shape BxNx2. Already transformed to the
|
||||
input frame of the model.
|
||||
'point_labels': (mx.array) Batched labels for point prompts,
|
||||
with shape BxN.
|
||||
'boxes': (mx.array) Batched box inputs, with shape Bx4.
|
||||
Already transformed to the input frame of the model.
|
||||
'mask_inputs': (mx.array) Batched mask inputs to the model,
|
||||
in the form BxHxWx1.
|
||||
multimask_output (bool): Whether the model should predict multiple
|
||||
disambiguating masks, or return a single mask.
|
||||
|
||||
Returns:
|
||||
(list(dict)): A list over input images, where each element is
|
||||
as dictionary with the following keys.
|
||||
'masks': (mx.array) Batched binary mask predictions,
|
||||
with shape BxCxHxW, where B is the number of input prompts,
|
||||
C is determined by multimask_output, and (H, W) is the
|
||||
original size of the image.
|
||||
'iou_predictions': (mx.array) The model's predictions
|
||||
of mask quality, in shape BxC.
|
||||
'low_res_logits': (mx.array) Low resolution logits with
|
||||
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
||||
to subsequent iterations of prediction.
|
||||
"""
|
||||
input_images = mx.stack(
|
||||
[self.preprocess(x["image"]) for x in batched_input], axis=0
|
||||
)
|
||||
image_embeddings = self.vision_encoder(input_images)
|
||||
|
||||
outputs = []
|
||||
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
||||
if "point_coords" in image_record:
|
||||
points = (image_record["point_coords"], image_record["point_labels"])
|
||||
else:
|
||||
points = None
|
||||
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
||||
points=points,
|
||||
boxes=image_record.get("boxes", None),
|
||||
masks=image_record.get("mask_inputs", None),
|
||||
pe_layer=self.shared_image_embedding,
|
||||
)
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
image_embeddings=curr_embedding[None],
|
||||
image_pe=self.shared_image_embedding(
|
||||
self.prompt_encoder.image_embedding_size
|
||||
),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
masks = self.postprocess_masks(
|
||||
low_res_masks,
|
||||
input_size=image_record["image"].shape[-3:-1],
|
||||
original_size=image_record["original_size"],
|
||||
)
|
||||
masks = masks > self.mask_threshold
|
||||
outputs.append(
|
||||
{
|
||||
"masks": masks,
|
||||
"iou_predictions": iou_predictions,
|
||||
"low_res_logits": low_res_masks,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
def postprocess_masks(
|
||||
self,
|
||||
masks: mx.array,
|
||||
input_size: Tuple[int, ...],
|
||||
original_size: Tuple[int, ...],
|
||||
) -> mx.array:
|
||||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Args:
|
||||
masks (mx.array): Batched masks from the mask_decoder,
|
||||
in BxHxWxC format.
|
||||
input_size (tuple(int, int)): The size of the image input to the
|
||||
model, in (H, W) format. Used to remove padding.
|
||||
original_size (tuple(int, int)): The original size of the image
|
||||
before resizing for input to the model, in (H, W) format.
|
||||
|
||||
Returns:
|
||||
(mx.array): Batched masks in BxCxHxW format, where (H, W)
|
||||
is given by original_size.
|
||||
"""
|
||||
scale_factor = (
|
||||
self.vision_encoder.img_size / masks.shape[1],
|
||||
self.vision_encoder.img_size / masks.shape[2],
|
||||
)
|
||||
masks = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=False
|
||||
)(masks)
|
||||
masks = masks[:, : input_size[0], : input_size[1]]
|
||||
scale_factor = (
|
||||
original_size[0] / masks.shape[1],
|
||||
original_size[1] / masks.shape[2],
|
||||
)
|
||||
masks = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=False
|
||||
)(masks)
|
||||
return masks
|
||||
|
||||
def preprocess(self, x: mx.array) -> mx.array:
|
||||
"""Normalize pixel values and pad to a square input."""
|
||||
# Normalize colors
|
||||
x = (x - self._pixel_mean) / self._pixel_std
|
||||
|
||||
# Pad
|
||||
h, w = x.shape[-3:-1]
|
||||
padh = self.vision_encoder.img_size - h
|
||||
padw = self.vision_encoder.img_size - w
|
||||
|
||||
if x.ndim == 3:
|
||||
pad_width = [(0, padh), (0, padw), (0, 0)]
|
||||
elif x.ndim == 4:
|
||||
pad_width = [(0, 0), (0, padh), (0, padw), (0, 0)]
|
||||
else:
|
||||
raise Exception("x.ndim can only be 3 or 4.")
|
||||
|
||||
x = mx.pad(x, pad_width)
|
||||
return x
|
||||
|
||||
|
||||
def load(model_path):
|
||||
model_path = Path(model_path)
|
||||
with open(model_path / "config.json", "r") as fid:
|
||||
config = json.load(fid)
|
||||
encoder_embed_dim = config["vision_config"]["hidden_size"]
|
||||
encoder_depth = config["vision_config"]["num_hidden_layers"]
|
||||
encoder_num_heads = config["vision_config"]["num_attention_heads"]
|
||||
encoder_global_attn_indexes = config["vision_config"]["global_attn_indexes"]
|
||||
prompt_embed_dim = 256
|
||||
image_size = 1024
|
||||
vit_patch_size = 16
|
||||
image_embedding_size = image_size // vit_patch_size
|
||||
sam = Sam(
|
||||
vision_encoder=ImageEncoderViT(
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=prompt_embed_dim,
|
||||
),
|
||||
prompt_encoder=PromptEncoder(
|
||||
embed_dim=prompt_embed_dim,
|
||||
image_embedding_size=(image_embedding_size, image_embedding_size),
|
||||
input_image_size=(image_size, image_size),
|
||||
mask_in_chans=16,
|
||||
),
|
||||
mask_decoder=MaskDecoder(
|
||||
num_multimask_outputs=3,
|
||||
transformer=TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=prompt_embed_dim,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
),
|
||||
)
|
||||
sam.load_weights(str(model_path / "model.safetensors"), strict=True)
|
||||
return sam
|
235
segment_anything/segment_anything/transformer.py
Normal file
235
segment_anything/segment_anything/transformer.py
Normal file
@ -0,0 +1,235 @@
|
||||
import math
|
||||
from typing import Tuple, Type
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .common import MLPBlock
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using
|
||||
queries whose positional embedding is supplied.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
embedding_dim (int): the channel dimension for the input embeddings
|
||||
num_heads (int): the number of heads for multihead attention. Must
|
||||
divide embedding_dim
|
||||
mlp_dim (int): the channel dimension internal to the MLP block
|
||||
activation (nn.Module): the activation to use in the MLP block
|
||||
"""
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
self.layers = []
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TwoWayAttentionBlock(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_dim=mlp_dim,
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
)
|
||||
)
|
||||
|
||||
self.final_attn_token_to_image = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.layer_norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image_embedding: mx.array,
|
||||
image_pe: mx.array,
|
||||
point_embedding: mx.array,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Args:
|
||||
image_embedding (mx.array): image to attend to. Should be shape
|
||||
B x h x w x embedding_dim for any h and w.
|
||||
image_pe (mx.array): the positional encoding to add to the image. Must
|
||||
have the same shape as image_embedding.
|
||||
point_embedding (mx.array): the embedding to add to the query points.
|
||||
Must have shape B x N_points x embedding_dim for any N_points.
|
||||
|
||||
Returns:
|
||||
mx.array: the processed point_embedding
|
||||
mx.array: the processed image_embedding
|
||||
"""
|
||||
# BxHxWxC -> BxHWxC == B x N_image_tokens x C
|
||||
bs, h, w, c = image_embedding.shape
|
||||
image_embedding = image_embedding.reshape(bs, h * w, c)
|
||||
image_pe = image_pe.reshape(h * w, c)
|
||||
|
||||
# Prepare queries
|
||||
queries = point_embedding
|
||||
keys = image_embedding
|
||||
# Apply transformer blocks and final layernorm
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_pe=point_embedding,
|
||||
key_pe=image_pe,
|
||||
)
|
||||
|
||||
# Apply the final attention layer from the points to the image
|
||||
q = queries + point_embedding
|
||||
k = keys + image_pe
|
||||
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.layer_norm_final_attn(queries)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int = 2048,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer block with four layers: (1) self-attention of sparse
|
||||
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
||||
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
inputs.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): the channel dimension of the embeddings
|
||||
num_heads (int): the number of heads in the attention layers
|
||||
mlp_dim (int): the hidden dimension of the mlp block
|
||||
activation (nn.Module): the activation of the mlp block
|
||||
skip_first_layer_pe (bool): skip the PE on the first layer
|
||||
"""
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
self.layer_norm1 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.cross_attn_token_to_image = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
||||
self.layer_norm3 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.layer_norm4 = nn.LayerNorm(embedding_dim)
|
||||
self.cross_attn_image_to_token = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def __call__(
|
||||
self, queries: mx.array, keys: mx.array, query_pe: mx.array, key_pe: mx.array
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.self_attn(q=queries, k=queries, v=queries)
|
||||
else:
|
||||
q = queries + query_pe
|
||||
attn_out = self.self_attn(q=q, k=q, v=queries)
|
||||
queries = queries + attn_out
|
||||
queries = self.layer_norm1(queries)
|
||||
|
||||
# Cross attention block, tokens attending to image embedding
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.layer_norm2(queries)
|
||||
|
||||
# MLP block
|
||||
mlp_out = self.mlp(queries)
|
||||
queries = queries + mlp_out
|
||||
queries = self.layer_norm3(queries)
|
||||
|
||||
# Cross attention block, image embedding attending to tokens
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
||||
keys = keys + attn_out
|
||||
keys = self.layer_norm4(keys)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding
|
||||
after projection to queries, keys, and values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
assert (
|
||||
self.internal_dim % num_heads == 0
|
||||
), "num_heads must divide embedding_dim."
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
def _separate_heads(self, x: mx.array, num_heads: int) -> mx.array:
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(0, 2, 1, 3) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
def _recombine_heads(self, x: mx.array) -> mx.array:
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(0, 2, 1, 3)
|
||||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def __call__(self, q: mx.array, k: mx.array, v: mx.array) -> mx.array:
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
|
||||
# Separate into heads
|
||||
q = self._separate_heads(q, self.num_heads)
|
||||
k = self._separate_heads(k, self.num_heads)
|
||||
v = self._separate_heads(v, self.num_heads)
|
||||
|
||||
# Attention
|
||||
_, _, _, c_per_head = q.shape
|
||||
attn = q @ k.transpose(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
||||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = mx.softmax(attn, axis=-1)
|
||||
|
||||
# Get output
|
||||
out = attn @ v
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
0
segment_anything/segment_anything/utils/__init__.py
Normal file
0
segment_anything/segment_anything/utils/__init__.py
Normal file
348
segment_anything/segment_anything/utils/amg.py
Normal file
348
segment_anything/segment_anything/utils/amg.py
Normal file
@ -0,0 +1,348 @@
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MaskData:
|
||||
"""
|
||||
A structure for storing masks and their related data in batched format.
|
||||
Implements basic filtering and concatenation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
for v in kwargs.values():
|
||||
assert isinstance(
|
||||
v, (list, np.ndarray, mx.array)
|
||||
), "MaskData only supports list, numpy arrays, and mlx arrays."
|
||||
self._stats = dict(**kwargs)
|
||||
|
||||
def __setitem__(self, key: str, item: Any) -> None:
|
||||
assert isinstance(
|
||||
item, (list, np.ndarray, mx.array)
|
||||
), "MaskData only supports list, numpy arrays, and mlx arrays."
|
||||
self._stats[key] = item
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._stats[key]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._stats[key]
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
return self._stats.items()
|
||||
|
||||
def filter(self, keep: mx.array) -> None:
|
||||
if keep.dtype == mx.bool_:
|
||||
keep = mx.array(np.where(keep)[0])
|
||||
for k, v in self._stats.items():
|
||||
if v is None:
|
||||
self._stats[k] = None
|
||||
elif isinstance(v, mx.array):
|
||||
self._stats[k] = v[keep]
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = v[keep]
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = [v[i] for i in keep.tolist()]
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def cat(self, new_stats: "MaskData") -> None:
|
||||
for k, v in new_stats.items():
|
||||
if k not in self._stats or self._stats[k] is None:
|
||||
self._stats[k] = deepcopy(v)
|
||||
elif isinstance(v, mx.array):
|
||||
self._stats[k] = mx.concatenate([self._stats[k], v], axis=0)
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = self._stats[k] + deepcopy(v)
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if isinstance(v, mx.array):
|
||||
self._stats[k] = np.array(v)
|
||||
|
||||
|
||||
def is_box_near_crop_edge(
|
||||
boxes: mx.array, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
||||
) -> mx.array:
|
||||
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||
crop_box_mlx = mx.array(crop_box, dtype=mx.float32)
|
||||
orig_box_mlx = mx.array(orig_box, dtype=mx.float32)
|
||||
boxes = uncrop_boxes_xyxy(boxes, crop_box).astype(mx.float32)
|
||||
near_crop_edge = mx.isclose(boxes, crop_box_mlx[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = mx.isclose(boxes, orig_box_mlx[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = mx.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return mx.any(near_crop_edge, axis=1)
|
||||
|
||||
|
||||
def box_xyxy_to_xywh(box_xyxy: mx.array) -> mx.array:
|
||||
box_xywh = deepcopy(box_xyxy)
|
||||
box_xywh[2] = box_xywh[2] - box_xywh[0]
|
||||
box_xywh[3] = box_xywh[3] - box_xywh[1]
|
||||
return box_xywh
|
||||
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||
assert len(args) > 0 and all(
|
||||
len(a) == len(args[0]) for a in args
|
||||
), "Batched iteration must have inputs of all the same size."
|
||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
||||
|
||||
|
||||
def mask_to_rle_mlx(tensor: mx.array) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Encodes masks to an uncompressed RLE, in the format expected by
|
||||
pycoco tools.
|
||||
"""
|
||||
# Put in fortran order and flatten h,w
|
||||
b, h, w = tensor.shape
|
||||
tensor = mx.transpose(tensor, axes=(0, 2, 1)).flatten(1)
|
||||
|
||||
# Compute change indices
|
||||
diff = mx.bitwise_xor(tensor[:, 1:], tensor[:, :-1])
|
||||
# TODO: fix this with mlx
|
||||
change_indices = np.stack(np.array(diff).nonzero(), axis=1)
|
||||
|
||||
# Encode run length
|
||||
out = []
|
||||
for i in range(b):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
|
||||
cur_idxs = mx.array(cur_idxs)
|
||||
cur_idxs = mx.concatenate(
|
||||
[
|
||||
mx.array([0], dtype=cur_idxs.dtype),
|
||||
cur_idxs + 1,
|
||||
mx.array([h * w], dtype=cur_idxs.dtype),
|
||||
]
|
||||
)
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if tensor[i, 0] == 0 else [0]
|
||||
counts.extend(btw_idxs.tolist())
|
||||
out.append({"size": [h, w], "counts": counts})
|
||||
return out
|
||||
|
||||
|
||||
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||
"""Compute a binary mask from an uncompressed RLE."""
|
||||
h, w = rle["size"]
|
||||
mask = np.empty(h * w, dtype=bool)
|
||||
idx = 0
|
||||
parity = False
|
||||
for count in rle["counts"]:
|
||||
mask[idx : idx + count] = parity
|
||||
idx += count
|
||||
parity ^= True
|
||||
mask = mask.reshape(w, h)
|
||||
return mask.transpose() # Put in C order
|
||||
|
||||
|
||||
def area_from_rle(rle: Dict[str, Any]) -> int:
|
||||
return sum(rle["counts"][1::2])
|
||||
|
||||
|
||||
def calculate_stability_score(
|
||||
masks: mx.array, mask_threshold: float, threshold_offset: float
|
||||
) -> mx.array:
|
||||
"""
|
||||
Computes the stability score for a batch of masks. The stability
|
||||
score is the IoU between the binary masks obtained by thresholding
|
||||
the predicted mask logits at high and low values.
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to mx.int64
|
||||
|
||||
# COMMENT OUT DTYPE CASTING FOR COREML
|
||||
intersections = (
|
||||
(masks > (mask_threshold + threshold_offset))
|
||||
.astype(mx.int16)
|
||||
.sum(-1)
|
||||
.astype(mx.int32)
|
||||
.sum(-1)
|
||||
)
|
||||
unions = (
|
||||
(masks > (mask_threshold - threshold_offset))
|
||||
.astype(mx.int16)
|
||||
.sum(-1)
|
||||
.astype(mx.int32)
|
||||
.sum(-1)
|
||||
)
|
||||
return intersections / unions
|
||||
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||
return points
|
||||
|
||||
|
||||
def build_all_layer_point_grids(
|
||||
n_per_side: int, n_layers: int, scale_per_layer: int
|
||||
) -> List[mx.array]:
|
||||
"""Generates point grids for all crop layers."""
|
||||
points_by_layer = []
|
||||
for i in range(n_layers + 1):
|
||||
n_points = int(n_per_side / (scale_per_layer**i))
|
||||
points_by_layer.append(mx.array(build_point_grid(n_points)))
|
||||
return points_by_layer
|
||||
|
||||
|
||||
def generate_crop_boxes(
|
||||
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer
|
||||
has (2**i)**2 boxes for the ith layer.
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
||||
# Original image
|
||||
crop_boxes.append([0, 0, im_w, im_h])
|
||||
layer_idxs.append(0)
|
||||
|
||||
def crop_len(orig_len, n_crops, overlap):
|
||||
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
||||
|
||||
for i_layer in range(n_layers):
|
||||
n_crops_per_side = 2 ** (i_layer + 1)
|
||||
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||
|
||||
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
||||
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
||||
|
||||
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
||||
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
||||
|
||||
# Crops in XYWH format
|
||||
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
||||
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
||||
crop_boxes.append(box)
|
||||
layer_idxs.append(i_layer + 1)
|
||||
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def uncrop_boxes_xyxy(boxes: mx.array, crop_box: List[int]) -> mx.array:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = mx.array([[x0, y0, x0, y0]])
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return boxes + offset
|
||||
|
||||
|
||||
def uncrop_points(points: mx.array, crop_box: List[int]) -> mx.array:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = mx.array([[x0, y0]])
|
||||
# Check if points has a channel dimension
|
||||
if len(points.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return points + offset
|
||||
|
||||
|
||||
def uncrop_masks(
|
||||
masks: mx.array, crop_box: List[int], orig_h: int, orig_w: int
|
||||
) -> mx.array:
|
||||
x0, y0, x1, y1 = crop_box
|
||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
||||
pad = [(0, 0), (y0, pad_y - y0), (x0, pad_x - x0)]
|
||||
return mx.pad(masks, pad, 0)
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
|
||||
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
|
||||
from pycocotools import mask as mask_utils # type: ignore
|
||||
|
||||
h, w = uncompressed_rle["size"]
|
||||
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
|
||||
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
|
||||
return rle
|
||||
|
||||
|
||||
def batched_mask_to_box(masks: mx.array) -> mx.array:
|
||||
"""
|
||||
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
||||
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
"""
|
||||
# mx.max below raises an error on empty inputs, just skip in this case
|
||||
if np.prod(masks.shape) == 0:
|
||||
return mx.zeros(*masks.shape[:-2], 4)
|
||||
|
||||
# Normalize shape to CxHxW
|
||||
shape = masks.shape
|
||||
h, w = shape[-2:]
|
||||
if len(shape) > 2:
|
||||
masks = masks.flatten(0, -3)
|
||||
else:
|
||||
masks = masks.unsqueeze(0)
|
||||
|
||||
# Get top and bottom edges
|
||||
in_height = mx.max(masks, axis=-1)
|
||||
in_height_coords = in_height * mx.arange(h)[None, :]
|
||||
bottom_edges = mx.max(in_height_coords, axis=-1)
|
||||
in_height_coords = in_height_coords + h * (~in_height)
|
||||
top_edges = mx.min(in_height_coords, axis=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width = mx.max(masks, axis=-2)
|
||||
in_width_coords = in_width * mx.arange(w)[None, :]
|
||||
right_edges = mx.max(in_width_coords, axis=-1)
|
||||
in_width_coords = in_width_coords + w * (~in_width)
|
||||
left_edges = mx.min(in_width_coords, axis=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = mx.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
|
||||
out = out * (~empty_filter)[..., None]
|
||||
|
||||
# Return to original shape
|
||||
if len(shape) > 2:
|
||||
out = out.reshape(*shape[:-2], 4)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
return out
|
65
segment_anything/segment_anything/utils/transforms.py
Normal file
65
segment_anything/segment_anything/utils/transforms.py
Normal file
@ -0,0 +1,65 @@
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ResizeLongestSide:
|
||||
"""
|
||||
Resizes images to the longest side 'target_length', as well as provides
|
||||
methods for resizing coordinates and boxes. Provides methods for
|
||||
transforming both numpy array and batched mlx tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, target_length: int) -> None:
|
||||
self.target_length = target_length
|
||||
|
||||
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array with shape HxWxC in uint8 format.
|
||||
"""
|
||||
target_size = self.get_preprocess_shape(
|
||||
image.shape[0], image.shape[1], self.target_length
|
||||
)
|
||||
return np.array(
|
||||
Image.fromarray(image).resize(
|
||||
target_size[::-1], resample=Image.Resampling.BILINEAR
|
||||
)
|
||||
)
|
||||
|
||||
def apply_coords(
|
||||
self, coords: mx.array, original_size: Tuple[int, ...]
|
||||
) -> mx.array:
|
||||
"""
|
||||
Expects a mlx tensor with length 2 in the last dimension. Requires the
|
||||
original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(
|
||||
original_size[0], original_size[1], self.target_length
|
||||
)
|
||||
return coords * mx.array([new_w / old_w, new_h / old_h])
|
||||
|
||||
def apply_boxes(self, boxes: mx.array, original_size: Tuple[int, ...]) -> mx.array:
|
||||
"""
|
||||
Expects a mlx tensor with shape ...x4. Requires the original image
|
||||
size in (H, W) format.
|
||||
"""
|
||||
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocess_shape(
|
||||
oldh: int, oldw: int, long_side_length: int
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Compute the output size given input size and target long side length.
|
||||
"""
|
||||
scale = long_side_length * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
return (newh, neww)
|
Loading…
Reference in New Issue
Block a user