From 794feb83df87eefa25b9e70fac7c6b2739687bec Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Thu, 21 Dec 2023 23:33:43 +0100 Subject: [PATCH] support arange for bfloat16 (#245) --- mlx/backend/metal/primitives.cpp | 3 ++- tests/creations_tests.cpp | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 578d4e382..94f1217a1 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -215,7 +215,8 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { arange_set_scalars(start_, start_ + step_, compute_encoder); break; case bfloat16: - throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16"); + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; case complex64: throw std::runtime_error("[Arange::eval_gpu] Does not support complex64"); } diff --git a/tests/creations_tests.cpp b/tests/creations_tests.cpp index 9d5f787c5..edb40a9fe 100644 --- a/tests/creations_tests.cpp +++ b/tests/creations_tests.cpp @@ -18,6 +18,12 @@ TEST_CASE("test arange") { x = arange(10, float32); CHECK_EQ(x.dtype(), float32); + x = arange(10, float16); + CHECK_EQ(x.dtype(), float16); + + x = arange(10, bfloat16); + CHECK_EQ(x.dtype(), bfloat16); + x = arange(10.0, int32); CHECK_EQ(x.dtype(), int32); @@ -105,6 +111,15 @@ TEST_CASE("test arange") { x = arange(0.0, 5.0, 1.5, int32); CHECK(array_equal(x, array({0, 1, 2, 3})).item()); + + x = arange(0.0, 5.0, 1.0, float16); + CHECK(array_equal(x, array({0, 1, 2, 3, 4}, float16)).item()); + + x = arange(0.0, 5.0, 1.0, bfloat16); + CHECK(array_equal(x, array({0, 1, 2, 3, 4}, bfloat16)).item()); + + x = arange(0.0, 5.0, 1.5, bfloat16); + CHECK(array_equal(x, array({0., 1.5, 3., 4.5}, bfloat16)).item()); } }