Add inner / outer op (#348)

* inner / outer impl

* python tests

* ops list and ack

* updated descriptions

* use test helper

* removed dtype check and flatten outer to 1-D

* updated docs

* just use the reshape to flatten
This commit is contained in:
Diogo
2024-01-07 12:01:09 -05:00
committed by GitHub
parent 6ea6b4258d
commit 449b43762e
7 changed files with 140 additions and 5 deletions

View File

@@ -2848,10 +2848,6 @@ array tensordot(
throw std::invalid_argument(
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
}
if (a.dtype() != b.dtype()) {
throw std::invalid_argument(
"[tensordot] a and b must have the same dtype.");
}
int csize = 1;
auto x = a;
auto y = b;
@@ -2905,4 +2901,21 @@ array tensordot(
return reshape(matmul(x, y, s), rshape, s);
}
array outer(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return multiply(
reshape(a, {static_cast<int>(a.size()), 1}, s), flatten(b, s), s);
}
array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (a.ndim() == 0 || b.ndim() == 0) {
return multiply(a, b, s);
}
if (a.shape(-1) != b.shape(-1)) {
throw std::invalid_argument(
"[inner] a and b must have the same last dimension.");
}
return tensordot(a, b, {{-1}, {-1}}, s);
}
} // namespace mlx::core

View File

@@ -1075,6 +1075,12 @@ array tensordot(
const std::pair<std::vector<int>, std::vector<int>>& dims,
StreamOrDevice s = {});
/** Compute the outer product of two vectors. */
array outer(const array& a, const array& b, StreamOrDevice s = {});
/** Compute the inner product of two vectors. */
array inner(const array& a, const array& b, StreamOrDevice s = {});
/** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream,