support arange for bfloat16 (#245)

This commit is contained in:
Daniel Strobusch
2023-12-21 23:33:43 +01:00
committed by GitHub
parent 2c7df6795e
commit 794feb83df
2 changed files with 17 additions and 1 deletions

View File

@@ -215,7 +215,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
arange_set_scalars<float>(start_, start_ + step_, compute_encoder);
break;
case bfloat16:
throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16");
arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
break;
case complex64:
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
}