mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fp64 on the CPU (#1843)
* add fp64 data type * clean build * update docs * fix bug
This commit is contained in:
@@ -148,6 +148,9 @@ void dispatch_gather(
|
||||
case float32:
|
||||
gather<float, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float64:
|
||||
gather<double, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
@@ -288,6 +291,9 @@ void dispatch_gather_axis(
|
||||
case float32:
|
||||
gather_axis<float, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case float64:
|
||||
gather_axis<double, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
@@ -499,6 +505,9 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float32:
|
||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
@@ -661,6 +670,9 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float32:
|
||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user