support arange for bfloat16 (#245)

This commit is contained in:
Daniel Strobusch 2023-12-21 23:33:43 +01:00 committed by GitHub
parent 2c7df6795e
commit 794feb83df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 1 deletions

View File

@ -215,7 +215,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
arange_set_scalars<float>(start_, start_ + step_, compute_encoder); arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
break; break;
case bfloat16: case bfloat16:
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16"); arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
break;
case complex64: case complex64:
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64"); throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
} }

View File

@ -18,6 +18,12 @@ TEST_CASE("test arange") {
x = arange(10, float32); x = arange(10, float32);
CHECK_EQ(x.dtype(), float32); CHECK_EQ(x.dtype(), float32);
x = arange(10, float16);
CHECK_EQ(x.dtype(), float16);
x = arange(10, bfloat16);
CHECK_EQ(x.dtype(), bfloat16);
x = arange(10.0, int32); x = arange(10.0, int32);
CHECK_EQ(x.dtype(), int32); CHECK_EQ(x.dtype(), int32);
@ -105,6 +111,15 @@ TEST_CASE("test arange") {
x = arange(0.0, 5.0, 1.5, int32); x = arange(0.0, 5.0, 1.5, int32);
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>()); CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
x = arange(0.0, 5.0, 1.0, float16);
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, float16)).item<bool>());
x = arange(0.0, 5.0, 1.0, bfloat16);
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, bfloat16)).item<bool>());
x = arange(0.0, 5.0, 1.5, bfloat16);
CHECK(array_equal(x, array({0., 1.5, 3., 4.5}, bfloat16)).item<bool>());
} }
} }