mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Pickle + dtype fix for numpy conversion (#763)
* pickle + dtype fix for numpy conversion * fix getattribute on Module base * remove unused function * fix tests * add topk to ops * fix doc
This commit is contained in:
@@ -248,11 +248,9 @@ TEST_CASE("test random uniform") {
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
if (is_available(float16)) {
|
||||
x = random::uniform({}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
}
|
||||
x = random::uniform({}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
|
||||
x = random::uniform({0});
|
||||
CHECK(array_equal(x, array({})).item<bool>());
|
||||
@@ -467,11 +465,9 @@ TEST_CASE("test random bernoulli") {
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
|
||||
// Bernoulli parameter can have floating point type
|
||||
if (is_available(float16)) {
|
||||
x = random::bernoulli(array(0.5, float16));
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
}
|
||||
x = random::bernoulli(array(0.5, float16));
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
|
||||
CHECK_THROWS(random::bernoulli(array(1, int32)));
|
||||
|
||||
@@ -513,11 +509,9 @@ TEST_CASE("Test truncated normal") {
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
if (is_available(float16)) {
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
}
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
|
||||
// Requested shape
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
|
||||
|
Reference in New Issue
Block a user