mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
c2dd81a8aa
commit
f5f65ef48c
@ -93,7 +93,6 @@ NO_GPU(Scan)
|
|||||||
NO_GPU(Scatter)
|
NO_GPU(Scatter)
|
||||||
NO_GPU(ScatterAxis)
|
NO_GPU(ScatterAxis)
|
||||||
NO_GPU(Select)
|
NO_GPU(Select)
|
||||||
NO_GPU(SliceUpdate)
|
|
||||||
NO_GPU_MULTI(SVD)
|
NO_GPU_MULTI(SVD)
|
||||||
NO_GPU(Inverse)
|
NO_GPU(Inverse)
|
||||||
NO_GPU(Cholesky)
|
NO_GPU(Cholesky)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/backend/common/slicing.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/slicing.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
@ -170,6 +171,41 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& upd = inputs[1];
|
||||||
|
|
||||||
|
if (upd.size() == 0) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
|
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||||
|
auto [data_offset, out_strides] =
|
||||||
|
prepare_slice(out, start_indices_, strides_);
|
||||||
|
|
||||||
|
// Do copy
|
||||||
|
copy_gpu_inplace(
|
||||||
|
/* const array& src = */ upd,
|
||||||
|
/* array& dst = */ out,
|
||||||
|
/* const Shape& data_shape = */ upd.shape(),
|
||||||
|
/* const Strides& i_strides = */ upd.strides(),
|
||||||
|
/* const Strides& o_strides = */ out_strides,
|
||||||
|
/* int64_t i_offset = */ 0,
|
||||||
|
/* int64_t o_offset = */ data_offset,
|
||||||
|
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||||
|
/* const Stream& s = */ stream());
|
||||||
|
}
|
||||||
|
|
||||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
|
@ -322,41 +322,6 @@ void DynamicSliceUpdate::eval_gpu(
|
|||||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& in = inputs[0];
|
|
||||||
auto& upd = inputs[1];
|
|
||||||
|
|
||||||
if (upd.size() == 0) {
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
|
||||||
? CopyType::Vector
|
|
||||||
: CopyType::General;
|
|
||||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
|
||||||
auto [data_offset, out_strides] =
|
|
||||||
prepare_slice(out, start_indices_, strides_);
|
|
||||||
|
|
||||||
// Do copy
|
|
||||||
copy_gpu_inplace(
|
|
||||||
/* const array& src = */ upd,
|
|
||||||
/* array& dst = */ out,
|
|
||||||
/* const Shape& data_shape = */ upd.shape(),
|
|
||||||
/* const Strides& i_strides = */ upd.strides(),
|
|
||||||
/* const Strides& o_strides = */ out_strides,
|
|
||||||
/* int64_t i_offset = */ 0,
|
|
||||||
/* int64_t o_offset = */ data_offset,
|
|
||||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
|
||||||
/* const Stream& s = */ stream());
|
|
||||||
}
|
|
||||||
|
|
||||||
void QRF::eval_gpu(
|
void QRF::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
|
@ -225,6 +225,8 @@ struct MPIWrapper {
|
|||||||
return mpi_bfloat16_;
|
return mpi_bfloat16_;
|
||||||
case float64:
|
case float64:
|
||||||
return mpi_double_;
|
return mpi_double_;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Invalid type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user