Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -141,7 +141,7 @@ TEST_CASE("test random bits") {
{
auto key = array({0u, 0u, 1u, 1u}, {2, 2});
auto shape = std::vector<int>{3};
auto shape = Shape{3};
auto fn = [&shape](array k) { return random::bits(shape, k); };
auto expected = array(
@@ -264,7 +264,7 @@ TEST_CASE("test random uniform") {
// Check broadcasting
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
CHECK_EQ(x.shape(), Shape{3, 3});
CHECK_THROWS_AS(
random::uniform(zeros({3, 3}), 1.0, {1, 3}), std::invalid_argument);
CHECK_THROWS_AS(
@@ -332,11 +332,11 @@ TEST_CASE("test random uniform") {
return random::uniform(low, 1, {3}, float32, k);
};
auto out = vmap(fun, -1)(key, zeros({2, 3}));
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
CHECK_EQ(out.shape(), Shape{2, 3});
key = zeros({2, 2}, uint32);
out = vmap(fun)(key, zeros({2, 3}));
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
CHECK_EQ(out.shape(), Shape{2, 3});
}
// Check bounds are respected
@@ -425,7 +425,7 @@ TEST_CASE("test random multivariate_normal") {
auto mean = zeros({3});
auto cov = eye(3);
auto x = random::multivariate_normal(mean, cov, {1000}, float32);
CHECK_EQ(x.shape(), std::vector<int>({1000, 3}));
CHECK_EQ(x.shape(), Shape{1000, 3});
CHECK_EQ(x.dtype(), float32);
}
@@ -435,7 +435,7 @@ TEST_CASE("test random multivariate_normal") {
auto cov = array({1., -1, -.1, 1.});
cov = reshape(cov, {2, 2});
auto x = random::multivariate_normal(mean, cov, {1}, float32);
CHECK_EQ(x.shape(), std::vector<int>({1, 2}));
CHECK_EQ(x.shape(), Shape{1, 2});
CHECK_EQ(x.dtype(), float32);
}
@@ -457,7 +457,7 @@ TEST_CASE("test random multivariate_normal") {
auto mean = zeros({3});
auto cov = zeros({1, 2, 3, 3});
auto x = random::multivariate_normal(mean, cov, {1000, 2}, float32);
CHECK_EQ(x.shape(), std::vector<int>({1000, 2, 3}));
CHECK_EQ(x.shape(), Shape{1000, 2, 3});
}
{
auto mean = zeros({3});
@@ -537,7 +537,7 @@ TEST_CASE("test random bernoulli") {
// Return array with correct shape
x = random::bernoulli(0.5, {3, 3});
CHECK_EQ(x.shape(), std::vector<int>({3, 3}));
CHECK_EQ(x.shape(), Shape{3, 3});
// Try with p = {}
x = random::bernoulli(array({}));
@@ -547,7 +547,7 @@ TEST_CASE("test random bernoulli") {
auto p = array({0.1, 0.2, 0.3});
p = reshape(p, {1, 3});
x = random::bernoulli(p, {4, 3});
CHECK_EQ(x.shape(), std::vector<int>({4, 3}));
CHECK_EQ(x.shape(), Shape{4, 3});
CHECK_THROWS_AS(random::bernoulli(array({}), {3, 3}), std::invalid_argument);
@@ -572,7 +572,7 @@ TEST_CASE("Test truncated normal") {
// Requested shape
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
CHECK_EQ(x.shape(), std::vector<int>({3, 4}));
CHECK_EQ(x.shape(), Shape{3, 4});
// Empty array
x = random::truncated_normal(array({}), array({}));
@@ -584,7 +584,7 @@ TEST_CASE("Test truncated normal") {
x = random::truncated_normal(lower, higher);
// All in bounds
CHECK_EQ(x.shape(), std::vector<int>({3, 2}));
CHECK_EQ(x.shape(), Shape{3, 2});
CHECK((all(x <= higher).item<bool>() && all(lower <= x).item<bool>()));
// high < low => all equal to low
@@ -615,17 +615,17 @@ TEST_CASE("test categorical") {
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
CHECK_THROWS(categorical(logits, 1, {10, 1}));
CHECK_EQ(categorical(logits, -1).shape(), std::vector<int>{10});
CHECK_EQ(categorical(logits, 0).shape(), std::vector<int>{20});
CHECK_EQ(categorical(logits, 1).shape(), std::vector<int>{10});
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});
CHECK_EQ(categorical(logits, 0).shape(), Shape{20});
CHECK_EQ(categorical(logits, 1).shape(), Shape{10});
auto out = categorical(logits);
CHECK_EQ(out.shape(), std::vector<int>{10});
CHECK_EQ(out.shape(), Shape{10});
CHECK_EQ(out.dtype(), uint32);
CHECK(max(out).item<uint32_t>() < 20);
out = categorical(logits, 0, {5, 20});
CHECK_EQ(out.shape(), std::vector<int>{5, 20});
CHECK_EQ(out.shape(), Shape{5, 20});
CHECK(max(out).item<uint32_t>() < 10);
float inf = std::numeric_limits<float>::infinity();
@@ -636,9 +636,9 @@ TEST_CASE("test categorical") {
CHECK_EQ(categorical(logits).item<uint32_t>(), 1);
logits = zeros({5, 4, 3});
CHECK_EQ(categorical(logits, -1, 7).shape(), std::vector<int>{5, 4, 7});
CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector<int>{5, 3, 7});
CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector<int>{4, 3, 7});
CHECK_EQ(categorical(logits, -1, 7).shape(), Shape{5, 4, 7});
CHECK_EQ(categorical(logits, -2, 7).shape(), Shape{5, 3, 7});
CHECK_EQ(categorical(logits, -3, 7).shape(), Shape{4, 3, 7});
}
TEST_CASE("test laplace") {