implemented changes

This commit is contained in:
dc-dc-dc 2024-01-02 18:11:54 -05:00
parent 722938892c
commit 34108e780c
2 changed files with 6 additions and 23 deletions

View File

@ -2838,14 +2838,11 @@ array tensordot(
auto x = a; auto x = a;
auto y = b; auto y = b;
for (int i = 0; i < dims[0].size(); i++) { for (int i = 0; i < dims[0].size(); i++) {
size_t xs = x.shape(dims[0].at(i)); if (x.shape(dims[0].at(i)) == y.shape(dims[1].at(i))) {
size_t ys = y.shape(dims[1].at(i)); csize *= x.shape(dims[0].at(i));
if (ys == 1) {
x = sum(x, dims[0].at(i), true, s);
} else if (xs == 1) {
y = sum(y, dims[1].at(i), true, s);
} else { } else {
csize *= xs; throw std::invalid_argument(
"[tensordot] a and b must have the same shape on the contracted axes.");
} }
} }
@ -2861,11 +2858,8 @@ array tensordot(
} }
std::vector<int> t1; std::vector<int> t1;
t1.reserve(a.ndim());
std::vector<int> t2; std::vector<int> t2;
t2.reserve(b.ndim());
std::vector<int> rshape; std::vector<int> rshape;
rshape.reserve(a.ndim() + b.ndim() * dims[0].size());
int size1 = 1; int size1 = 1;
int size2 = 1; int size2 = 1;
for (int i = 0; i < a.ndim(); i++) { for (int i = 0; i < a.ndim(); i++) {

View File

@ -2288,19 +2288,8 @@ TEST_CASE("tensordot") {
CHECK(array_equal(z, expected).item<bool>()); CHECK(array_equal(z, expected).item<bool>());
x = reshape(arange(360.), {3, 4, 5, 6}); x = reshape(arange(360.), {3, 4, 5, 6});
y = reshape(arange(360.), {6, 4, 5, 3}); y = reshape(arange(360.), {6, 4, 5, 3});
z = tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}); CHECK_THROWS_AS(
expected = array( tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument);
{1326270,
1333410,
1340550,
3896670,
3918210,
3939750,
6467070,
6503010,
6538950},
{3, 3});
CHECK(array_equal(z, expected).item<bool>());
x = reshape(arange(60.), {3, 4, 5}); x = reshape(arange(60.), {3, 4, 5});
y = reshape(arange(120.), {4, 5, 6}); y = reshape(arange(120.), {4, 5, 6});
z = tensordot(x, y, 2); z = tensordot(x, y, 2);