diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 8de4f92f9..5ffe0e10d 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -93,7 +93,6 @@ NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) NO_GPU(Select) -NO_GPU(SliceUpdate) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 938923977..1adb85918 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/primitives.h" +#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" @@ -170,6 +171,41 @@ void Slice::eval_gpu(const std::vector& inputs, array& out) { slice_gpu(in, out, start_indices_, strides_, stream()); } +void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); + + // Do copy + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ stream()); +} + void Squeeze::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Squeeze::eval_gpu"); eval(inputs, out); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 705c3ea76..2ac543ad8 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -322,41 +322,6 @@ void DynamicSliceUpdate::eval_gpu( /* const std::optional& dynamic_o_offset = */ out_offset); } -void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - auto& upd = inputs[1]; - - if (upd.size() == 0) { - out.copy_shared_buffer(in); - return; - } - - auto ctype = in.flags().contiguous && in.size() == in.data_size() - ? CopyType::Vector - : CopyType::General; - copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); - auto [data_offset, out_strides] = - prepare_slice(out, start_indices_, strides_); - - // Do copy - copy_gpu_inplace( - /* const array& src = */ upd, - /* array& dst = */ out, - /* const Shape& data_shape = */ upd.shape(), - /* const Strides& i_strides = */ upd.strides(), - /* const Strides& o_strides = */ out_strides, - /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ data_offset, - /* CopyType ctype = */ CopyType::GeneralGeneral, - /* const Stream& s = */ stream()); -} - void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index e80a1759f..6a440c319 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -225,6 +225,8 @@ struct MPIWrapper { return mpi_bfloat16_; case float64: return mpi_double_; + default: + throw std::runtime_error("Invalid type"); } }