mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
improvements to scatter / gather (#1541)
This commit is contained in:
@@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
|
||||
|
||||
def _nearest_indices(N, scale, dim, ndims):
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
|
||||
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
@@ -37,8 +37,8 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
weight = mx.expand_dims(weight, -1)
|
||||
|
||||
return (
|
||||
(indices_l.astype(mx.int32), 1 - weight),
|
||||
(indices_r.astype(mx.int32), weight),
|
||||
(indices_l.astype(mx.uint32), 1 - weight),
|
||||
(indices_r.astype(mx.uint32), weight),
|
||||
)
|
||||
|
||||
|
||||
@@ -73,10 +73,10 @@ def _cubic_indices(N, scale, align_corners, dim, ndims):
|
||||
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
|
||||
|
||||
return (
|
||||
(indices_l1.astype(mx.int32), weight_l1),
|
||||
(indices_r1.astype(mx.int32), weight_r1),
|
||||
(indices_l2.astype(mx.int32), weight_l2),
|
||||
(indices_r2.astype(mx.int32), weight_r2),
|
||||
(indices_l1.astype(mx.uint32), weight_l1),
|
||||
(indices_r1.astype(mx.uint32), weight_r1),
|
||||
(indices_l2.astype(mx.uint32), weight_l2),
|
||||
(indices_r2.astype(mx.uint32), weight_r2),
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1089,12 +1089,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a_mlx = mx.array(a_np)
|
||||
|
||||
if ax == None:
|
||||
idx_np = np.random.randint(low=0, high=a_np.size, size=(16,))
|
||||
idx_np = np.random.permutation(a_np.size)
|
||||
values_np = np.random.randint(low=0, high=100, size=(16,))
|
||||
else:
|
||||
shape = list(a_np.shape)
|
||||
shape[ax] = 2
|
||||
idx_np = np.random.randint(low=0, high=a_np.shape[ax], size=shape)
|
||||
idx_np = np.random.choice(a_np.shape[ax], replace=False, size=(2,))
|
||||
idx_np = np.expand_dims(idx_np, list(range(1, 2 - ax + 1)))
|
||||
idx_np = np.broadcast_to(idx_np, shape)
|
||||
values_np = np.random.randint(low=0, high=100, size=shape)
|
||||
|
||||
idx_np.astype(np.int32)
|
||||
|
Reference in New Issue
Block a user