Add tile op (#438)

This commit is contained in:
Diogo
2024-01-13 02:03:16 -05:00
committed by GitHub
parent 1b71487e1f
commit 2e29d0815b
7 changed files with 105 additions and 3 deletions

View File

@@ -24,9 +24,10 @@ class MLXTestCase(unittest.TestCase):
def tearDown(self):
mx.set_default_device(self.default)
# Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple
def assertCmpNumpy(
self,
shape: List[Union[Tuple[int], Any]],
args: List[Union[Tuple[int], Any]],
mx_fn: Callable[..., mx.array],
np_fn: Callable[..., np.array],
atol=1e-2,
@@ -37,7 +38,7 @@ class MLXTestCase(unittest.TestCase):
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
args = [
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
for s in shape
for s in args
]
mx_res = mx_fn(*args, **kwargs)
np_res = np_fn(