mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Added Kronecker Product (#1728)
This commit is contained in:
parent
92ec632ad5
commit
491fa95b1f
28
mlx/ops.cpp
28
mlx/ops.cpp
@ -2759,6 +2759,34 @@ array gather(
|
|||||||
inputs);
|
inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array kron(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
if (a.size() == 0 || b.size() == 0) {
|
||||||
|
throw std::invalid_argument("[kron] Input arrays cannot be empty.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int ndim = std::max(a.ndim(), b.ndim());
|
||||||
|
std::vector<int> a_shape(2 * ndim, 1);
|
||||||
|
std::vector<int> b_shape(2 * ndim, 1);
|
||||||
|
std::vector<int> out_shape(ndim, 1);
|
||||||
|
|
||||||
|
for (int i = ndim - 1, j = a.ndim() - 1; j >= 0; j--, i--) {
|
||||||
|
a_shape[2 * i] = a.shape(j);
|
||||||
|
out_shape[i] *= a.shape(j);
|
||||||
|
}
|
||||||
|
for (int i = ndim - 1, j = b.ndim() - 1; j >= 0; j--, i--) {
|
||||||
|
b_shape[2 * i + 1] = b.shape(j);
|
||||||
|
out_shape[i] *= b.shape(j);
|
||||||
|
}
|
||||||
|
|
||||||
|
return reshape(
|
||||||
|
multiply(
|
||||||
|
reshape(a, std::move(a_shape), s),
|
||||||
|
reshape(b, std::move(b_shape), s),
|
||||||
|
s),
|
||||||
|
std::move(out_shape),
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
|
||||||
array take(
|
array take(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& indices,
|
const array& indices,
|
||||||
|
@ -917,6 +917,9 @@ inline array gather(
|
|||||||
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Returns Kronecker Producct given two input arrays. */
|
||||||
|
array kron(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Take array slices at the given indices of the specified axis. */
|
/** Take array slices at the given indices of the specified axis. */
|
||||||
array take(
|
array take(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -1458,6 +1458,35 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The range of values.
|
array: The range of values.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"kron",
|
||||||
|
&kron,
|
||||||
|
nb::arg("a"),
|
||||||
|
nb::arg("b"),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Compute the Kronecker product of two arrays `a` and `b`.
|
||||||
|
Args:
|
||||||
|
a (array): The first input array
|
||||||
|
b (array): The second input array
|
||||||
|
stream (Union[None, Stream, Device], optional): Optional stream or device for execution.
|
||||||
|
Default is `None`.
|
||||||
|
Returns:
|
||||||
|
array: The Kronecker product of `a` and `b`.
|
||||||
|
Examples:
|
||||||
|
>>> import mlx
|
||||||
|
>>> a = mlx.array([[1, 2], [3, 4]])
|
||||||
|
>>> b = mlx.array([[0, 5], [6, 7]])
|
||||||
|
>>> result = mlx.kron(a, b)
|
||||||
|
>>> print(result)
|
||||||
|
[[ 0 5 0 10]
|
||||||
|
[ 6 7 12 14]
|
||||||
|
[ 0 15 0 20]
|
||||||
|
[18 21 24 28]]
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"take",
|
"take",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
|
@ -1000,6 +1000,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertListEqual(mx.grad(func)(x).tolist(), expected)
|
self.assertListEqual(mx.grad(func)(x).tolist(), expected)
|
||||||
|
|
||||||
|
def test_kron(self):
|
||||||
|
# Basic vector test
|
||||||
|
x = mx.array([1, 2])
|
||||||
|
y = mx.array([3, 4])
|
||||||
|
z = mx.kron(x, y)
|
||||||
|
self.assertEqual(z.tolist(), [3, 4, 6, 8])
|
||||||
|
|
||||||
|
# Basic matrix test
|
||||||
|
x = mx.array([[1, 2], [3, 4]])
|
||||||
|
y = mx.array([[0, 5], [6, 7]])
|
||||||
|
z = mx.kron(x, y)
|
||||||
|
self.assertEqual(
|
||||||
|
z.tolist(),
|
||||||
|
[[0, 5, 0, 10], [6, 7, 12, 14], [0, 15, 0, 20], [18, 21, 24, 28]],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with different dimensions
|
||||||
|
x = mx.array([1, 2]) # (2,)
|
||||||
|
y = mx.array([[3, 4], [5, 6]]) # (2, 2)
|
||||||
|
z = mx.kron(x, y)
|
||||||
|
self.assertEqual(z.tolist(), [[3, 4, 6, 8], [5, 6, 10, 12]])
|
||||||
|
|
||||||
|
# Test with empty array
|
||||||
|
x = mx.array([])
|
||||||
|
y = mx.array([1, 2])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.kron(x, y)
|
||||||
|
|
||||||
def test_take(self):
|
def test_take(self):
|
||||||
# Shape: 4 x 3 x 2
|
# Shape: 4 x 3 x 2
|
||||||
l = [
|
l = [
|
||||||
|
Loading…
Reference in New Issue
Block a user