mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
support arange for bfloat16 (#245)
This commit is contained in:
parent
2c7df6795e
commit
794feb83df
@ -215,7 +215,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case bfloat16:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16");
|
||||
arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
|
||||
}
|
||||
|
@ -18,6 +18,12 @@ TEST_CASE("test arange") {
|
||||
x = arange(10, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(10, float16);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
|
||||
x = arange(10, bfloat16);
|
||||
CHECK_EQ(x.dtype(), bfloat16);
|
||||
|
||||
x = arange(10.0, int32);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
@ -105,6 +111,15 @@ TEST_CASE("test arange") {
|
||||
|
||||
x = arange(0.0, 5.0, 1.5, int32);
|
||||
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 1.0, float16);
|
||||
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, float16)).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 1.0, bfloat16);
|
||||
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, bfloat16)).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 1.5, bfloat16);
|
||||
CHECK(array_equal(x, array({0., 1.5, 3., 4.5}, bfloat16)).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user