Added Kronecker Product (#1728)

This commit is contained in:
Venkata Naga Aditya Datta Chivukula
2025-01-02 17:00:34 -07:00
committed by GitHub
parent 92ec632ad5
commit 491fa95b1f
4 changed files with 88 additions and 0 deletions

View File

@@ -1000,6 +1000,34 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(mx.grad(func)(x).tolist(), expected)
def test_kron(self):
# Basic vector test
x = mx.array([1, 2])
y = mx.array([3, 4])
z = mx.kron(x, y)
self.assertEqual(z.tolist(), [3, 4, 6, 8])
# Basic matrix test
x = mx.array([[1, 2], [3, 4]])
y = mx.array([[0, 5], [6, 7]])
z = mx.kron(x, y)
self.assertEqual(
z.tolist(),
[[0, 5, 0, 10], [6, 7, 12, 14], [0, 15, 0, 20], [18, 21, 24, 28]],
)
# Test with different dimensions
x = mx.array([1, 2]) # (2,)
y = mx.array([[3, 4], [5, 6]]) # (2, 2)
z = mx.kron(x, y)
self.assertEqual(z.tolist(), [[3, 4, 6, 8], [5, 6, 10, 12]])
# Test with empty array
x = mx.array([])
y = mx.array([1, 2])
with self.assertRaises(ValueError):
mx.kron(x, y)
def test_take(self):
# Shape: 4 x 3 x 2
l = [