nice tensordot for mlx c (#782)

This commit is contained in:
Awni Hannun
2024-03-04 09:51:02 -08:00
committed by GitHub
parent 6a665ea6ed
commit 5121f028d9
4 changed files with 32 additions and 33 deletions

View File

@@ -3190,42 +3190,41 @@ array dequantize(
array tensordot(
const array& a,
const array& b,
const int dims /* = 2 */,
const int axis /* = 2 */,
StreamOrDevice s /* = {} */
) {
if (dims < 0) {
if (axis < 0) {
throw std::invalid_argument(
"[tensordot] dims must be greater or equal to 0.");
"[tensordot] axis must be greater or equal to 0.");
}
if (dims > std::min(a.ndim(), b.ndim())) {
if (axis > std::min(a.ndim(), b.ndim())) {
throw std::invalid_argument(
"[tensordot] dims must be less than the number of dimensions of a and b.");
"[tensordot] axis must be less than the number of dimensions of a and b.");
}
std::vector<int> adims;
std::vector<int> bdims;
for (int i = 0; i < dims; i++) {
for (int i = 0; i < axis; i++) {
bdims.emplace_back(i);
adims.emplace_back(-dims + i);
adims.emplace_back(i - axis);
}
return tensordot(a, b, {adims, bdims}, s);
return tensordot(a, b, {adims}, {bdims}, s);
}
array tensordot(
const array& a,
const array& b,
const std::pair<std::vector<int>, std::vector<int>>& dims,
StreamOrDevice s /* = {} */
) {
if (dims.first.size() != dims.second.size()) {
throw std::invalid_argument(
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
const std::vector<int>& axes_a,
const std::vector<int>& axes_b,
StreamOrDevice s /* = {} */) {
if (axes_a.size() != axes_b.size()) {
throw std::invalid_argument("[tensordot] axes must have the same size.");
}
int csize = 1;
auto x = a;
auto y = b;
for (int i = 0; i < dims.first.size(); i++) {
if (x.shape(dims.first.at(i)) == y.shape(dims.second.at(i))) {
csize *= x.shape(dims.first.at(i));
for (int i = 0; i < axes_a.size(); i++) {
if (x.shape(axes_a.at(i)) == y.shape(axes_b.at(i))) {
csize *= x.shape(axes_a.at(i));
} else {
throw std::invalid_argument(
"[tensordot] a and b must have the same shape on the contracted axes.");
@@ -3234,11 +3233,11 @@ array tensordot(
std::vector<bool> cdims1(x.ndim(), false);
std::vector<bool> cdims2(y.ndim(), false);
for (const auto n : dims.first) {
for (const auto n : axes_a) {
int n_ = (n < 0) ? n + x.ndim() : n;
cdims1[n_] = true;
}
for (const auto n : dims.second) {
for (const auto n : axes_b) {
int n_ = (n < 0) ? n + y.ndim() : n;
cdims2[n_] = true;
}
@@ -3255,10 +3254,10 @@ array tensordot(
rshape.emplace_back(a.shape(i));
}
}
for (const auto x : dims.first) {
for (const auto x : axes_a) {
t1.emplace_back(x);
}
for (const auto x : dims.second) {
for (const auto x : axes_b) {
t2.emplace_back(x);
}
for (int i = 0; i < b.ndim(); i++) {
@@ -3287,7 +3286,7 @@ array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
"[inner] a and b must have the same last dimension.");
}
return tensordot(a, b, {{-1}, {-1}}, s);
return tensordot(a, b, {-1}, {-1}, s);
}
/** Compute D = beta * C + alpha * (A @ B) */

View File

@@ -1110,17 +1110,18 @@ array dequantize(
int bits = 4,
StreamOrDevice s = {});
/** TensorDot returns a contraction of a and b over multiple dimensions. */
/** Returns a contraction of a and b over multiple dimensions. */
array tensordot(
const array& a,
const array& b,
const int dims = 2,
const int axis = 2,
StreamOrDevice s = {});
array tensordot(
const array& a,
const array& b,
const std::pair<std::vector<int>, std::vector<int>>& dims,
const std::vector<int>& axes_a,
const std::vector<int>& axes_b,
StreamOrDevice s = {});
/** Compute the outer product of two vectors. */