mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix export to work with gather/scatter axis (#2263)
This commit is contained in:
		| @@ -286,6 +286,32 @@ class TestExportImport(mlx_tests.MLXTestCase): | ||||
|         with self.assertRaises(ValueError): | ||||
|             f2(mx.array(10), mx.array([5, 10, 20])) | ||||
|  | ||||
|     def test_export_scatter_gather(self): | ||||
|         path = os.path.join(self.test_dir, "fn.mlxfn") | ||||
|  | ||||
|         def fun(a, b): | ||||
|             return mx.take_along_axis(a, b, axis=0) | ||||
|  | ||||
|         x = mx.random.uniform(shape=(4, 4)) | ||||
|         y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) | ||||
|         mx.export_function(path, fun, (x, y)) | ||||
|         imported_fun = mx.import_function(path) | ||||
|         expected = fun(x, y) | ||||
|         out = imported_fun(x, y)[0] | ||||
|         self.assertTrue(mx.array_equal(expected, out)) | ||||
|  | ||||
|         def fun(a, b, c): | ||||
|             return mx.put_along_axis(a, b, c, axis=0) | ||||
|  | ||||
|         x = mx.random.uniform(shape=(4, 4)) | ||||
|         y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) | ||||
|         z = mx.random.uniform(shape=(2, 4)) | ||||
|         mx.export_function(path, fun, (x, y, z)) | ||||
|         imported_fun = mx.import_function(path) | ||||
|         expected = fun(x, y, z) | ||||
|         out = imported_fun(x, y, z)[0] | ||||
|         self.assertTrue(mx.array_equal(expected, out)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun