Fp64 on the CPU (#1843)

* add fp64 data type

* clean build

* update docs

* fix bug
This commit is contained in:
Awni Hannun
2025-02-07 15:52:22 -08:00
committed by GitHub
parent 1a1b2108ec
commit 1c0c118f7c
32 changed files with 438 additions and 65 deletions

View File

@@ -6,6 +6,7 @@
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
@@ -28,7 +29,7 @@ void softmax(const array& in, array& out) {
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
@@ -163,6 +164,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
softmax<bfloat16_t, bfloat16_t>(in, out);
}
break;
case float64:
softmax<double, double>(in, out);
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");