Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -617,7 +617,7 @@ TEST_CASE("test op vjps") {
axes = {0};
out = vjp(fun, array({}), array(3.0f)).second;
CHECK_EQ(out.size(), 0);
CHECK_EQ(out.shape(), std::vector<int>{0});
CHECK_EQ(out.shape(), Shape{0});
axes = {0};
out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}))
@@ -725,9 +725,9 @@ TEST_CASE("test gather and take grads") {
}
TEST_CASE("test slice grads") {
std::vector<int> start = {5, 0, 0};
std::vector<int> stop = {7, 2, 4};
std::vector<int> strides = {1, 1, 1};
Shape start = {5, 0, 0};
Shape stop = {7, 2, 4};
Shape strides = {1, 1, 1};
auto fn = [&start, &stop, &strides](array input) {
return slice(input, start, stop, strides);
@@ -982,8 +982,8 @@ TEST_CASE("test comparison grads") {
TEST_CASE("test as_strided grads") {
auto x = ones({11});
std::vector<int> shape = {5, 5};
std::vector<size_t> strides = {1, 1};
Shape shape = {5, 5};
Strides strides = {1, 1};
size_t offset = 0;
auto fun = [&shape, &strides, &offset](array x) {