mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 16:51:24 +08:00
Added Kronecker Product (#1728)
This commit is contained in:

committed by
GitHub

parent
92ec632ad5
commit
491fa95b1f
@@ -1000,6 +1000,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertListEqual(mx.grad(func)(x).tolist(), expected)
|
||||
|
||||
def test_kron(self):
|
||||
# Basic vector test
|
||||
x = mx.array([1, 2])
|
||||
y = mx.array([3, 4])
|
||||
z = mx.kron(x, y)
|
||||
self.assertEqual(z.tolist(), [3, 4, 6, 8])
|
||||
|
||||
# Basic matrix test
|
||||
x = mx.array([[1, 2], [3, 4]])
|
||||
y = mx.array([[0, 5], [6, 7]])
|
||||
z = mx.kron(x, y)
|
||||
self.assertEqual(
|
||||
z.tolist(),
|
||||
[[0, 5, 0, 10], [6, 7, 12, 14], [0, 15, 0, 20], [18, 21, 24, 28]],
|
||||
)
|
||||
|
||||
# Test with different dimensions
|
||||
x = mx.array([1, 2]) # (2,)
|
||||
y = mx.array([[3, 4], [5, 6]]) # (2, 2)
|
||||
z = mx.kron(x, y)
|
||||
self.assertEqual(z.tolist(), [[3, 4, 6, 8], [5, 6, 10, 12]])
|
||||
|
||||
# Test with empty array
|
||||
x = mx.array([])
|
||||
y = mx.array([1, 2])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.kron(x, y)
|
||||
|
||||
def test_take(self):
|
||||
# Shape: 4 x 3 x 2
|
||||
l = [
|
||||
|
Reference in New Issue
Block a user