mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -120,6 +120,11 @@ array mlx_gather_nd( | ||||
|     if (py::isinstance<py::slice>(idx)) { | ||||
|       int start, end, stride; | ||||
|       get_slice_params(start, end, stride, idx, src.shape(i)); | ||||
|  | ||||
|       // Handle negative indices | ||||
|       start = (start < 0) ? start + src.shape(i) : start; | ||||
|       end = (end < 0) ? end + src.shape(i) : end; | ||||
|  | ||||
|       gather_indices.push_back(arange(start, end, stride, uint32)); | ||||
|       num_slices++; | ||||
|       is_slice[i] = true; | ||||
|   | ||||
| @@ -727,6 +727,11 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|             np.array_equal(a_np[idx_np, idx_np], np.array(a_mlx[idx_mlx, idx_mlx])) | ||||
|         ) | ||||
|  | ||||
|         # Slicing with negative indices and integer | ||||
|         a_np = np.arange(10).reshape(5, 2) | ||||
|         a_mlx = mx.array(a_np) | ||||
|         self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0]))) | ||||
|  | ||||
|     def test_setitem(self): | ||||
|         a = mx.array(0) | ||||
|         a[None] = 1 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun