mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 18:26:41 +08:00
[CUDA] Implement DynamicSlice/DynamicSliceUpdate (#2533)
* Move DynamicSlice to gpu/primitives * Implement compute_dynamic_offset in CUDA
This commit is contained in:
parent
2ca75bb529
commit
4822c3dbe9
@ -15,8 +15,8 @@ void copy_gpu_inplace(
|
|||||||
int64_t offset_out,
|
int64_t offset_out,
|
||||||
CopyType ctype,
|
CopyType ctype,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::optional<array>& dynamic_offset_in,
|
std::optional<array> dynamic_offset_in,
|
||||||
const std::optional<array>& dynamic_offset_out) {
|
std::optional<array> dynamic_offset_out) {
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -44,6 +44,16 @@ void copy_gpu_inplace(
|
|||||||
strides_vec[0]);
|
strides_vec[0]);
|
||||||
} else {
|
} else {
|
||||||
if (dynamic_offset_in || dynamic_offset_out) {
|
if (dynamic_offset_in || dynamic_offset_out) {
|
||||||
|
if (!dynamic_offset_in) {
|
||||||
|
dynamic_offset_in = array(0, int64);
|
||||||
|
encoder.add_temporary(*dynamic_offset_in);
|
||||||
|
}
|
||||||
|
if (!dynamic_offset_out) {
|
||||||
|
dynamic_offset_out = array(0, int64);
|
||||||
|
encoder.add_temporary(*dynamic_offset_out);
|
||||||
|
}
|
||||||
|
encoder.set_input_array(*dynamic_offset_in);
|
||||||
|
encoder.set_input_array(*dynamic_offset_out);
|
||||||
copy_general_dynamic(
|
copy_general_dynamic(
|
||||||
encoder,
|
encoder,
|
||||||
ctype,
|
ctype,
|
||||||
@ -54,8 +64,8 @@ void copy_gpu_inplace(
|
|||||||
shape_collapsed,
|
shape_collapsed,
|
||||||
strides_vec[0],
|
strides_vec[0],
|
||||||
strides_vec[1],
|
strides_vec[1],
|
||||||
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
|
*dynamic_offset_in,
|
||||||
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
|
*dynamic_offset_out);
|
||||||
} else {
|
} else {
|
||||||
copy_general(
|
copy_general(
|
||||||
encoder,
|
encoder,
|
||||||
|
@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
args.append<int32_t>(src.ndim());
|
args.append<int32_t>(src.ndim());
|
||||||
args.append_ndim(slice_sizes_);
|
args.append_ndim(slice_sizes_);
|
||||||
args.append(slice_size);
|
args.append(slice_size);
|
||||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
args.append(axes_);
|
||||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||||
|
|
||||||
std::string kernel_name = fmt::format(
|
std::string kernel_name = fmt::format(
|
||||||
@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
args.append_ndim(out.shape());
|
args.append_ndim(out.shape());
|
||||||
args.append_ndim(out.strides());
|
args.append_ndim(out.strides());
|
||||||
args.append<int32_t>(out.ndim());
|
args.append<int32_t>(out.ndim());
|
||||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
args.append(axes_);
|
||||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||||
|
|
||||||
std::string kernel_name = fmt::format(
|
std::string kernel_name = fmt::format(
|
||||||
|
@ -46,6 +46,11 @@ struct KernelArgs {
|
|||||||
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
|
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void append(const std::vector<T>& vec) {
|
||||||
|
append(SmallVector<T>(vec.begin(), vec.end()));
|
||||||
|
}
|
||||||
|
|
||||||
// Make sure the arg is copied to an array with size of NDIM.
|
// Make sure the arg is copied to an array with size of NDIM.
|
||||||
template <size_t NDIM = MAX_NDIM, typename T>
|
template <size_t NDIM = MAX_NDIM, typename T>
|
||||||
void append_ndim(SmallVector<T> vec) {
|
void append_ndim(SmallVector<T> vec) {
|
||||||
|
@ -24,8 +24,6 @@ namespace mlx::core {
|
|||||||
}
|
}
|
||||||
|
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(DynamicSlice)
|
|
||||||
NO_GPU(DynamicSliceUpdate)
|
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
NO_GPU(GatherMM)
|
NO_GPU(GatherMM)
|
||||||
NO_GPU(GatherQMM)
|
NO_GPU(GatherQMM)
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/slicing.h"
|
#include "mlx/backend/common/slicing.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/jit_module.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/slicing.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
@ -38,4 +41,71 @@ void concatenate_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array compute_dynamic_offset(
|
||||||
|
const array& indices,
|
||||||
|
const Strides& strides,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const Stream& s) {
|
||||||
|
Dtype dtype = indices.dtype();
|
||||||
|
int nidx = axes.size();
|
||||||
|
|
||||||
|
std::string module_name =
|
||||||
|
fmt::format("compute_dynamic_offset_{}_{}", dtype_to_string(dtype), nidx);
|
||||||
|
std::string kernel_name = fmt::format(
|
||||||
|
"mlx::core::cu::compute_dynamic_offset<{}, {}>",
|
||||||
|
dtype_to_cuda_type(dtype),
|
||||||
|
nidx);
|
||||||
|
|
||||||
|
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||||
|
std::string source = R"(
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
template <typename T, int NIDX>
|
||||||
|
__global__ void compute_dynamic_offset(
|
||||||
|
const T* indices,
|
||||||
|
int64_t* offset,
|
||||||
|
const __grid_constant__ Strides strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int, NIDX> axes) {
|
||||||
|
int64_t acc = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
|
acc += indices[i] * strides[axes[i]];
|
||||||
|
}
|
||||||
|
*offset = acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
|
)";
|
||||||
|
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Prepare output.
|
||||||
|
array offset({1}, int64, nullptr, {});
|
||||||
|
bool donate = indices.is_donatable() &&
|
||||||
|
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||||
|
if (donate) {
|
||||||
|
offset.copy_shared_buffer(indices);
|
||||||
|
} else {
|
||||||
|
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.add_temporary(offset);
|
||||||
|
encoder.set_input_array(indices);
|
||||||
|
encoder.set_output_array(offset);
|
||||||
|
|
||||||
|
cu::KernelArgs args;
|
||||||
|
args.append(indices);
|
||||||
|
args.append(offset);
|
||||||
|
args.append_ndim(strides);
|
||||||
|
args.append(axes);
|
||||||
|
|
||||||
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
|
encoder.add_kernel_node(kernel, 1, 1, 0, args.args());
|
||||||
|
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -20,8 +20,8 @@ void copy_gpu_inplace(
|
|||||||
int64_t o_offset,
|
int64_t o_offset,
|
||||||
CopyType ctype,
|
CopyType ctype,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
std::optional<array> dynamic_i_offset = std::nullopt,
|
||||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
std::optional<array> dynamic_o_offset = std::nullopt);
|
||||||
|
|
||||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||||
|
@ -80,6 +80,74 @@ void Depends::eval_gpu(
|
|||||||
eval(inputs, outputs);
|
eval(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("DynamicSlice::eval_gpu");
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& start = inputs[1];
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto s = stream();
|
||||||
|
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||||
|
copy_gpu_inplace(
|
||||||
|
/* const array& src = */ in,
|
||||||
|
/* array& dst = */ out,
|
||||||
|
/* const Shape& data_shape = */ out.shape(),
|
||||||
|
/* const Strides& i_strides = */ in.strides(),
|
||||||
|
/* const Strides& o_strides = */ out.strides(),
|
||||||
|
/* int64_t i_offset = */ 0,
|
||||||
|
/* int64_t o_offset = */ 0,
|
||||||
|
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* std::optional<array> dynamic_i_offset = */ std::move(in_offset),
|
||||||
|
/* std::optional<array> dynamic_o_offset = */ std::nullopt);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DynamicSliceUpdate::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out) {
|
||||||
|
MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu");
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& upd = inputs[1];
|
||||||
|
auto& start_indices = inputs[2];
|
||||||
|
|
||||||
|
if (upd.size() == 0) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy or donate input to output
|
||||||
|
auto s = stream();
|
||||||
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
|
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
|
||||||
|
|
||||||
|
auto out_offset =
|
||||||
|
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
|
||||||
|
copy_gpu_inplace(
|
||||||
|
/* const array& src = */ upd,
|
||||||
|
/* array& dst = */ out,
|
||||||
|
/* const Shape& data_shape = */ upd.shape(),
|
||||||
|
/* const Strides& i_strides = */ upd.strides(),
|
||||||
|
/* const Strides& o_strides = */ out.strides(),
|
||||||
|
/* int64_t i_offset = */ 0,
|
||||||
|
/* int64_t o_offset = */ 0,
|
||||||
|
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* std::optional<array> dynamic_i_offset = */ std::nullopt,
|
||||||
|
/* std::optional<array> dynamic_o_offset = */ std::move(out_offset));
|
||||||
|
}
|
||||||
|
|
||||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
|
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
|
@ -27,4 +27,10 @@ void pad_gpu(
|
|||||||
const Shape& low_pad_size,
|
const Shape& low_pad_size,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
|
array compute_dynamic_offset(
|
||||||
|
const array& indices,
|
||||||
|
const Strides& strides,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -20,8 +20,8 @@ void copy_gpu_inplace(
|
|||||||
int64_t out_offset,
|
int64_t out_offset,
|
||||||
CopyType ctype,
|
CopyType ctype,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
std::optional<array> dynamic_i_offset /* = std::nullopt */,
|
||||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
std::optional<array> dynamic_o_offset /* = std::nullopt */) {
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
|
||||||
#include "mlx/backend/common/slicing.h"
|
#include "mlx/backend/common/slicing.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
@ -25,60 +24,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
|||||||
enc.set_bytes(step, 1);
|
enc.set_bytes(step, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static array compute_dynamic_offset(
|
|
||||||
const array& indices,
|
|
||||||
const Strides& strides,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
Stream s) {
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
|
|
||||||
// Kernel to compute offset here.
|
|
||||||
array offset({1}, int64, nullptr, {});
|
|
||||||
bool donate = indices.is_donatable() &&
|
|
||||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
|
||||||
if (donate) {
|
|
||||||
offset.copy_shared_buffer(indices);
|
|
||||||
} else {
|
|
||||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
|
||||||
}
|
|
||||||
d.add_temporary(offset, s.index);
|
|
||||||
|
|
||||||
auto dtype = indices.dtype();
|
|
||||||
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
|
|
||||||
auto lib = d.get_library(lib_name, [dtype]() {
|
|
||||||
return fmt::format(
|
|
||||||
R"(
|
|
||||||
[[kernel]] void compute_dynamic_offset_{0}(
|
|
||||||
constant const {1}* indices [[buffer(0)]],
|
|
||||||
device int64_t& offset [[buffer(1)]],
|
|
||||||
constant const int64_t* strides [[buffer(2)]],
|
|
||||||
constant const int* axes [[buffer(3)]],
|
|
||||||
constant const int& n_axes [[buffer(4)]],
|
|
||||||
uint index [[thread_position_in_grid]]) {{
|
|
||||||
int64_t acc = 0;
|
|
||||||
for (int i = 0; i < n_axes; ++i) {{
|
|
||||||
acc += indices[i] * strides[axes[i]];
|
|
||||||
}}
|
|
||||||
offset = acc;
|
|
||||||
}})",
|
|
||||||
type_to_name(dtype),
|
|
||||||
get_type_string(dtype));
|
|
||||||
});
|
|
||||||
auto kernel = d.get_kernel(lib_name, lib);
|
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
compute_encoder.set_input_array(indices, 0);
|
|
||||||
compute_encoder.set_output_array(offset, 1);
|
|
||||||
compute_encoder.set_vector_bytes(strides, 2);
|
|
||||||
compute_encoder.set_vector_bytes(axes, 3);
|
|
||||||
int n_axes = axes.size();
|
|
||||||
compute_encoder.set_bytes(n_axes, 4);
|
|
||||||
MTL::Size dims = MTL::Size(1, 1, 1);
|
|
||||||
compute_encoder.dispatch_threads(dims, dims);
|
|
||||||
return offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 0);
|
assert(inputs.size() == 0);
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
@ -256,72 +201,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& in = inputs[0];
|
|
||||||
auto& start = inputs[1];
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
auto s = stream();
|
|
||||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
|
||||||
copy_gpu_inplace(
|
|
||||||
/* const array& src = */ in,
|
|
||||||
/* array& dst = */ out,
|
|
||||||
/* const Shape& data_shape = */ out.shape(),
|
|
||||||
/* const Strides& i_strides = */ in.strides(),
|
|
||||||
/* const Strides& o_strides = */ out.strides(),
|
|
||||||
/* int64_t i_offset = */ 0,
|
|
||||||
/* int64_t o_offset = */ 0,
|
|
||||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
|
|
||||||
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DynamicSliceUpdate::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
array& out) {
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& in = inputs[0];
|
|
||||||
auto& upd = inputs[1];
|
|
||||||
auto& start_indices = inputs[2];
|
|
||||||
|
|
||||||
if (upd.size() == 0) {
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy or donate input to output
|
|
||||||
auto s = stream();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
|
||||||
? CopyType::Vector
|
|
||||||
: CopyType::General;
|
|
||||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
|
|
||||||
|
|
||||||
auto out_offset =
|
|
||||||
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
|
|
||||||
copy_gpu_inplace(
|
|
||||||
/* const array& src = */ upd,
|
|
||||||
/* array& dst = */ out,
|
|
||||||
/* const Shape& data_shape = */ upd.shape(),
|
|
||||||
/* const Strides& i_strides = */ upd.strides(),
|
|
||||||
/* const Strides& o_strides = */ out.strides(),
|
|
||||||
/* int64_t i_offset = */ 0,
|
|
||||||
/* int64_t o_offset = */ 0,
|
|
||||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
|
||||||
/* const Stream& s = */ s,
|
|
||||||
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
|
|
||||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
void QRF::eval_gpu(
|
void QRF::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
|
@ -2,9 +2,12 @@
|
|||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/slicing.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -39,4 +42,58 @@ void concatenate_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array compute_dynamic_offset(
|
||||||
|
const array& indices,
|
||||||
|
const Strides& strides,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
// Kernel to compute offset here.
|
||||||
|
array offset({1}, int64, nullptr, {});
|
||||||
|
bool donate = indices.is_donatable() &&
|
||||||
|
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||||
|
if (donate) {
|
||||||
|
offset.copy_shared_buffer(indices);
|
||||||
|
} else {
|
||||||
|
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||||
|
}
|
||||||
|
d.add_temporary(offset, s.index);
|
||||||
|
|
||||||
|
auto dtype = indices.dtype();
|
||||||
|
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
|
||||||
|
auto lib = d.get_library(lib_name, [dtype]() {
|
||||||
|
return fmt::format(
|
||||||
|
R"(
|
||||||
|
[[kernel]] void compute_dynamic_offset_{0}(
|
||||||
|
constant const {1}* indices [[buffer(0)]],
|
||||||
|
device int64_t& offset [[buffer(1)]],
|
||||||
|
constant const int64_t* strides [[buffer(2)]],
|
||||||
|
constant const int* axes [[buffer(3)]],
|
||||||
|
constant const int& n_axes [[buffer(4)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {{
|
||||||
|
int64_t acc = 0;
|
||||||
|
for (int i = 0; i < n_axes; ++i) {{
|
||||||
|
acc += indices[i] * strides[axes[i]];
|
||||||
|
}}
|
||||||
|
offset = acc;
|
||||||
|
}})",
|
||||||
|
type_to_name(dtype),
|
||||||
|
get_type_string(dtype));
|
||||||
|
});
|
||||||
|
auto kernel = d.get_kernel(lib_name, lib);
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
compute_encoder.set_input_array(indices, 0);
|
||||||
|
compute_encoder.set_output_array(offset, 1);
|
||||||
|
compute_encoder.set_vector_bytes(strides, 2);
|
||||||
|
compute_encoder.set_vector_bytes(axes, 3);
|
||||||
|
int n_axes = axes.size();
|
||||||
|
compute_encoder.set_bytes(n_axes, 4);
|
||||||
|
MTL::Size dims = MTL::Size(1, 1, 1);
|
||||||
|
compute_encoder.dispatch_threads(dims, dims);
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
cuda_skip = {
|
cuda_skip = {
|
||||||
"TestLoad.test_load_f8_e4m3",
|
"TestLoad.test_load_f8_e4m3",
|
||||||
"TestLayers.test_quantized_embedding",
|
"TestLayers.test_quantized_embedding",
|
||||||
"TestOps.test_dynamic_slicing",
|
|
||||||
# Block masked matmul NYI
|
# Block masked matmul NYI
|
||||||
"TestBlas.test_block_masked_matmul",
|
"TestBlas.test_block_masked_matmul",
|
||||||
# Gather matmul NYI
|
# Gather matmul NYI
|
||||||
|
Loading…
Reference in New Issue
Block a user