Segment Anything Model (#552)

* add segment anything model

* add readme

* reorg file structure

* update

* lint

* minor updates

* ack

* fix weight loading

* simplify

* fix to run notebooks

* amg in mlx

* remove torch dependency

* nit in README

* return indices in nms

* simplify

* bugfix / simplify

* fix bug'

* simplify

* fix notebook and remove output

* couple more nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Shiyu
2024-06-03 07:45:51 +08:00
committed by GitHub
parent 89b0b75250
commit 8353bbbf93
22 changed files with 3667 additions and 0 deletions

View File

@@ -0,0 +1,257 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Automatically generating object masks with SAM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook walks through how to automatically segment objects in an image. It is modified from [original SAM GitHub repo](https://github.com/facebookresearch/segment-anything/).\n",
"\n",
"Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image. This method was used to generate the dataset SA-1B. \n",
"\n",
"The class `SamAutomaticMaskGenerator` implements this. It samples single-point input prompts in a grid over the image, from each of which SAM then predicts multiple masks. The masks are filtered for quality and deduplicated using non-max suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set-up"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def show_anns(anns):\n",
" if len(anns) == 0:\n",
" return\n",
" sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)\n",
" ax = plt.gca()\n",
" ax.set_autoscale_on(False)\n",
"\n",
" img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))\n",
" img[:,:,3] = 0\n",
" for ann in sorted_anns:\n",
" m = ann['segmentation']\n",
" color_mask = np.concatenate([np.random.random(3), [0.35]])\n",
" img[m] = color_mask\n",
" ax.imshow(img)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image = cv2.imread('images/dog.jpg')\n",
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(20,20))\n",
"plt.imshow(image)\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Automatic mask generation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"..\")\n",
"from segment_anything import SamAutomaticMaskGenerator\n",
"from segment_anything.sam import load\n",
"\n",
"sam_checkpoint = \"../sam-vit-base\"\n",
"sam = load(sam_checkpoint)\n",
"\n",
"mask_generator = SamAutomaticMaskGenerator(sam)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To generate masks, run `generate` on an image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"masks = mask_generator.generate(image)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Mask generation returns a list over masks. Each item is a dictionary with keys:\n",
"* `segmentation` : the mask\n",
"* `area` : the area of the mask in pixels\n",
"* `bbox` : the boundary box of the mask in XYWH format\n",
"* `predicted_iou` : the model's own prediction for the quality of the mask\n",
"* `point_coords` : the sampled input point that generated this mask\n",
"* `stability_score` : an additional measure of mask quality\n",
"* `crop_box` : the crop of the image used to generate this mask in XYWH format"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(len(masks))\n",
"print(masks[0].keys())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Show all the masks overlayed on the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(20,20))\n",
"plt.imshow(image)\n",
"show_anns(masks)\n",
"plt.axis('off')\n",
"plt.show() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Automatic mask generation options"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Generation can be automatically run on crops of the image to get better results for smaller objects. Post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mask_generator_2 = SamAutomaticMaskGenerator(\n",
" model=sam,\n",
" points_per_side=32,\n",
" pred_iou_thresh=0.86,\n",
" stability_score_thresh=0.92,\n",
" crop_n_layers=1,\n",
" crop_n_points_downscale_factor=2,\n",
" min_mask_region_area=100, # Requires open-cv to run post-processing\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"masks2 = mask_generator_2.generate(image)\n",
"len(masks2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(20,20))\n",
"plt.imshow(image)\n",
"show_anns(masks2)\n",
"plt.axis('off')\n",
"plt.show() "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 164 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

View File

@@ -0,0 +1,629 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Segmenting from Prompts\n",
"\n",
"This notebook walks through predicting object segmentations from a provided prompt. It uses the `Predictor` class. It is modified from [original SAM GitHub repo](https://github.com/facebookresearch/segment-anything/).\n",
"\n",
"### Setup\n",
"Necessary imports and helper functions for displaying points, boxes, and masks."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"import mlx.core as mx\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def show_mask(mask, ax, random_color=False):\n",
" if random_color:\n",
" color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
" else:\n",
" color = np.array([30/255, 144/255, 255/255, 0.6])\n",
" h, w = mask.shape[:2]\n",
" mask_image = np.array(mask).reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
" ax.imshow(mask_image)\n",
" \n",
"def show_points(coords, labels, ax, marker_size=375):\n",
" pos_points = np.array(coords)[labels==1]\n",
" neg_points = np.array(coords)[labels==0]\n",
" ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
" ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n",
" \n",
"def show_box(box, ax):\n",
" box = box.tolist()\n",
" x0, y0 = box[0], box[1]\n",
" w, h = box[2] - box[0], box[3] - box[1]\n",
" ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image = cv2.imread('images/truck.jpg')\n",
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10,10))\n",
"plt.imshow(image)\n",
"plt.axis('on')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Selecting objects with SAM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"..\")\n",
"from segment_anything.sam import load\n",
"from segment_anything.predictor import SamPredictor\n",
"\n",
"sam_checkpoint = \"../sam-vit-base\"\n",
"sam = load(sam_checkpoint)\n",
"predictor = SamPredictor(sam)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predictor.set_image(image)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_point = mx.array([[500, 375]])\n",
"input_label = mx.array([1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10,10))\n",
"plt.imshow(image)\n",
"show_points(input_point, input_label, plt.gca())\n",
"plt.axis('on')\n",
"plt.show() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict with `SamPredictor.predict`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"masks, scores, logits = predictor.predict(\n",
" point_coords=input_point[None],\n",
" point_labels=input_label[None],\n",
" multimask_output=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(masks.shape[-1]):\n",
" mask = masks[..., i]\n",
" score = scores[..., i].item()\n",
" plt.figure(figsize=(10,10))\n",
" plt.imshow(image)\n",
" show_mask(mask[0], plt.gca())\n",
" show_points(input_point, input_label, plt.gca())\n",
" plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
" plt.axis('off')\n",
" plt.show() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Specifying a specific object with additional points"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_point = mx.array([[500, 375], [1125, 625]])\n",
"input_label = mx.array([1, 1])\n",
"mask_input = logits[..., mx.argmax(scores)] # Choose the model's best mask"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"masks, _, _ = predictor.predict(\n",
" point_coords=input_point[None],\n",
" point_labels=input_label[None],\n",
" mask_input=mask_input[..., None],\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10,10))\n",
"plt.imshow(image)\n",
"show_mask(masks[0], plt.gca())\n",
"show_points(input_point, input_label, plt.gca())\n",
"plt.axis('off')\n",
"plt.show() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_point = mx.array([[500, 375], [1125, 625]])\n",
"input_label = mx.array([1, 0])\n",
"mask_input = logits[..., mx.argmax(scores)] # Choose the model's best mask"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"masks, _, _ = predictor.predict(\n",
" point_coords=input_point[None],\n",
" point_labels=input_label[None],\n",
" mask_input=mask_input[..., None],\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(image)\n",
"show_mask(masks[0], plt.gca())\n",
"show_points(input_point, input_label, plt.gca())\n",
"plt.axis('off')\n",
"plt.show() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Specifying a specific object with a box"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model can also take a box as input, provided in xyxy format."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_box = mx.array([425, 600, 700, 875])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"masks, _, _ = predictor.predict(\n",
" point_coords=None,\n",
" point_labels=None,\n",
" box=input_box[None, :],\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(image)\n",
"show_mask(masks[0, ..., 0], plt.gca())\n",
"show_box(input_box, plt.gca())\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Combining points and boxes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Points and boxes may be combined, just by including both types of prompts to the predictor. Here this can be used to select just the trucks's tire, instead of the entire wheel."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_box = mx.array([425, 600, 700, 875])\n",
"input_point = mx.array([[575, 750]])\n",
"input_label = mx.array([0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"masks, _, _ = predictor.predict(\n",
" point_coords=input_point[None],\n",
" point_labels=input_label[None],\n",
" box=input_box,\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(image)\n",
"show_mask(masks[0, ..., 0], plt.gca())\n",
"show_box(input_box, plt.gca())\n",
"show_points(input_point, input_label, plt.gca())\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Batched prompt inputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`SamPredictor` can take multiple input prompts for the same image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_boxes = mx.array([\n",
" [75, 275, 1725, 850],\n",
" [425, 600, 700, 875],\n",
" [1375, 550, 1650, 800],\n",
" [1240, 675, 1400, 750],\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"masks, _, _ = predictor.predict(\n",
" point_coords=None,\n",
" point_labels=None,\n",
" box=input_boxes,\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(image)\n",
"for mask in masks:\n",
" show_mask(mask, plt.gca(), random_color=True)\n",
"for box in input_boxes:\n",
" show_box(box, plt.gca())\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## End-to-end batched inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If all prompts are available in advance, it is possible to run SAM directly in an end-to-end fashion. This also allows batching over images."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image1 = image # truck.jpg from above\n",
"image1_boxes = mx.array([\n",
" [75, 275, 1725, 850],\n",
" [425, 600, 700, 875],\n",
" [1375, 550, 1650, 800],\n",
" [1240, 675, 1400, 750],\n",
"])\n",
"\n",
"image2 = cv2.imread('images/groceries.jpg')\n",
"image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)\n",
"image2_boxes = mx.array([\n",
" [450, 170, 520, 350],\n",
" [350, 190, 450, 350],\n",
" [500, 170, 580, 350],\n",
" [580, 170, 640, 350],\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Both images and prompts are input as mlx array that are already transformed to the correct frame. Inputs are packaged as a list over images, which each element is a dict that takes the following keys:\n",
"* `image`: The input image as a mlx array in HWC format.\n",
"* `original_size`: The size of the image before transforming for input to SAM, in (H, W) format.\n",
"* `point_coords`: Batched coordinates of point prompts.\n",
"* `point_labels`: Batched labels of point prompts.\n",
"* `boxes`: Batched input boxes.\n",
"* `mask_inputs`: Batched input masks.\n",
"\n",
"If a prompt is not present, the key can be excluded."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from segment_anything.utils.transforms import ResizeLongestSide\n",
"resize_transform = ResizeLongestSide(sam.vision_encoder.img_size)\n",
"\n",
"def prepare_image(image, transform, device):\n",
" image = transform.apply_image(image)\n",
" image = mx.array(image)\n",
" return image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batched_input = [\n",
" {\n",
" 'image': prepare_image(image1, resize_transform, sam),\n",
" 'boxes': resize_transform.apply_boxes(image1_boxes, image1.shape[:2]),\n",
" 'original_size': image1.shape[:2]\n",
" },\n",
" {\n",
" 'image': prepare_image(image2, resize_transform, sam),\n",
" 'boxes': resize_transform.apply_boxes(image2_boxes, image2.shape[:2]),\n",
" 'original_size': image2.shape[:2]\n",
" }\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batched_output = sam(batched_input, multimask_output=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The output is a list over results for each input image, where list elements are dictionaries with the following keys:\n",
"* `masks`: A batched mlx array of predicted binary masks, the size of the original image.\n",
"* `iou_predictions`: The model's prediction of the quality for each mask.\n",
"* `low_res_logits`: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(1, 2, figsize=(20, 20))\n",
"\n",
"ax[0].imshow(image1)\n",
"for mask in batched_output[0]['masks']:\n",
" show_mask(np.array(mask), ax[0], random_color=True)\n",
"for box in image1_boxes:\n",
" show_box(np.array(box), ax[0])\n",
"ax[0].axis('off')\n",
"\n",
"ax[1].imshow(image2)\n",
"for mask in batched_output[1]['masks']:\n",
" show_mask(np.array(mask), ax[1], random_color=True)\n",
"for box in image2_boxes:\n",
" show_box(np.array(box), ax[1])\n",
"ax[1].axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 2
}