Fp64 on the CPU (#1843)

* add fp64 data type

* clean build

* update docs

* fix bug
This commit is contained in:
Awni Hannun
2025-02-07 15:52:22 -08:00
committed by GitHub
parent 1a1b2108ec
commit 1c0c118f7c
32 changed files with 438 additions and 65 deletions

View File

@@ -42,6 +42,7 @@ instantiate_default_limit(int64_t);
instantiate_float_limit(float16_t);
instantiate_float_limit(bfloat16_t);
instantiate_float_limit(float);
instantiate_float_limit(double);
instantiate_float_limit(complex64_t);
template <>
@@ -59,6 +60,8 @@ const bfloat16_t Limits<bfloat16_t>::min =
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::min =
-std::numeric_limits<float>::infinity();
const double Limits<double>::max = std::numeric_limits<double>::infinity();
const double Limits<double>::min = -std::numeric_limits<double>::infinity();
const complex64_t Limits<complex64_t>::max =
std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::min =
@@ -460,6 +463,7 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
@@ -495,6 +499,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
@@ -537,6 +544,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;