Make sliceUpdate general

This commit is contained in:
Awni Hannun 2025-06-12 14:20:33 -07:00
parent c2dd81a8aa
commit 9825c33b90
4 changed files with 37 additions and 36 deletions

View File

@ -93,7 +93,6 @@ NO_GPU(Scan)
NO_GPU(Scatter) NO_GPU(Scatter)
NO_GPU(ScatterAxis) NO_GPU(ScatterAxis)
NO_GPU(Select) NO_GPU(Select)
NO_GPU(SliceUpdate)
NO_GPU_MULTI(SVD) NO_GPU_MULTI(SVD)
NO_GPU(Inverse) NO_GPU(Inverse)
NO_GPU(Cholesky) NO_GPU(Cholesky)

View File

@ -170,6 +170,41 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
slice_gpu(in, out, start_indices_, strides_, stream()); slice_gpu(in, out, start_indices_, strides_, stream());
} }
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);
// Do copy
copy_gpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const Shape& data_shape = */ upd.shape(),
/* const Strides& i_strides = */ upd.strides(),
/* const Strides& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) { void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Squeeze::eval_gpu"); MLX_PROFILER_RANGE("Squeeze::eval_gpu");
eval(inputs, out); eval(inputs, out);

View File

@ -322,41 +322,6 @@ void DynamicSliceUpdate::eval_gpu(
/* const std::optional<array>& dynamic_o_offset = */ out_offset); /* const std::optional<array>& dynamic_o_offset = */ out_offset);
} }
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);
// Do copy
copy_gpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const Shape& data_shape = */ upd.shape(),
/* const Strides& i_strides = */ upd.strides(),
/* const Strides& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ stream());
}
void QRF::eval_gpu( void QRF::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {

View File

@ -225,6 +225,8 @@ struct MPIWrapper {
return mpi_bfloat16_; return mpi_bfloat16_;
case float64: case float64:
return mpi_double_; return mpi_double_;
default:
throw std::runtime_error("Invalid type");
} }
} }