mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-10 13:07:28 +08:00
Segment Anything Model (#552)
* add segment anything model * add readme * reorg file structure * update * lint * minor updates * ack * fix weight loading * simplify * fix to run notebooks * amg in mlx * remove torch dependency * nit in README * return indices in nms * simplify * bugfix / simplify * fix bug' * simplify * fix notebook and remove output * couple more nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
35
segment_anything/segment_anything/common.py
Normal file
35
segment_anything/segment_anything/common.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Type
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
mlp_dim: int,
|
||||
act: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
||||
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
||||
self.act = act()
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.lin2(self.act(self.lin1(x)))
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = mx.ones(num_channels)
|
||||
self.bias = mx.zeros(num_channels)
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
u = x.mean(3, keepdims=True)
|
||||
s = ((x - u) ** 2).mean(3, keepdims=True)
|
||||
x = (x - u) / mx.sqrt(s + self.eps)
|
||||
x = self.weight * x + self.bias
|
||||
return x
|
Reference in New Issue
Block a user