support arange for bfloat16

This commit is contained in:
Daniel Strobusch
2023-12-21 19:57:07 +01:00
parent b3916cbf2b
commit a4be1ee231
2 changed files with 17 additions and 1 deletions

View File

@@ -215,7 +215,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
break;
case bfloat16:
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16");
arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
break;
case complex64:
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
}

View File

@@ -18,6 +18,12 @@ TEST_CASE("test arange") {
x = arange(10, float32);
CHECK_EQ(x.dtype(), float32);
x = arange(10, float16);
CHECK_EQ(x.dtype(), float16);
x = arange(10, bfloat16);
CHECK_EQ(x.dtype(), bfloat16);
x = arange(10.0, int32);
CHECK_EQ(x.dtype(), int32);
@@ -105,6 +111,15 @@ TEST_CASE("test arange") {
x = arange(0.0, 5.0, 1.5, int32);
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
x = arange(0.0, 5.0, 1.0, float16);
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, float16)).item<bool>());
x = arange(0.0, 5.0, 1.0, bfloat16);
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, bfloat16)).item<bool>());
x = arange(0.0, 5.0, 1.5, bfloat16);
CHECK(array_equal(x, array({0., 1.5, 3., 4.5}, bfloat16)).item<bool>());
}
}