From 491fa95b1f9ec52b24ca4aebde85aa963de73047 Mon Sep 17 00:00:00 2001 From: Venkata Naga Aditya Datta Chivukula <51185970+cvnad1@users.noreply.github.com> Date: Thu, 2 Jan 2025 17:00:34 -0700 Subject: [PATCH] Added Kronecker Product (#1728) --- mlx/ops.cpp | 28 ++++++++++++++++++++++++++++ mlx/ops.h | 3 +++ python/src/ops.cpp | 29 +++++++++++++++++++++++++++++ python/tests/test_ops.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 88 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9386d59b9..99b6a721e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2759,6 +2759,34 @@ array gather( inputs); } +array kron(const array& a, const array& b, StreamOrDevice s /* = {} */) { + if (a.size() == 0 || b.size() == 0) { + throw std::invalid_argument("[kron] Input arrays cannot be empty."); + } + + int ndim = std::max(a.ndim(), b.ndim()); + std::vector a_shape(2 * ndim, 1); + std::vector b_shape(2 * ndim, 1); + std::vector out_shape(ndim, 1); + + for (int i = ndim - 1, j = a.ndim() - 1; j >= 0; j--, i--) { + a_shape[2 * i] = a.shape(j); + out_shape[i] *= a.shape(j); + } + for (int i = ndim - 1, j = b.ndim() - 1; j >= 0; j--, i--) { + b_shape[2 * i + 1] = b.shape(j); + out_shape[i] *= b.shape(j); + } + + return reshape( + multiply( + reshape(a, std::move(a_shape), s), + reshape(b, std::move(b_shape), s), + s), + std::move(out_shape), + s); +} + array take( const array& a, const array& indices, diff --git a/mlx/ops.h b/mlx/ops.h index d6e456c88..0f92ff372 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -917,6 +917,9 @@ inline array gather( return gather(a, {indices}, std::vector{axis}, slice_sizes, s); } +/** Returns Kronecker Producct given two input arrays. */ +array kron(const array& a, const array& b, StreamOrDevice s = {}); + /** Take array slices at the given indices of the specified axis. */ array take( const array& a, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index ce696b7dc..5ea979670 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1458,6 +1458,35 @@ void init_ops(nb::module_& m) { Returns: array: The range of values. )pbdoc"); + m.def( + "kron", + &kron, + nb::arg("a"), + nb::arg("b"), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the Kronecker product of two arrays `a` and `b`. + Args: + a (array): The first input array + b (array): The second input array + stream (Union[None, Stream, Device], optional): Optional stream or device for execution. + Default is `None`. + Returns: + array: The Kronecker product of `a` and `b`. + Examples: + >>> import mlx + >>> a = mlx.array([[1, 2], [3, 4]]) + >>> b = mlx.array([[0, 5], [6, 7]]) + >>> result = mlx.kron(a, b) + >>> print(result) + [[ 0 5 0 10] + [ 6 7 12 14] + [ 0 15 0 20] + [18 21 24 28]] + )pbdoc"); m.def( "take", [](const mx::array& a, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d76a8143e..6d01ee5a2 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1000,6 +1000,34 @@ class TestOps(mlx_tests.MLXTestCase): self.assertListEqual(mx.grad(func)(x).tolist(), expected) + def test_kron(self): + # Basic vector test + x = mx.array([1, 2]) + y = mx.array([3, 4]) + z = mx.kron(x, y) + self.assertEqual(z.tolist(), [3, 4, 6, 8]) + + # Basic matrix test + x = mx.array([[1, 2], [3, 4]]) + y = mx.array([[0, 5], [6, 7]]) + z = mx.kron(x, y) + self.assertEqual( + z.tolist(), + [[0, 5, 0, 10], [6, 7, 12, 14], [0, 15, 0, 20], [18, 21, 24, 28]], + ) + + # Test with different dimensions + x = mx.array([1, 2]) # (2,) + y = mx.array([[3, 4], [5, 6]]) # (2, 2) + z = mx.kron(x, y) + self.assertEqual(z.tolist(), [[3, 4, 6, 8], [5, 6, 10, 12]]) + + # Test with empty array + x = mx.array([]) + y = mx.array([1, 2]) + with self.assertRaises(ValueError): + mx.kron(x, y) + def test_take(self): # Shape: 4 x 3 x 2 l = [