mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
support arange for bfloat16 (#245)
This commit is contained in:
@@ -18,6 +18,12 @@ TEST_CASE("test arange") {
|
||||
x = arange(10, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(10, float16);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
|
||||
x = arange(10, bfloat16);
|
||||
CHECK_EQ(x.dtype(), bfloat16);
|
||||
|
||||
x = arange(10.0, int32);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
@@ -105,6 +111,15 @@ TEST_CASE("test arange") {
|
||||
|
||||
x = arange(0.0, 5.0, 1.5, int32);
|
||||
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 1.0, float16);
|
||||
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, float16)).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 1.0, bfloat16);
|
||||
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, bfloat16)).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 1.5, bfloat16);
|
||||
CHECK(array_equal(x, array({0., 1.5, 3., 4.5}, bfloat16)).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user