mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Add inner / outer op (#348)
* inner / outer impl * python tests * ops list and ack * updated descriptions * use test helper * removed dtype check and flatten outer to 1-D * updated docs * just use the reshape to flatten
This commit is contained in:
21
mlx/ops.cpp
21
mlx/ops.cpp
@@ -2848,10 +2848,6 @@ array tensordot(
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
|
||||
}
|
||||
if (a.dtype() != b.dtype()) {
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] a and b must have the same dtype.");
|
||||
}
|
||||
int csize = 1;
|
||||
auto x = a;
|
||||
auto y = b;
|
||||
@@ -2905,4 +2901,21 @@ array tensordot(
|
||||
return reshape(matmul(x, y, s), rshape, s);
|
||||
}
|
||||
|
||||
array outer(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
return multiply(
|
||||
reshape(a, {static_cast<int>(a.size()), 1}, s), flatten(b, s), s);
|
||||
}
|
||||
|
||||
array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||
return multiply(a, b, s);
|
||||
}
|
||||
if (a.shape(-1) != b.shape(-1)) {
|
||||
throw std::invalid_argument(
|
||||
"[inner] a and b must have the same last dimension.");
|
||||
}
|
||||
|
||||
return tensordot(a, b, {{-1}, {-1}}, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1075,6 +1075,12 @@ array tensordot(
|
||||
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute the outer product of two vectors. */
|
||||
array outer(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Compute the inner product of two vectors. */
|
||||
array inner(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
|
Reference in New Issue
Block a user