mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	ExpandDims primitive (#1687)
				
					
				
			* add squeeze primitive * simplify squeeze, use in gather * fix * fix * fix * fix * fix no cpu * use squeeze in matmul and friends * expand dims primitive * comment
This commit is contained in:
		| @@ -144,23 +144,23 @@ array mlx_gather_nd( | ||||
|     int slice_index = 0; | ||||
|     for (int i = 0; i < gather_indices.size(); i++) { | ||||
|       if (is_slice[i]) { | ||||
|         std::vector<int> index_shape(max_dims + num_slices, 1); | ||||
|         Shape index_shape(max_dims + num_slices, 1); | ||||
|         index_shape[max_dims + slice_index] = gather_indices[i].shape(0); | ||||
|         gather_indices[i] = reshape(gather_indices[i], index_shape); | ||||
|         gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); | ||||
|         slice_index++; | ||||
|       } else { | ||||
|         std::vector<int> index_shape = gather_indices[i].shape(); | ||||
|         auto index_shape = gather_indices[i].shape(); | ||||
|         index_shape.insert(index_shape.end(), num_slices, 1); | ||||
|         gather_indices[i] = reshape(gather_indices[i], index_shape); | ||||
|         gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); | ||||
|       } | ||||
|     } | ||||
|   } else { | ||||
|     // reshape them so that the int/array indices are last | ||||
|     for (int i = 0; i < gather_indices.size(); i++) { | ||||
|       if (i < num_slices) { | ||||
|         std::vector<int> index_shape(max_dims + num_slices, 1); | ||||
|         Shape index_shape(max_dims + num_slices, 1); | ||||
|         index_shape[i] = gather_indices[i].shape(0); | ||||
|         gather_indices[i] = reshape(gather_indices[i], index_shape); | ||||
|         gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| @@ -172,19 +172,11 @@ array mlx_gather_nd( | ||||
|   std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1); | ||||
|   src = gather(src, gather_indices, axes, slice_sizes); | ||||
|  | ||||
|   // Squeeze the dims | ||||
|   std::vector<int> out_shape; | ||||
|   out_shape.insert( | ||||
|       out_shape.end(), | ||||
|       src.shape().begin(), | ||||
|       src.shape().begin() + max_dims + num_slices); | ||||
|   out_shape.insert( | ||||
|       out_shape.end(), | ||||
|       src.shape().begin() + max_dims + num_slices + indices.size(), | ||||
|       src.shape().end()); | ||||
|   src = reshape(src, out_shape); | ||||
|  | ||||
|   return src; | ||||
|   // Squeeze the array index dims | ||||
|   for (auto& ax : axes) { | ||||
|     ax += max_dims + num_slices; | ||||
|   } | ||||
|   return squeeze(src, axes); | ||||
| } | ||||
|  | ||||
| auto mlx_expand_ellipsis( | ||||
|   | ||||
| @@ -392,27 +392,6 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         out = fun(x, y=y, z=z) | ||||
|         self.assertEqual(out.item(), 6) | ||||
|  | ||||
|     def test_shapeless_compile(self): | ||||
|         y = 1 | ||||
|  | ||||
|         @partial(mx.compile, shapeless=True) | ||||
|         def fun(x): | ||||
|             return x + y | ||||
|  | ||||
|         x = mx.array([1, 2]) | ||||
|         self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3]))) | ||||
|  | ||||
|         # The function is not recompiled, so the change | ||||
|         # to y should not be reflected in the output | ||||
|         y = 2 | ||||
|         x = mx.array([1, 2, 3]) | ||||
|         self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4]))) | ||||
|  | ||||
|         # Type change recompiles | ||||
|         x = mx.array([1.0, 2.0, 3.0]) | ||||
|         self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0]))) | ||||
|         fun(x, y=y, z=z) | ||||
|  | ||||
|     def test_shapeless_compile(self): | ||||
|         y = 1 | ||||
|  | ||||
| @@ -477,6 +456,12 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         mx.eval(cfun(x1)) | ||||
|         self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) | ||||
|  | ||||
|         def fun(x): | ||||
|             return x * x.sum(-1, keepdims=False) | ||||
|  | ||||
|         cfun = mx.compile(fun, shapeless=True) | ||||
|         self.assertTrue(mx.array_equal(fun(x2), cfun(x2))) | ||||
|  | ||||
|     def test_compile_with_constant(self): | ||||
|         # Test float | ||||
|         @partial(mx.compile) | ||||
| @@ -809,6 +794,13 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         out = fun(*inputs) | ||||
|         self.assertTrue(mx.allclose(out, mx.full((2, 2), 20))) | ||||
|  | ||||
|     def test_shapeless_compile_matmul(self): | ||||
|         a = mx.array([0.0, 1.0, 2.0]) | ||||
|         b = mx.array([0.0, 1.0, 2.0]) | ||||
|  | ||||
|         fun = mx.compile(lambda a, b: a @ b, shapeless=True) | ||||
|         self.assertTrue(mx.allclose(fun(a, b), a @ b)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun