improvements to scatter / gather (#1541)

This commit is contained in:
Awni Hannun
2024-10-30 19:30:54 -07:00
committed by GitHub
parent 960e3f0f05
commit 4f72c66911
9 changed files with 195 additions and 248 deletions

View File

@@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
def _nearest_indices(N, scale, dim, ndims):
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
def _linear_indices(N, scale, align_corners, dim, ndims):
@@ -37,8 +37,8 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
weight = mx.expand_dims(weight, -1)
return (
(indices_l.astype(mx.int32), 1 - weight),
(indices_r.astype(mx.int32), weight),
(indices_l.astype(mx.uint32), 1 - weight),
(indices_r.astype(mx.uint32), weight),
)
@@ -73,10 +73,10 @@ def _cubic_indices(N, scale, align_corners, dim, ndims):
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
return (
(indices_l1.astype(mx.int32), weight_l1),
(indices_r1.astype(mx.int32), weight_r1),
(indices_l2.astype(mx.int32), weight_l2),
(indices_r2.astype(mx.int32), weight_r2),
(indices_l1.astype(mx.uint32), weight_l1),
(indices_r1.astype(mx.uint32), weight_r1),
(indices_l2.astype(mx.uint32), weight_l2),
(indices_r2.astype(mx.uint32), weight_r2),
)