Fp64 on the CPU (#1843)

* add fp64 data type

* clean build

* update docs

* fix bug
This commit is contained in:
Awni Hannun
2025-02-07 15:52:22 -08:00
committed by GitHub
parent 1a1b2108ec
commit 1c0c118f7c
32 changed files with 438 additions and 65 deletions

View File

@@ -151,6 +151,9 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case float64:
*out.data<double>() = static_cast<double>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;

View File

@@ -62,6 +62,9 @@ void arange(
case float32:
arange<float>(start, start + step, out, out.size());
break;
case float64:
arange<double>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;

View File

@@ -103,6 +103,9 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;

View File

@@ -51,6 +51,9 @@ void comparison_op(const array& a, const array& b, array& out, Op op) {
case float32:
binary_op<float, bool>(a, b, out, op);
break;
case float64:
binary_op<double, bool>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, op);
break;
@@ -114,6 +117,9 @@ void DivMod::eval_cpu(
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case float64:
binary_op<double>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
@@ -150,6 +156,9 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
binary_op<float, bool>(a, b, out, detail::NaNEqual());
break;
case float64:
binary_op<double, bool>(a, b, out, detail::NaNEqual());
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
break;
@@ -189,20 +198,22 @@ void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
" non floating point type.");
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::LogAddExp());
break;
case float32:
binary_op<float>(a, b, out, detail::LogAddExp());
break;
case float64:
binary_op<double>(a, b, out, detail::LogAddExp());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
}
}
@@ -321,20 +332,22 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
}
}

View File

@@ -358,6 +358,9 @@ void binary(const array& a, const array& b, array& out, Op op) {
case float32:
binary_op<float>(a, b, out, op);
break;
case float64:
binary_op<double>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, op);
break;

View File

@@ -205,6 +205,9 @@ void binary(
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;

View File

@@ -193,6 +193,9 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
case float32:
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float64:
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
@@ -242,6 +245,9 @@ inline void copy_inplace_dispatch(
case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float64:
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;

View File

@@ -41,4 +41,39 @@ void matmul<float>(
}
}
template <>
void matmul<double>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_dgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<double>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@@ -148,6 +148,9 @@ void dispatch_gather(
case float32:
gather<float, IdxT>(src, inds, out, axes, size);
break;
case float64:
gather<double, IdxT>(src, inds, out, axes, size);
break;
case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
break;
@@ -288,6 +291,9 @@ void dispatch_gather_axis(
case float32:
gather_axis<float, IdxT>(src, inds, out, axis);
break;
case float64:
gather_axis<double, IdxT>(src, inds, out, axis);
break;
case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
break;
@@ -499,6 +505,9 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case float64:
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
@@ -661,6 +670,9 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case float64:
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
break;

View File

@@ -46,6 +46,9 @@ void matmul_general(
} else if (out.dtype() == bfloat16) {
matmul<bfloat16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else if (out.dtype() == float64) {
matmul<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}

View File

@@ -42,6 +42,7 @@ instantiate_default_limit(int64_t);
instantiate_float_limit(float16_t);
instantiate_float_limit(bfloat16_t);
instantiate_float_limit(float);
instantiate_float_limit(double);
instantiate_float_limit(complex64_t);
template <>
@@ -59,6 +60,8 @@ const bfloat16_t Limits<bfloat16_t>::min =
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::min =
-std::numeric_limits<float>::infinity();
const double Limits<double>::max = std::numeric_limits<double>::infinity();
const double Limits<double>::min = -std::numeric_limits<double>::infinity();
const complex64_t Limits<complex64_t>::max =
std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::min =
@@ -460,6 +463,7 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
@@ -495,6 +499,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
@@ -537,6 +544,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;

View File

@@ -299,6 +299,10 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);

View File

@@ -51,6 +51,9 @@ void select_op(
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;

View File

@@ -6,6 +6,7 @@
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
@@ -28,7 +29,7 @@ void softmax(const array& in, array& out) {
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
@@ -163,6 +164,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
softmax<bfloat16_t, bfloat16_t>(in, out);
}
break;
case float64:
softmax<double, double>(in, out);
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");

View File

@@ -312,6 +312,8 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
@@ -346,6 +348,8 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
return sort<int64_t>(in, out, axis_);
case float32:
return sort<float>(in, out, axis_);
case float64:
return sort<double>(in, out, axis_);
case float16:
return sort<float16_t>(in, out, axis_);
case bfloat16:
@@ -380,6 +384,8 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
@@ -414,6 +420,8 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
return partition<int64_t>(in, out, axis_, kth_);
case float32:
return partition<float>(in, out, axis_, kth_);
case float64:
return partition<double>(in, out, axis_, kth_);
case float16:
return partition<float16_t>(in, out, axis_, kth_);
case bfloat16:

View File

@@ -34,6 +34,9 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
unary_op<float>(in, out, op);
break;
case float64:
unary_op<double>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
@@ -120,6 +123,9 @@ void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case float64:
unary_op<double>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
@@ -140,6 +146,9 @@ void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case float64:
unary_op<double>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;

View File

@@ -104,6 +104,9 @@ void unary(const array& a, array& out, Op op) {
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
@@ -125,6 +128,9 @@ void unary_fp(const array& a, array& out, Op op) {
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;

View File

@@ -151,8 +151,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
case bfloat16:
arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
break;
case complex64:
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
default:
throw std::runtime_error("[Arange::eval_gpu] Does not support type.");
}
compute_encoder.set_output_array(out, 2);

View File

@@ -42,6 +42,9 @@ std::string type_to_name(const Dtype& t) {
case float32:
tname = "float32";
break;
case float64:
tname = "double";
break;
case bfloat16:
tname = "bfloat16";
break;