mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fp64 on the CPU (#1843)
* add fp64 data type * clean build * update docs * fix bug
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/types/limits.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -28,7 +29,7 @@ void softmax(const array& in, array& out) {
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
@@ -163,6 +164,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case float64:
|
||||
softmax<double, double>(in, out);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
"[Softmax] Not yet implemented for complex64");
|
||||
|
||||
Reference in New Issue
Block a user