mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-08 01:54:37 +08:00
Add tile op (#438)
This commit is contained in:
@@ -3394,4 +3394,30 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): The outer product.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"tile",
|
||||
[](const array& a, const IntOrVec& reps, StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&reps); pv) {
|
||||
return tile(a, {*pv}, s);
|
||||
} else {
|
||||
return tile(a, std::get<std::vector<int>>(reps), s);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"reps"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Construct an array by repeating ``a`` the number of times given by ``reps``.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
reps (int or list(int)): The number of times to repeat ``a`` along each axis.
|
||||
|
||||
Returns:
|
||||
result (array): The tiled array.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -24,9 +24,10 @@ class MLXTestCase(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
mx.set_default_device(self.default)
|
||||
|
||||
# Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple
|
||||
def assertCmpNumpy(
|
||||
self,
|
||||
shape: List[Union[Tuple[int], Any]],
|
||||
args: List[Union[Tuple[int], Any]],
|
||||
mx_fn: Callable[..., mx.array],
|
||||
np_fn: Callable[..., np.array],
|
||||
atol=1e-2,
|
||||
@@ -37,7 +38,7 @@ class MLXTestCase(unittest.TestCase):
|
||||
assert dtype != mx.bfloat16, "numpy does not support bfloat16"
|
||||
args = [
|
||||
mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s
|
||||
for s in shape
|
||||
for s in args
|
||||
]
|
||||
mx_res = mx_fn(*args, **kwargs)
|
||||
np_res = np_fn(
|
||||
|
@@ -1634,6 +1634,23 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
np.allclose(np_out[0], mx_out[0]), msg=f"Shapes {s1} {s2}, Type {t}"
|
||||
)
|
||||
|
||||
def test_tile(self):
|
||||
self.assertCmpNumpy([(2,), [2]], mx.tile, np.tile)
|
||||
self.assertCmpNumpy([(2, 3, 4), [2]], mx.tile, np.tile)
|
||||
self.assertCmpNumpy([(2, 3, 4), [2, 1]], mx.tile, np.tile)
|
||||
self.assertCmpNumpy(
|
||||
[
|
||||
(2, 3, 4),
|
||||
[
|
||||
2,
|
||||
2,
|
||||
],
|
||||
],
|
||||
mx.tile,
|
||||
np.tile,
|
||||
)
|
||||
self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user