support arange for bfloat16 (#245)

This commit is contained in:
Daniel Strobusch
2023-12-21 23:33:43 +01:00
committed by GitHub
parent 2c7df6795e
commit 794feb83df
2 changed files with 17 additions and 1 deletions

View File

@@ -18,6 +18,12 @@ TEST_CASE("test arange") {
x = arange(10, float32);
CHECK_EQ(x.dtype(), float32);
x = arange(10, float16);
CHECK_EQ(x.dtype(), float16);
x = arange(10, bfloat16);
CHECK_EQ(x.dtype(), bfloat16);
x = arange(10.0, int32);
CHECK_EQ(x.dtype(), int32);
@@ -105,6 +111,15 @@ TEST_CASE("test arange") {
x = arange(0.0, 5.0, 1.5, int32);
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
x = arange(0.0, 5.0, 1.0, float16);
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, float16)).item<bool>());
x = arange(0.0, 5.0, 1.0, bfloat16);
CHECK(array_equal(x, array({0, 1, 2, 3, 4}, bfloat16)).item<bool>());
x = arange(0.0, 5.0, 1.5, bfloat16);
CHECK(array_equal(x, array({0., 1.5, 3., 4.5}, bfloat16)).item<bool>());
}
}