mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Tensordot op (#344)
This commit is contained in:
90
mlx/ops.cpp
90
mlx/ops.cpp
@@ -2793,4 +2793,94 @@ array dequantize(
|
||||
return w_full;
|
||||
}
|
||||
|
||||
array tensordot(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const int dims /* = 2 */,
|
||||
StreamOrDevice s /* = {} */
|
||||
) {
|
||||
if (dims < 0) {
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] dims must be greater or equal to 0.");
|
||||
}
|
||||
if (dims > std::min(a.ndim(), b.ndim())) {
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] dims must be less than the number of dimensions of a and b.");
|
||||
}
|
||||
std::vector<int> adims;
|
||||
std::vector<int> bdims;
|
||||
for (int i = 0; i < dims; i++) {
|
||||
bdims.emplace_back(i);
|
||||
adims.emplace_back(-dims + i);
|
||||
}
|
||||
return tensordot(a, b, {adims, bdims}, s);
|
||||
}
|
||||
|
||||
array tensordot(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
||||
StreamOrDevice s /* = {} */
|
||||
) {
|
||||
if (dims.first.size() != dims.second.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
|
||||
}
|
||||
if (a.dtype() != b.dtype()) {
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] a and b must have the same dtype.");
|
||||
}
|
||||
int csize = 1;
|
||||
auto x = a;
|
||||
auto y = b;
|
||||
for (int i = 0; i < dims.first.size(); i++) {
|
||||
if (x.shape(dims.first.at(i)) == y.shape(dims.second.at(i))) {
|
||||
csize *= x.shape(dims.first.at(i));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tensordot] a and b must have the same shape on the contracted axes.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<bool> cdims1(x.ndim(), false);
|
||||
std::vector<bool> cdims2(y.ndim(), false);
|
||||
for (const auto n : dims.first) {
|
||||
int n_ = (n < 0) ? n + x.ndim() : n;
|
||||
cdims1[n_] = true;
|
||||
}
|
||||
for (const auto n : dims.second) {
|
||||
int n_ = (n < 0) ? n + y.ndim() : n;
|
||||
cdims2[n_] = true;
|
||||
}
|
||||
|
||||
std::vector<int> t1;
|
||||
std::vector<int> t2;
|
||||
std::vector<int> rshape;
|
||||
int size1 = 1;
|
||||
int size2 = 1;
|
||||
for (int i = 0; i < a.ndim(); i++) {
|
||||
if (!cdims1[i]) {
|
||||
t1.emplace_back(i);
|
||||
size1 *= a.shape(i);
|
||||
rshape.emplace_back(a.shape(i));
|
||||
}
|
||||
}
|
||||
for (const auto x : dims.first) {
|
||||
t1.emplace_back(x);
|
||||
}
|
||||
for (const auto x : dims.second) {
|
||||
t2.emplace_back(x);
|
||||
}
|
||||
for (int i = 0; i < b.ndim(); i++) {
|
||||
if (!cdims2[i]) {
|
||||
t2.emplace_back(i);
|
||||
size2 *= b.shape(i);
|
||||
rshape.emplace_back(b.shape(i));
|
||||
}
|
||||
}
|
||||
x = reshape(transpose(x, t1, s), {size1, csize}, s);
|
||||
y = reshape(transpose(y, t2, s), {csize, size2}, s);
|
||||
return reshape(matmul(x, y, s), rshape, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
13
mlx/ops.h
13
mlx/ops.h
@@ -1061,6 +1061,19 @@ array dequantize(
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** TensorDot returns a contraction of a and b over multiple dimensions. */
|
||||
array tensordot(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const int dims = 2,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array tensordot(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
|
||||
Reference in New Issue
Block a user