mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add tile op (#438)
This commit is contained in:
		| @@ -24,9 +24,10 @@ class MLXTestCase(unittest.TestCase): | ||||
|     def tearDown(self): | ||||
|         mx.set_default_device(self.default) | ||||
|  | ||||
|     # Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple | ||||
|     def assertCmpNumpy( | ||||
|         self, | ||||
|         shape: List[Union[Tuple[int], Any]], | ||||
|         args: List[Union[Tuple[int], Any]], | ||||
|         mx_fn: Callable[..., mx.array], | ||||
|         np_fn: Callable[..., np.array], | ||||
|         atol=1e-2, | ||||
| @@ -37,7 +38,7 @@ class MLXTestCase(unittest.TestCase): | ||||
|         assert dtype != mx.bfloat16, "numpy does not support bfloat16" | ||||
|         args = [ | ||||
|             mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s | ||||
|             for s in shape | ||||
|             for s in args | ||||
|         ] | ||||
|         mx_res = mx_fn(*args, **kwargs) | ||||
|         np_res = np_fn( | ||||
|   | ||||
| @@ -1634,6 +1634,23 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                     np.allclose(np_out[0], mx_out[0]), msg=f"Shapes {s1} {s2}, Type {t}" | ||||
|                 ) | ||||
|  | ||||
|     def test_tile(self): | ||||
|         self.assertCmpNumpy([(2,), [2]], mx.tile, np.tile) | ||||
|         self.assertCmpNumpy([(2, 3, 4), [2]], mx.tile, np.tile) | ||||
|         self.assertCmpNumpy([(2, 3, 4), [2, 1]], mx.tile, np.tile) | ||||
|         self.assertCmpNumpy( | ||||
|             [ | ||||
|                 (2, 3, 4), | ||||
|                 [ | ||||
|                     2, | ||||
|                     2, | ||||
|                 ], | ||||
|             ], | ||||
|             mx.tile, | ||||
|             np.tile, | ||||
|         ) | ||||
|         self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Diogo
					Diogo