mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
nice tensordot for mlx c (#782)
This commit is contained in:
parent
6a665ea6ed
commit
5121f028d9
43
mlx/ops.cpp
43
mlx/ops.cpp
@ -3190,42 +3190,41 @@ array dequantize(
|
|||||||
array tensordot(
|
array tensordot(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const int dims /* = 2 */,
|
const int axis /* = 2 */,
|
||||||
StreamOrDevice s /* = {} */
|
StreamOrDevice s /* = {} */
|
||||||
) {
|
) {
|
||||||
if (dims < 0) {
|
if (axis < 0) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[tensordot] dims must be greater or equal to 0.");
|
"[tensordot] axis must be greater or equal to 0.");
|
||||||
}
|
}
|
||||||
if (dims > std::min(a.ndim(), b.ndim())) {
|
if (axis > std::min(a.ndim(), b.ndim())) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[tensordot] dims must be less than the number of dimensions of a and b.");
|
"[tensordot] axis must be less than the number of dimensions of a and b.");
|
||||||
}
|
}
|
||||||
std::vector<int> adims;
|
std::vector<int> adims;
|
||||||
std::vector<int> bdims;
|
std::vector<int> bdims;
|
||||||
for (int i = 0; i < dims; i++) {
|
for (int i = 0; i < axis; i++) {
|
||||||
bdims.emplace_back(i);
|
bdims.emplace_back(i);
|
||||||
adims.emplace_back(-dims + i);
|
adims.emplace_back(i - axis);
|
||||||
}
|
}
|
||||||
return tensordot(a, b, {adims, bdims}, s);
|
return tensordot(a, b, {adims}, {bdims}, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array tensordot(
|
array tensordot(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
const std::vector<int>& axes_a,
|
||||||
StreamOrDevice s /* = {} */
|
const std::vector<int>& axes_b,
|
||||||
) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (dims.first.size() != dims.second.size()) {
|
if (axes_a.size() != axes_b.size()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument("[tensordot] axes must have the same size.");
|
||||||
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
|
|
||||||
}
|
}
|
||||||
int csize = 1;
|
int csize = 1;
|
||||||
auto x = a;
|
auto x = a;
|
||||||
auto y = b;
|
auto y = b;
|
||||||
for (int i = 0; i < dims.first.size(); i++) {
|
for (int i = 0; i < axes_a.size(); i++) {
|
||||||
if (x.shape(dims.first.at(i)) == y.shape(dims.second.at(i))) {
|
if (x.shape(axes_a.at(i)) == y.shape(axes_b.at(i))) {
|
||||||
csize *= x.shape(dims.first.at(i));
|
csize *= x.shape(axes_a.at(i));
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[tensordot] a and b must have the same shape on the contracted axes.");
|
"[tensordot] a and b must have the same shape on the contracted axes.");
|
||||||
@ -3234,11 +3233,11 @@ array tensordot(
|
|||||||
|
|
||||||
std::vector<bool> cdims1(x.ndim(), false);
|
std::vector<bool> cdims1(x.ndim(), false);
|
||||||
std::vector<bool> cdims2(y.ndim(), false);
|
std::vector<bool> cdims2(y.ndim(), false);
|
||||||
for (const auto n : dims.first) {
|
for (const auto n : axes_a) {
|
||||||
int n_ = (n < 0) ? n + x.ndim() : n;
|
int n_ = (n < 0) ? n + x.ndim() : n;
|
||||||
cdims1[n_] = true;
|
cdims1[n_] = true;
|
||||||
}
|
}
|
||||||
for (const auto n : dims.second) {
|
for (const auto n : axes_b) {
|
||||||
int n_ = (n < 0) ? n + y.ndim() : n;
|
int n_ = (n < 0) ? n + y.ndim() : n;
|
||||||
cdims2[n_] = true;
|
cdims2[n_] = true;
|
||||||
}
|
}
|
||||||
@ -3255,10 +3254,10 @@ array tensordot(
|
|||||||
rshape.emplace_back(a.shape(i));
|
rshape.emplace_back(a.shape(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const auto x : dims.first) {
|
for (const auto x : axes_a) {
|
||||||
t1.emplace_back(x);
|
t1.emplace_back(x);
|
||||||
}
|
}
|
||||||
for (const auto x : dims.second) {
|
for (const auto x : axes_b) {
|
||||||
t2.emplace_back(x);
|
t2.emplace_back(x);
|
||||||
}
|
}
|
||||||
for (int i = 0; i < b.ndim(); i++) {
|
for (int i = 0; i < b.ndim(); i++) {
|
||||||
@ -3287,7 +3286,7 @@ array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
"[inner] a and b must have the same last dimension.");
|
"[inner] a and b must have the same last dimension.");
|
||||||
}
|
}
|
||||||
|
|
||||||
return tensordot(a, b, {{-1}, {-1}}, s);
|
return tensordot(a, b, {-1}, {-1}, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Compute D = beta * C + alpha * (A @ B) */
|
/** Compute D = beta * C + alpha * (A @ B) */
|
||||||
|
@ -1110,17 +1110,18 @@ array dequantize(
|
|||||||
int bits = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** TensorDot returns a contraction of a and b over multiple dimensions. */
|
/** Returns a contraction of a and b over multiple dimensions. */
|
||||||
array tensordot(
|
array tensordot(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const int dims = 2,
|
const int axis = 2,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
array tensordot(
|
array tensordot(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
const std::vector<int>& axes_a,
|
||||||
|
const std::vector<int>& axes_b,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Compute the outer product of two vectors. */
|
/** Compute the outer product of two vectors. */
|
||||||
|
@ -3555,17 +3555,17 @@ void init_ops(py::module_& m) {
|
|||||||
"tensordot",
|
"tensordot",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
const std::variant<int, std::vector<std::vector<int>>>& dims,
|
const std::variant<int, std::vector<std::vector<int>>>& axes,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (auto pv = std::get_if<int>(&dims); pv) {
|
if (auto pv = std::get_if<int>(&axes); pv) {
|
||||||
return tensordot(a, b, *pv, s);
|
return tensordot(a, b, *pv, s);
|
||||||
} else {
|
} else {
|
||||||
auto x = std::get<std::vector<std::vector<int>>>(dims);
|
auto& x = std::get<std::vector<std::vector<int>>>(axes);
|
||||||
if (x.size() != 2) {
|
if (x.size() != 2) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[tensordot] dims must be a list of two lists.");
|
"[tensordot] axes must be a list of two lists.");
|
||||||
}
|
}
|
||||||
return tensordot(a, b, {x[0], x[1]}, s);
|
return tensordot(a, b, x[0], x[1], s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
|
@ -2554,14 +2554,13 @@ TEST_CASE("tile") {
|
|||||||
TEST_CASE("tensordot") {
|
TEST_CASE("tensordot") {
|
||||||
auto x = reshape(arange(60.), {3, 4, 5});
|
auto x = reshape(arange(60.), {3, 4, 5});
|
||||||
auto y = reshape(arange(24.), {4, 3, 2});
|
auto y = reshape(arange(24.), {4, 3, 2});
|
||||||
auto z = tensordot(x, y, {{1, 0}, {0, 1}});
|
auto z = tensordot(x, y, {1, 0}, {0, 1});
|
||||||
auto expected = array(
|
auto expected = array(
|
||||||
{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});
|
{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});
|
||||||
CHECK(array_equal(z, expected).item<bool>());
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
x = reshape(arange(360.), {3, 4, 5, 6});
|
x = reshape(arange(360.), {3, 4, 5, 6});
|
||||||
y = reshape(arange(360.), {6, 4, 5, 3});
|
y = reshape(arange(360.), {6, 4, 5, 3});
|
||||||
CHECK_THROWS_AS(
|
CHECK_THROWS_AS(tensordot(x, y, {2, 1, 3}, {1, 2, 0}), std::invalid_argument);
|
||||||
tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument);
|
|
||||||
x = reshape(arange(60.), {3, 4, 5});
|
x = reshape(arange(60.), {3, 4, 5});
|
||||||
y = reshape(arange(120.), {4, 5, 6});
|
y = reshape(arange(120.), {4, 5, 6});
|
||||||
z = tensordot(x, y, 2);
|
z = tensordot(x, y, 2);
|
||||||
|
Loading…
Reference in New Issue
Block a user