mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
implemented changes
This commit is contained in:
parent
722938892c
commit
34108e780c
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -2838,14 +2838,11 @@ array tensordot(
|
|||||||
auto x = a;
|
auto x = a;
|
||||||
auto y = b;
|
auto y = b;
|
||||||
for (int i = 0; i < dims[0].size(); i++) {
|
for (int i = 0; i < dims[0].size(); i++) {
|
||||||
size_t xs = x.shape(dims[0].at(i));
|
if (x.shape(dims[0].at(i)) == y.shape(dims[1].at(i))) {
|
||||||
size_t ys = y.shape(dims[1].at(i));
|
csize *= x.shape(dims[0].at(i));
|
||||||
if (ys == 1) {
|
|
||||||
x = sum(x, dims[0].at(i), true, s);
|
|
||||||
} else if (xs == 1) {
|
|
||||||
y = sum(y, dims[1].at(i), true, s);
|
|
||||||
} else {
|
} else {
|
||||||
csize *= xs;
|
throw std::invalid_argument(
|
||||||
|
"[tensordot] a and b must have the same shape on the contracted axes.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2861,11 +2858,8 @@ array tensordot(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> t1;
|
std::vector<int> t1;
|
||||||
t1.reserve(a.ndim());
|
|
||||||
std::vector<int> t2;
|
std::vector<int> t2;
|
||||||
t2.reserve(b.ndim());
|
|
||||||
std::vector<int> rshape;
|
std::vector<int> rshape;
|
||||||
rshape.reserve(a.ndim() + b.ndim() * dims[0].size());
|
|
||||||
int size1 = 1;
|
int size1 = 1;
|
||||||
int size2 = 1;
|
int size2 = 1;
|
||||||
for (int i = 0; i < a.ndim(); i++) {
|
for (int i = 0; i < a.ndim(); i++) {
|
||||||
|
@ -2288,19 +2288,8 @@ TEST_CASE("tensordot") {
|
|||||||
CHECK(array_equal(z, expected).item<bool>());
|
CHECK(array_equal(z, expected).item<bool>());
|
||||||
x = reshape(arange(360.), {3, 4, 5, 6});
|
x = reshape(arange(360.), {3, 4, 5, 6});
|
||||||
y = reshape(arange(360.), {6, 4, 5, 3});
|
y = reshape(arange(360.), {6, 4, 5, 3});
|
||||||
z = tensordot(x, y, {{2, 1, 3}, {1, 2, 0}});
|
CHECK_THROWS_AS(
|
||||||
expected = array(
|
tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument);
|
||||||
{1326270,
|
|
||||||
1333410,
|
|
||||||
1340550,
|
|
||||||
3896670,
|
|
||||||
3918210,
|
|
||||||
3939750,
|
|
||||||
6467070,
|
|
||||||
6503010,
|
|
||||||
6538950},
|
|
||||||
{3, 3});
|
|
||||||
CHECK(array_equal(z, expected).item<bool>());
|
|
||||||
x = reshape(arange(60.), {3, 4, 5});
|
x = reshape(arange(60.), {3, 4, 5});
|
||||||
y = reshape(arange(120.), {4, 5, 6});
|
y = reshape(arange(120.), {4, 5, 6});
|
||||||
z = tensordot(x, y, 2);
|
z = tensordot(x, y, 2);
|
||||||
|
Loading…
Reference in New Issue
Block a user