mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-03 22:34:43 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -141,7 +141,7 @@ TEST_CASE("test random bits") {
|
||||
|
||||
{
|
||||
auto key = array({0u, 0u, 1u, 1u}, {2, 2});
|
||||
auto shape = std::vector<int>{3};
|
||||
auto shape = Shape{3};
|
||||
auto fn = [&shape](array k) { return random::bits(shape, k); };
|
||||
|
||||
auto expected = array(
|
||||
@@ -264,7 +264,7 @@ TEST_CASE("test random uniform") {
|
||||
|
||||
// Check broadcasting
|
||||
x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>{3, 3});
|
||||
CHECK_EQ(x.shape(), Shape{3, 3});
|
||||
CHECK_THROWS_AS(
|
||||
random::uniform(zeros({3, 3}), 1.0, {1, 3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(
|
||||
@@ -332,11 +332,11 @@ TEST_CASE("test random uniform") {
|
||||
return random::uniform(low, 1, {3}, float32, k);
|
||||
};
|
||||
auto out = vmap(fun, -1)(key, zeros({2, 3}));
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(out.shape(), Shape{2, 3});
|
||||
|
||||
key = zeros({2, 2}, uint32);
|
||||
out = vmap(fun)(key, zeros({2, 3}));
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3});
|
||||
CHECK_EQ(out.shape(), Shape{2, 3});
|
||||
}
|
||||
|
||||
// Check bounds are respected
|
||||
@@ -425,7 +425,7 @@ TEST_CASE("test random multivariate_normal") {
|
||||
auto mean = zeros({3});
|
||||
auto cov = eye(3);
|
||||
auto x = random::multivariate_normal(mean, cov, {1000}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>({1000, 3}));
|
||||
CHECK_EQ(x.shape(), Shape{1000, 3});
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
}
|
||||
|
||||
@@ -435,7 +435,7 @@ TEST_CASE("test random multivariate_normal") {
|
||||
auto cov = array({1., -1, -.1, 1.});
|
||||
cov = reshape(cov, {2, 2});
|
||||
auto x = random::multivariate_normal(mean, cov, {1}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>({1, 2}));
|
||||
CHECK_EQ(x.shape(), Shape{1, 2});
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
}
|
||||
|
||||
@@ -457,7 +457,7 @@ TEST_CASE("test random multivariate_normal") {
|
||||
auto mean = zeros({3});
|
||||
auto cov = zeros({1, 2, 3, 3});
|
||||
auto x = random::multivariate_normal(mean, cov, {1000, 2}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>({1000, 2, 3}));
|
||||
CHECK_EQ(x.shape(), Shape{1000, 2, 3});
|
||||
}
|
||||
{
|
||||
auto mean = zeros({3});
|
||||
@@ -537,7 +537,7 @@ TEST_CASE("test random bernoulli") {
|
||||
|
||||
// Return array with correct shape
|
||||
x = random::bernoulli(0.5, {3, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>({3, 3}));
|
||||
CHECK_EQ(x.shape(), Shape{3, 3});
|
||||
|
||||
// Try with p = {}
|
||||
x = random::bernoulli(array({}));
|
||||
@@ -547,7 +547,7 @@ TEST_CASE("test random bernoulli") {
|
||||
auto p = array({0.1, 0.2, 0.3});
|
||||
p = reshape(p, {1, 3});
|
||||
x = random::bernoulli(p, {4, 3});
|
||||
CHECK_EQ(x.shape(), std::vector<int>({4, 3}));
|
||||
CHECK_EQ(x.shape(), Shape{4, 3});
|
||||
|
||||
CHECK_THROWS_AS(random::bernoulli(array({}), {3, 3}), std::invalid_argument);
|
||||
|
||||
@@ -572,7 +572,7 @@ TEST_CASE("Test truncated normal") {
|
||||
|
||||
// Requested shape
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
|
||||
CHECK_EQ(x.shape(), std::vector<int>({3, 4}));
|
||||
CHECK_EQ(x.shape(), Shape{3, 4});
|
||||
|
||||
// Empty array
|
||||
x = random::truncated_normal(array({}), array({}));
|
||||
@@ -584,7 +584,7 @@ TEST_CASE("Test truncated normal") {
|
||||
x = random::truncated_normal(lower, higher);
|
||||
|
||||
// All in bounds
|
||||
CHECK_EQ(x.shape(), std::vector<int>({3, 2}));
|
||||
CHECK_EQ(x.shape(), Shape{3, 2});
|
||||
CHECK((all(x <= higher).item<bool>() && all(lower <= x).item<bool>()));
|
||||
|
||||
// high < low => all equal to low
|
||||
@@ -615,17 +615,17 @@ TEST_CASE("test categorical") {
|
||||
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
|
||||
CHECK_THROWS(categorical(logits, 1, {10, 1}));
|
||||
|
||||
CHECK_EQ(categorical(logits, -1).shape(), std::vector<int>{10});
|
||||
CHECK_EQ(categorical(logits, 0).shape(), std::vector<int>{20});
|
||||
CHECK_EQ(categorical(logits, 1).shape(), std::vector<int>{10});
|
||||
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});
|
||||
CHECK_EQ(categorical(logits, 0).shape(), Shape{20});
|
||||
CHECK_EQ(categorical(logits, 1).shape(), Shape{10});
|
||||
|
||||
auto out = categorical(logits);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{10});
|
||||
CHECK_EQ(out.shape(), Shape{10});
|
||||
CHECK_EQ(out.dtype(), uint32);
|
||||
CHECK(max(out).item<uint32_t>() < 20);
|
||||
|
||||
out = categorical(logits, 0, {5, 20});
|
||||
CHECK_EQ(out.shape(), std::vector<int>{5, 20});
|
||||
CHECK_EQ(out.shape(), Shape{5, 20});
|
||||
CHECK(max(out).item<uint32_t>() < 10);
|
||||
|
||||
float inf = std::numeric_limits<float>::infinity();
|
||||
@@ -636,9 +636,9 @@ TEST_CASE("test categorical") {
|
||||
CHECK_EQ(categorical(logits).item<uint32_t>(), 1);
|
||||
|
||||
logits = zeros({5, 4, 3});
|
||||
CHECK_EQ(categorical(logits, -1, 7).shape(), std::vector<int>{5, 4, 7});
|
||||
CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector<int>{5, 3, 7});
|
||||
CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector<int>{4, 3, 7});
|
||||
CHECK_EQ(categorical(logits, -1, 7).shape(), Shape{5, 4, 7});
|
||||
CHECK_EQ(categorical(logits, -2, 7).shape(), Shape{5, 3, 7});
|
||||
CHECK_EQ(categorical(logits, -3, 7).shape(), Shape{4, 3, 7});
|
||||
}
|
||||
|
||||
TEST_CASE("test laplace") {
|
||||
|
Reference in New Issue
Block a user