diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index bb16d4b1f..18a8c5599 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 5f7f60b99..2112a2a4d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -753,6 +753,36 @@ array repeat(const array& arr, int repeats, StreamOrDevice s) { return repeat(flatten(arr, s), repeats, 0, s); } +array tile( + const array& arr, + std::vector reps, + StreamOrDevice s /* = {} */) { + auto shape = arr.shape(); + if (reps.size() < shape.size()) { + reps.insert(reps.begin(), shape.size() - reps.size(), 1); + } + if (reps.size() > shape.size()) { + shape.insert(shape.begin(), reps.size() - shape.size(), 1); + } + + std::vector expand_shape; + std::vector broad_shape; + std::vector final_shape; + for (int i = 0; i < shape.size(); i++) { + if (reps[i] != 1) { + expand_shape.push_back(1); + broad_shape.push_back(reps[i]); + } + expand_shape.push_back(shape[i]); + broad_shape.push_back(shape[i]); + final_shape.push_back(reps[i] * shape[i]); + } + + auto x = reshape(arr, expand_shape, s); + x = broadcast_to(x, broad_shape, s); + return reshape(x, final_shape, s); +} + /** Pad an array with a constant value */ array pad( const array& a, diff --git a/mlx/ops.h b/mlx/ops.h index f095a06dd..f0823ed5f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -218,6 +218,8 @@ array stack(const std::vector& arrays, StreamOrDevice s = {}); array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {}); array repeat(const array& arr, int repeats, StreamOrDevice s = {}); +array tile(const array& arr, std::vector reps, StreamOrDevice s = {}); + /** Permutes the dimensions according to the given axes. */ array transpose(const array& a, std::vector axes, StreamOrDevice s = {}); inline array transpose( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 9a47f0ed8..c9c29532c 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3394,4 +3394,30 @@ void init_ops(py::module_& m) { Returns: result (array): The outer product. )pbdoc"); + m.def( + "tile", + [](const array& a, const IntOrVec& reps, StreamOrDevice s) { + if (auto pv = std::get_if(&reps); pv) { + return tile(a, {*pv}, s); + } else { + return tile(a, std::get>(reps), s); + } + }, + "a"_a, + "reps"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + tile(a: array, reps: Union[int, List[int]], /, *, stream: Union[None, Stream, Device] = None) -> array + + Construct an array by repeating ``a`` the number of times given by ``reps``. + + Args: + a (array): Input array + reps (int or list(int)): The number of times to repeat ``a`` along each axis. + + Returns: + result (array): The tiled array. + )pbdoc"); } diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 8cccc61ec..414759b5e 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -24,9 +24,10 @@ class MLXTestCase(unittest.TestCase): def tearDown(self): mx.set_default_device(self.default) + # Note if a tuple is passed into args, it will be considered a shape request and convert to a mx.random.normal with the shape matching the tuple def assertCmpNumpy( self, - shape: List[Union[Tuple[int], Any]], + args: List[Union[Tuple[int], Any]], mx_fn: Callable[..., mx.array], np_fn: Callable[..., np.array], atol=1e-2, @@ -37,7 +38,7 @@ class MLXTestCase(unittest.TestCase): assert dtype != mx.bfloat16, "numpy does not support bfloat16" args = [ mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s - for s in shape + for s in args ] mx_res = mx_fn(*args, **kwargs) np_res = np_fn( diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 3b889359c..84135427f 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1634,6 +1634,23 @@ class TestOps(mlx_tests.MLXTestCase): np.allclose(np_out[0], mx_out[0]), msg=f"Shapes {s1} {s2}, Type {t}" ) + def test_tile(self): + self.assertCmpNumpy([(2,), [2]], mx.tile, np.tile) + self.assertCmpNumpy([(2, 3, 4), [2]], mx.tile, np.tile) + self.assertCmpNumpy([(2, 3, 4), [2, 1]], mx.tile, np.tile) + self.assertCmpNumpy( + [ + (2, 3, 4), + [ + 2, + 2, + ], + ], + mx.tile, + np.tile, + ) + self.assertCmpNumpy([(3,), [2, 2, 2]], mx.tile, np.tile) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index d8d76b94f..84c35ff2f 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2343,6 +2343,32 @@ TEST_CASE("test repeat") { CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument); } +TEST_CASE("tile") { + auto x = array({1, 2, 3}, {3}); + auto y = tile(x, {2}); + auto expected = array({1, 2, 3, 1, 2, 3}, {6}); + CHECK(array_equal(y, expected).item()); + x = array({1, 2, 3, 4}, {2, 2}); + y = tile(x, {2}); + expected = array({1, 2, 1, 2, 3, 4, 3, 4}, {2, 4}); + CHECK(array_equal(y, expected).item()); + x = array({1, 2, 3, 4}, {2, 2}); + y = tile(x, {4, 1}); + expected = array({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, {8, 2}); + CHECK(array_equal(y, expected).item()); + + x = array({1, 2, 3, 4}, {2, 2}); + y = tile(x, {2, 2}); + expected = array({1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4}, {4, 4}); + CHECK(array_equal(y, expected).item()); + x = array({1, 2, 3}, {3}); + y = tile(x, {2, 2, 2}); + expected = array( + {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}, + {2, 2, 6}); + CHECK(array_equal(y, expected).item()); +} + TEST_CASE("tensordot") { auto x = reshape(arange(60.), {3, 4, 5}); auto y = reshape(arange(24.), {4, 3, 2});