mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
Added Kronecker Product (#1728)
This commit is contained in:

committed by
GitHub

parent
92ec632ad5
commit
491fa95b1f
28
mlx/ops.cpp
28
mlx/ops.cpp
@@ -2759,6 +2759,34 @@ array gather(
|
||||
inputs);
|
||||
}
|
||||
|
||||
array kron(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
if (a.size() == 0 || b.size() == 0) {
|
||||
throw std::invalid_argument("[kron] Input arrays cannot be empty.");
|
||||
}
|
||||
|
||||
int ndim = std::max(a.ndim(), b.ndim());
|
||||
std::vector<int> a_shape(2 * ndim, 1);
|
||||
std::vector<int> b_shape(2 * ndim, 1);
|
||||
std::vector<int> out_shape(ndim, 1);
|
||||
|
||||
for (int i = ndim - 1, j = a.ndim() - 1; j >= 0; j--, i--) {
|
||||
a_shape[2 * i] = a.shape(j);
|
||||
out_shape[i] *= a.shape(j);
|
||||
}
|
||||
for (int i = ndim - 1, j = b.ndim() - 1; j >= 0; j--, i--) {
|
||||
b_shape[2 * i + 1] = b.shape(j);
|
||||
out_shape[i] *= b.shape(j);
|
||||
}
|
||||
|
||||
return reshape(
|
||||
multiply(
|
||||
reshape(a, std::move(a_shape), s),
|
||||
reshape(b, std::move(b_shape), s),
|
||||
s),
|
||||
std::move(out_shape),
|
||||
s);
|
||||
}
|
||||
|
||||
array take(
|
||||
const array& a,
|
||||
const array& indices,
|
||||
|
@@ -917,6 +917,9 @@ inline array gather(
|
||||
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
||||
}
|
||||
|
||||
/** Returns Kronecker Producct given two input arrays. */
|
||||
array kron(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Take array slices at the given indices of the specified axis. */
|
||||
array take(
|
||||
const array& a,
|
||||
|
Reference in New Issue
Block a user