mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
Add tile op (#438)
This commit is contained in:
@@ -24,9 +24,10 @@ class MLXTestCase(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
mx.set_default_device(self.default)
|
||||
|
||||
# Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple
|
||||
def assertCmpNumpy(
|
||||
self,
|
||||
shape: List[Union[Tuple[int], Any]],
|
||||
args: List[Union[Tuple[int], Any]],
|
||||
mx_fn: Callable[..., mx.array],
|
||||
np_fn: Callable[..., np.array],
|
||||
atol=1e-2,
|
||||
@@ -37,7 +38,7 @@ class MLXTestCase(unittest.TestCase):
|
||||
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
|
||||
args = [
|
||||
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
|
||||
for s in shape
|
||||
for s in args
|
||||
]
|
||||
mx_res = mx_fn(*args, **kwargs)
|
||||
np_res = np_fn(
|
||||
|
Reference in New Issue
Block a user