mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fp64 on the CPU (#1843)
* add fp64 data type * clean build * update docs * fix bug
This commit is contained in:
@@ -312,6 +312,8 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return argsort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
@@ -346,6 +348,8 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return sort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return sort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return sort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
@@ -380,6 +384,8 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
@@ -414,6 +420,8 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return partition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return partition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return partition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return partition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
|
||||
Reference in New Issue
Block a user