Pickle + dtype fix for numpy conversion (#763)

* pickle + dtype fix for numpy conversion

* fix getattribute on Module base

* remove unused function

* fix tests

* add topk to ops

* fix doc
This commit is contained in:
Awni Hannun
2024-03-02 06:09:29 -08:00
committed by GitHub
parent 8e281c76c3
commit bc06cb9ff6
7 changed files with 99 additions and 39 deletions

View File

@@ -248,11 +248,9 @@ TEST_CASE("test random uniform") {
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float32);
if (is_available(float16)) {
x = random::uniform({}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
}
x = random::uniform({}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
x = random::uniform({0});
CHECK(array_equal(x, array({})).item<bool>());
@@ -467,11 +465,9 @@ TEST_CASE("test random bernoulli") {
CHECK_EQ(x.dtype(), bool_);
// Bernoulli parameter can have floating point type
if (is_available(float16)) {
x = random::bernoulli(array(0.5, float16));
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), bool_);
}
x = random::bernoulli(array(0.5, float16));
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), bool_);
CHECK_THROWS(random::bernoulli(array(1, int32)));
@@ -513,11 +509,9 @@ TEST_CASE("Test truncated normal") {
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float32);
if (is_available(float16)) {
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
}
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
// Requested shape
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});