implemented changes

This commit is contained in:
dc-dc-dc 2024-01-02 18:11:54 -05:00
parent 722938892c
commit 34108e780c
2 changed files with 6 additions and 23 deletions

View File

@ -2838,14 +2838,11 @@ array tensordot(
auto x = a;
auto y = b;
for (int i = 0; i < dims[0].size(); i++) {
size_t xs = x.shape(dims[0].at(i));
size_t ys = y.shape(dims[1].at(i));
if (ys == 1) {
x = sum(x, dims[0].at(i), true, s);
} else if (xs == 1) {
y = sum(y, dims[1].at(i), true, s);
if (x.shape(dims[0].at(i)) == y.shape(dims[1].at(i))) {
csize *= x.shape(dims[0].at(i));
} else {
csize *= xs;
throw std::invalid_argument(
"[tensordot] a and b must have the same shape on the contracted axes.");
}
}
@ -2861,11 +2858,8 @@ array tensordot(
}
std::vector<int> t1;
t1.reserve(a.ndim());
std::vector<int> t2;
t2.reserve(b.ndim());
std::vector<int> rshape;
rshape.reserve(a.ndim() + b.ndim() * dims[0].size());
int size1 = 1;
int size2 = 1;
for (int i = 0; i < a.ndim(); i++) {

View File

@ -2288,19 +2288,8 @@ TEST_CASE("tensordot") {
CHECK(array_equal(z, expected).item<bool>());
x = reshape(arange(360.), {3, 4, 5, 6});
y = reshape(arange(360.), {6, 4, 5, 3});
z = tensordot(x, y, {{2, 1, 3}, {1, 2, 0}});
expected = array(
{1326270,
1333410,
1340550,
3896670,
3918210,
3939750,
6467070,
6503010,
6538950},
{3, 3});
CHECK(array_equal(z, expected).item<bool>());
CHECK_THROWS_AS(
tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument);
x = reshape(arange(60.), {3, 4, 5});
y = reshape(arange(120.), {4, 5, 6});
z = tensordot(x, y, 2);