mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-09 10:46:39 +08:00
[CUDA] Matmul utils initial commit (#2441)
This commit is contained in:
parent
86258f292f
commit
be9bc96da4
@ -171,6 +171,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<uint32_t>(),
|
out.data<uint32_t>(),
|
||||||
out.size(),
|
out.size(),
|
||||||
|
@ -222,6 +222,7 @@ void binary_op_gpu_inplace(
|
|||||||
dims_constant()>,
|
dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
@ -236,6 +237,7 @@ void binary_op_gpu_inplace(
|
|||||||
cu::binary_g<Op, InType, OutType, IdxT>,
|
cu::binary_g<Op, InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
@ -264,6 +266,7 @@ void binary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
|
@ -238,6 +238,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
dims_constant()>,
|
dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
@ -254,6 +255,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
cu::binary_two_g<Op, InType, OutType, IdxT>,
|
cu::binary_two_g<Op, InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
@ -287,6 +289,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
|
@ -334,7 +334,7 @@ void Compiled::eval_gpu(
|
|||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(outputs[0], large, work_per_thread);
|
get_launch_args(outputs[0], large, work_per_thread);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -76,6 +76,7 @@ void copy_contiguous(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>() + in_offset,
|
in.data<InType>() + in_offset,
|
||||||
out.data<OutType>() + out_offset,
|
out.data<OutType>() + out_offset,
|
||||||
out.data_size());
|
out.data_size());
|
||||||
|
@ -77,6 +77,7 @@ void copy_general(
|
|||||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
|
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
data_size,
|
||||||
@ -91,6 +92,7 @@ void copy_general(
|
|||||||
cu::copy_gg<InType, OutType, IdxT>,
|
cu::copy_gg<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
data_size,
|
||||||
|
@ -83,6 +83,7 @@ void copy_general_dynamic(
|
|||||||
dims_constant()>,
|
dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
@ -98,6 +99,7 @@ void copy_general_dynamic(
|
|||||||
cu::copy_gg_dynamic<InType, OutType, IdxT>,
|
cu::copy_gg_dynamic<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
|
@ -68,6 +68,7 @@ void copy_general_input(
|
|||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
@ -80,6 +81,7 @@ void copy_general_input(
|
|||||||
cu::copy_g<InType, OutType, IdxT>,
|
cu::copy_g<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
|
@ -224,12 +224,14 @@ void CommandEncoder::add_kernel_node(
|
|||||||
void* func,
|
void* func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
cudaKernelNodeParams kernel_params = {0};
|
cudaKernelNodeParams kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
kernel_params.blockDim = block_dim;
|
kernel_params.blockDim = block_dim;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
|
kernel_params.sharedMemBytes = smem_bytes;
|
||||||
add_kernel_node(kernel_params);
|
add_kernel_node(kernel_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,6 +239,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
CUfunction func,
|
CUfunction func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
@ -247,6 +250,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.blockDimY = block_dim.y;
|
kernel_params.blockDimY = block_dim.y;
|
||||||
kernel_params.blockDimZ = block_dim.z;
|
kernel_params.blockDimZ = block_dim.z;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
|
kernel_params.sharedMemBytes = smem_bytes;
|
||||||
add_kernel_node(kernel_params);
|
add_kernel_node(kernel_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,25 +47,34 @@ class CommandEncoder {
|
|||||||
void set_output_array(const array& arr);
|
void set_output_array(const array& arr);
|
||||||
|
|
||||||
template <typename F, typename... Params>
|
template <typename F, typename... Params>
|
||||||
void
|
void add_kernel_node(
|
||||||
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
|
F* func,
|
||||||
|
dim3 grid_dim,
|
||||||
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
|
Params&&... params) {
|
||||||
constexpr size_t num = sizeof...(Params);
|
constexpr size_t num = sizeof...(Params);
|
||||||
void* ptrs[num];
|
void* ptrs[num];
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
||||||
std::forward<Params>(params)),
|
std::forward<Params>(params)),
|
||||||
...);
|
...);
|
||||||
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
|
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_kernel_node(
|
void add_kernel_node(
|
||||||
CUfunction func,
|
CUfunction func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params);
|
void** params);
|
||||||
|
|
||||||
void
|
void add_kernel_node(
|
||||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
void* func,
|
||||||
|
dim3 grid_dim,
|
||||||
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
|
void** params);
|
||||||
|
|
||||||
// Low-level graph helpers.
|
// Low-level graph helpers.
|
||||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||||
|
@ -108,6 +108,7 @@ void Matmul::run_batched(
|
|||||||
cu::set_mm_device_pointers,
|
cu::set_mm_device_pointers,
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
cuda::ceil_div(pointers.size(), block_size),
|
||||||
block_size,
|
block_size,
|
||||||
|
0,
|
||||||
pointers.data<int8_t*>(),
|
pointers.data<int8_t*>(),
|
||||||
a.data<int8_t>(),
|
a.data<int8_t>(),
|
||||||
b.data<int8_t>(),
|
b.data<int8_t>(),
|
||||||
@ -168,6 +169,7 @@ void Matmul::run_batched(
|
|||||||
cu::set_addmm_device_pointers,
|
cu::set_addmm_device_pointers,
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
cuda::ceil_div(pointers.size(), block_size),
|
||||||
block_size,
|
block_size,
|
||||||
|
0,
|
||||||
pointers.data<int8_t*>(),
|
pointers.data<int8_t*>(),
|
||||||
a.data<int8_t>(),
|
a.data<int8_t>(),
|
||||||
b.data<int8_t>(),
|
b.data<int8_t>(),
|
||||||
|
@ -143,6 +143,7 @@ void gemv(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks_x,
|
num_blocks_x,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
mat,
|
mat,
|
||||||
vec,
|
vec,
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
@ -154,6 +155,7 @@ void gemv(
|
|||||||
kernel,
|
kernel,
|
||||||
dim3{num_blocks_x, batch_count},
|
dim3{num_blocks_x, batch_count},
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
mat,
|
mat,
|
||||||
vec,
|
vec,
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
|
@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(out, large);
|
auto [num_blocks, block_dims] = get_launch_args(out, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(upd, large);
|
auto [num_blocks, block_dims] = get_launch_args(upd, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -318,7 +318,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -422,7 +422,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -279,6 +279,7 @@ void LayerNorm::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
b.data<DataType>(),
|
b.data<DataType>(),
|
||||||
@ -391,6 +392,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
g.data<DataType>(),
|
g.data<DataType>(),
|
||||||
|
@ -150,6 +150,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<DataType>(),
|
in.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
|
@ -261,6 +261,7 @@ void affine_quantize(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
w.data<T>(),
|
w.data<T>(),
|
||||||
wq.data<uint8_t>(),
|
wq.data<uint8_t>(),
|
||||||
scales.data<T>(),
|
scales.data<T>(),
|
||||||
@ -316,6 +317,7 @@ void affine_dequantize(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
wq.data<uint8_t>(),
|
wq.data<uint8_t>(),
|
||||||
scales.data<T>(),
|
scales.data<T>(),
|
||||||
biases.data<T>(),
|
biases.data<T>(),
|
||||||
|
@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
cu::rbitsc,
|
cu::rbitsc,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
keys.data<uint32_t>(),
|
keys.data<uint32_t>(),
|
||||||
out.data<uint8_t>(),
|
out.data<uint8_t>(),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
cu::rbits,
|
cu::rbits,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
keys.data<uint32_t>(),
|
keys.data<uint32_t>(),
|
||||||
out.data<uint8_t>(),
|
out.data<uint8_t>(),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
|
@ -120,6 +120,7 @@ void all_reduce(
|
|||||||
kernel,
|
kernel,
|
||||||
blocks,
|
blocks,
|
||||||
threads,
|
threads,
|
||||||
|
0,
|
||||||
static_cast<T*>(indata),
|
static_cast<T*>(indata),
|
||||||
intermediate.data<U>(),
|
intermediate.data<U>(),
|
||||||
block_step,
|
block_step,
|
||||||
@ -146,6 +147,7 @@ void all_reduce(
|
|||||||
kernel,
|
kernel,
|
||||||
blocks,
|
blocks,
|
||||||
threads,
|
threads,
|
||||||
|
0,
|
||||||
static_cast<T*>(indata),
|
static_cast<T*>(indata),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
block_step,
|
block_step,
|
||||||
|
@ -230,7 +230,7 @@ void col_reduce_looped(
|
|||||||
auto kernel =
|
auto kernel =
|
||||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, blocks, indata, out.data<U>(), args);
|
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@ -41,7 +41,8 @@ void init_reduce(
|
|||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||||
grid.x = (grid.x + 1023) / 1024;
|
grid.x = (grid.x + 1023) / 1024;
|
||||||
encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size());
|
encoder.add_kernel_node(
|
||||||
|
kernel, grid, block, 0, out.data<U>(), out.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -269,7 +269,7 @@ void row_reduce_simple(
|
|||||||
|
|
||||||
int size = plan.shape.back();
|
int size = plan.shape.back();
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, indata, out.data<U>(), out.size(), size);
|
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -322,7 +322,7 @@ void row_reduce_looped(
|
|||||||
});
|
});
|
||||||
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, indata, out.data<U>(), out.size(), args);
|
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -222,6 +222,7 @@ void RMSNorm::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
@ -316,6 +317,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
g.data<DataType>(),
|
g.data<DataType>(),
|
||||||
|
@ -325,6 +325,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
@ -341,6 +342,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
@ -360,6 +362,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
@ -381,6 +384,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
|
@ -414,6 +414,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
in.data_size() / axis_size,
|
in.data_size() / axis_size,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
@ -443,6 +444,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
axis_size,
|
axis_size,
|
||||||
|
@ -151,6 +151,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<DataType>(),
|
in.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
|
9
mlx/backend/cuda/steel/defines.cuh
Normal file
9
mlx/backend/cuda/steel/defines.cuh
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#define MLX_UNROLL _Pragma("unroll")
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||||
|
#define MLX_CUDA_SM_80_ENABLED
|
||||||
|
#endif
|
101
mlx/backend/cuda/steel/gemm.cuh
Normal file
101
mlx/backend/cuda/steel/gemm.cuh
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
|
||||||
|
#include "mlx/backend/cuda/steel/mma.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/tiles.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An example gemm written with the utils.
|
||||||
|
*
|
||||||
|
* Computes A @ B.T when A and B are all aligned with the block sizes.
|
||||||
|
*/
|
||||||
|
template <typename T, int BM, int BN, int BK>
|
||||||
|
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
||||||
|
constexpr int WARPS_M = 2;
|
||||||
|
constexpr int WARPS_N = 2;
|
||||||
|
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
||||||
|
constexpr int WARP_STEP_M = BM / WARPS_M;
|
||||||
|
constexpr int WARP_STEP_N = BN / WARPS_N;
|
||||||
|
|
||||||
|
// Precompute some offsets for each thread
|
||||||
|
const int warpid = threadIdx.x / 32;
|
||||||
|
const int laneid = threadIdx.x % 32;
|
||||||
|
const int wm = warpid / WARPS_N;
|
||||||
|
const int wn = warpid % WARPS_N;
|
||||||
|
const int offset_m = wm * WARP_STEP_M;
|
||||||
|
const int offset_n = wn * WARP_STEP_N;
|
||||||
|
|
||||||
|
// Allocate shared memory
|
||||||
|
extern __shared__ char shmem[];
|
||||||
|
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
|
||||||
|
SharedTile<T, BN, BK>(&bs)[2] =
|
||||||
|
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
|
||||||
|
|
||||||
|
// Allocate registers for the MMA
|
||||||
|
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
||||||
|
RegisterTile<T, BM / WARPS_M, 16> A;
|
||||||
|
RegisterTile<T, BN / WARPS_N, 16> B;
|
||||||
|
|
||||||
|
// Move the global pointers to the tile
|
||||||
|
a += blockIdx.y * BM * K;
|
||||||
|
b += blockIdx.x * BN * K;
|
||||||
|
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||||
|
|
||||||
|
// Zero the accumulators
|
||||||
|
C.fill(0);
|
||||||
|
|
||||||
|
// Start the SM pipeline
|
||||||
|
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
|
||||||
|
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
|
||||||
|
cp_async_commit();
|
||||||
|
|
||||||
|
int tic = 0;
|
||||||
|
for (int k_block = BK; k_block < K; k_block += BK) {
|
||||||
|
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
|
||||||
|
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
||||||
|
cp_async_commit();
|
||||||
|
cp_async_wait<1>();
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < BK / 16; k++) {
|
||||||
|
A.load(
|
||||||
|
as[tic],
|
||||||
|
as[tic].base_addr(),
|
||||||
|
offset_m + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
B.load(
|
||||||
|
bs[tic],
|
||||||
|
bs[tic].base_addr(),
|
||||||
|
offset_n + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
|
||||||
|
mma_t(C, A, B);
|
||||||
|
}
|
||||||
|
|
||||||
|
tic ^= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty the pipeline
|
||||||
|
cp_async_wait_all();
|
||||||
|
__syncthreads();
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < BK / 16; k++) {
|
||||||
|
A.load(
|
||||||
|
as[tic],
|
||||||
|
as[tic].base_addr(),
|
||||||
|
offset_m + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
B.load(
|
||||||
|
bs[tic],
|
||||||
|
bs[tic].base_addr(),
|
||||||
|
offset_n + laneid % 16,
|
||||||
|
k * 16 + laneid / 16 * 8);
|
||||||
|
|
||||||
|
mma_t(C, A, B);
|
||||||
|
}
|
||||||
|
|
||||||
|
C.store_global(y, N, offset_m, offset_n);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
117
mlx/backend/cuda/steel/mma.cuh
Normal file
117
mlx/backend/cuda/steel/mma.cuh
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/steel/defines.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/tiles.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fallback mma.
|
||||||
|
*
|
||||||
|
* We should probably a) implement a fallback or complain about it to the
|
||||||
|
* compiler.
|
||||||
|
*/
|
||||||
|
template <typename U, typename T>
|
||||||
|
__device__ inline void
|
||||||
|
mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16
|
||||||
|
* float tile.
|
||||||
|
*
|
||||||
|
* We actually perform C += A @ B.T
|
||||||
|
*/
|
||||||
|
__device__ __forceinline__ void mma_t(
|
||||||
|
Tile16x16<float>& C,
|
||||||
|
Tile16x16<__nv_bfloat16>& A,
|
||||||
|
Tile16x16<__nv_bfloat16>& B) {
|
||||||
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
|
asm volatile(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||||
|
"{%0, %1, %2, %3}, "
|
||||||
|
"{%4, %5, %6, %7}, "
|
||||||
|
"{%8, %9}, "
|
||||||
|
"{%10, %11, %12, %13};"
|
||||||
|
|
||||||
|
// D matrix
|
||||||
|
: "+f"(C.values[0].x),
|
||||||
|
"+f"(C.values[0].y),
|
||||||
|
"+f"(C.values[1].x),
|
||||||
|
"+f"(C.values[1].y)
|
||||||
|
|
||||||
|
// A matrix
|
||||||
|
: "r"(*(uint32_t*)(&A.values[0])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[1])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[2])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[3])),
|
||||||
|
|
||||||
|
// B matrix
|
||||||
|
"r"(*(uint32_t*)(&B.values[0])),
|
||||||
|
"r"(*(uint32_t*)(&B.values[2])),
|
||||||
|
|
||||||
|
// C matrix
|
||||||
|
"f"(C.values[0].x),
|
||||||
|
"f"(C.values[0].y),
|
||||||
|
"f"(C.values[1].x),
|
||||||
|
"f"(C.values[1].y));
|
||||||
|
asm volatile(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||||
|
"{%0, %1, %2, %3}, "
|
||||||
|
"{%4, %5, %6, %7}, "
|
||||||
|
"{%8, %9}, "
|
||||||
|
"{%10, %11, %12, %13};"
|
||||||
|
|
||||||
|
// D matrix
|
||||||
|
: "+f"(C.values[2].x),
|
||||||
|
"+f"(C.values[2].y),
|
||||||
|
"+f"(C.values[3].x),
|
||||||
|
"+f"(C.values[3].y)
|
||||||
|
|
||||||
|
// A matrix
|
||||||
|
: "r"(*(uint32_t*)(&A.values[0])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[1])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[2])),
|
||||||
|
"r"(*(uint32_t*)(&A.values[3])),
|
||||||
|
|
||||||
|
// B matrix
|
||||||
|
"r"(*(uint32_t*)(&B.values[1])),
|
||||||
|
"r"(*(uint32_t*)(&B.values[3])),
|
||||||
|
|
||||||
|
// C matrix
|
||||||
|
"f"(C.values[2].x),
|
||||||
|
"f"(C.values[2].y),
|
||||||
|
"f"(C.values[3].x),
|
||||||
|
"f"(C.values[3].y));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiply larger register tiles by delegating to mma_t.
|
||||||
|
*/
|
||||||
|
template <typename U, typename T, int M, int N, int K>
|
||||||
|
__device__ __forceinline__ void mma_t(
|
||||||
|
RegisterTile<U, M, N>& C,
|
||||||
|
RegisterTile<T, M, K>& A,
|
||||||
|
RegisterTile<T, N, K>& B) {
|
||||||
|
constexpr int TILES_M = RegisterTile<T, M, K>::TILES_Y;
|
||||||
|
constexpr int TILES_K = RegisterTile<T, M, K>::TILES_X;
|
||||||
|
constexpr int TILES_N = RegisterTile<T, N, K>::TILES_Y;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int k = 0; k < TILES_K; k++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int m = 0; m < TILES_M; m++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int n = 0; n < TILES_N; n++) {
|
||||||
|
mma_t(
|
||||||
|
C.data[m * TILES_N + n],
|
||||||
|
A.data[m * TILES_K + k],
|
||||||
|
B.data[n * TILES_K + k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
471
mlx/backend/cuda/steel/tiles.cuh
Normal file
471
mlx/backend/cuda/steel/tiles.cuh
Normal file
@ -0,0 +1,471 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/steel/utils.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// Map types to their vector of 2 type float -> float2, double -> double2 etc
|
||||||
|
template <typename T>
|
||||||
|
struct Vector2;
|
||||||
|
template <>
|
||||||
|
struct Vector2<double> {
|
||||||
|
using type = double2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct Vector2<float> {
|
||||||
|
using type = float2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct Vector2<__half> {
|
||||||
|
using type = __half2;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct Vector2<__nv_bfloat16> {
|
||||||
|
using type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
using Vector2_t = typename Vector2<T>::type;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
||||||
|
* the warp.
|
||||||
|
*
|
||||||
|
* Each thread holds 8 values. They are distributed according to
|
||||||
|
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
||||||
|
*
|
||||||
|
* For use instructions see the individual methods eg load().
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
struct Tile16x16 {
|
||||||
|
using T2 = Vector2_t<T>;
|
||||||
|
|
||||||
|
T2 values[4];
|
||||||
|
|
||||||
|
__device__ inline void fill(T v) {
|
||||||
|
T2 v2 = {v, v};
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
values[i] = v2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a 16x16 tile from shared memory.
|
||||||
|
*
|
||||||
|
* The instruction is a bit weird in the sense that the address provided by
|
||||||
|
* each thread and the elements loaded are not the same.
|
||||||
|
*
|
||||||
|
* We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a
|
||||||
|
* result the warp provides 4*8 = 32 addresses one per row.
|
||||||
|
*
|
||||||
|
* Threads 0-7 provide the addresses for the first tile, 8-15 for the second
|
||||||
|
* and so on. For instance to load a non swizzled tile we would do
|
||||||
|
*
|
||||||
|
* base_addr + (laneid % 16) * BK + (laneid / 2) * 8
|
||||||
|
*
|
||||||
|
* See
|
||||||
|
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
|
||||||
|
*/
|
||||||
|
__device__ __forceinline__ void load(uint32_t row_address) {
|
||||||
|
if constexpr (
|
||||||
|
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
|
||||||
|
asm volatile(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(*(uint32_t*)&(values[0])),
|
||||||
|
"=r"(*(uint32_t*)&(values[1])),
|
||||||
|
"=r"(*(uint32_t*)&(values[2])),
|
||||||
|
"=r"(*(uint32_t*)&(values[3]))
|
||||||
|
: "r"(row_address));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Store the tile to the address pointed to by `x`.
|
||||||
|
*
|
||||||
|
* The provided pointer is a generic pointer but this is meant to be used to
|
||||||
|
* store to global memory. For storing to shared memory we should use
|
||||||
|
* `stmatrix`.
|
||||||
|
*
|
||||||
|
* This also showcases the format of the tile quite nicely. Each register is
|
||||||
|
* holding to adjacent values. The indices are
|
||||||
|
*
|
||||||
|
* row + 0, col + 0
|
||||||
|
* row + 8, col + 0
|
||||||
|
* row + 0, col + 8
|
||||||
|
* row + 8, col + 8
|
||||||
|
*
|
||||||
|
* Given that we are dealing with Vector2_t<U> the column offsets are 4
|
||||||
|
* instead of 8.
|
||||||
|
*/
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global(U* x, int N) {
|
||||||
|
using U2 = Vector2_t<U>;
|
||||||
|
U2* x2 = reinterpret_cast<U2*>(x);
|
||||||
|
const int laneid = threadIdx.x % 32;
|
||||||
|
const int row = laneid / 4;
|
||||||
|
const int col = laneid % 4;
|
||||||
|
if constexpr (std::is_same_v<U2, T2>) {
|
||||||
|
x2[(row + 0) * (N / 2) + col + 0] = values[0];
|
||||||
|
x2[(row + 0) * (N / 2) + col + 4] = values[2];
|
||||||
|
x2[(row + 8) * (N / 2) + col + 0] = values[1];
|
||||||
|
x2[(row + 8) * (N / 2) + col + 4] = values[3];
|
||||||
|
} else if constexpr (
|
||||||
|
std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {
|
||||||
|
x2[(row + 0) * (N / 2) + col + 0] =
|
||||||
|
__floats2bfloat162_rn(values[0].x, values[0].y);
|
||||||
|
x2[(row + 0) * (N / 2) + col + 4] =
|
||||||
|
__floats2bfloat162_rn(values[2].x, values[2].y);
|
||||||
|
x2[(row + 8) * (N / 2) + col + 0] =
|
||||||
|
__floats2bfloat162_rn(values[1].x, values[1].y);
|
||||||
|
x2[(row + 8) * (N / 2) + col + 4] =
|
||||||
|
__floats2bfloat162_rn(values[3].x, values[3].y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
|
||||||
|
const int laneid = threadIdx.x % 32;
|
||||||
|
const int row = laneid / 4;
|
||||||
|
const int col = laneid % 4;
|
||||||
|
if (row < max_rows) {
|
||||||
|
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
|
||||||
|
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
|
||||||
|
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
|
||||||
|
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
|
||||||
|
}
|
||||||
|
if (row + 8 < max_rows) {
|
||||||
|
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
|
||||||
|
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
|
||||||
|
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
|
||||||
|
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple container of multiple Tile16x16.
|
||||||
|
*
|
||||||
|
* Provides utility functions for loading and manipulating collections of basic
|
||||||
|
* tiles.
|
||||||
|
*/
|
||||||
|
template <typename T, int ROWS_, int COLS_>
|
||||||
|
struct RegisterTile {
|
||||||
|
static constexpr int ROWS = ROWS_;
|
||||||
|
static constexpr int COLS = COLS_;
|
||||||
|
static constexpr int TILES_X = COLS / 16;
|
||||||
|
static constexpr int TILES_Y = ROWS / 16;
|
||||||
|
|
||||||
|
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||||
|
|
||||||
|
__device__ inline void fill(T v) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].fill(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tile>
|
||||||
|
__device__ __forceinline__ void
|
||||||
|
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].load(
|
||||||
|
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tile, typename F>
|
||||||
|
__device__ __forceinline__ void
|
||||||
|
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
f(data[i * TILES_X + j],
|
||||||
|
tile,
|
||||||
|
base_address,
|
||||||
|
row + i * 16,
|
||||||
|
col + j * 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].store_global(
|
||||||
|
x + (row + i * 16) * N + col + j * 16, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void
|
||||||
|
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].store_global_safe(
|
||||||
|
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple container of multiple Tile16x16.
|
||||||
|
*
|
||||||
|
* Provides utility functions for loading and manipulating collections of basic
|
||||||
|
* tiles.
|
||||||
|
*/
|
||||||
|
template <typename T, int ROWS_, int COLS_>
|
||||||
|
struct RegisterTile {
|
||||||
|
static constexpr int ROWS = ROWS_;
|
||||||
|
static constexpr int COLS = COLS_;
|
||||||
|
static constexpr int TILES_X = COLS / 16;
|
||||||
|
static constexpr int TILES_Y = ROWS / 16;
|
||||||
|
|
||||||
|
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||||
|
|
||||||
|
__device__ inline void fill(T v) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].fill(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tile>
|
||||||
|
__device__ inline void
|
||||||
|
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].load(
|
||||||
|
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].store_global(
|
||||||
|
x + (row + i * 16) * N + col + j * 16, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int ROWS_, int COLS_>
|
||||||
|
struct SharedTile {
|
||||||
|
static constexpr int ROWS = ROWS_;
|
||||||
|
static constexpr int COLS = COLS_;
|
||||||
|
static constexpr int TILES_X = COLS / 16;
|
||||||
|
static constexpr int TILES_Y = ROWS / 16;
|
||||||
|
static constexpr int NUMEL = ROWS * COLS;
|
||||||
|
|
||||||
|
// Swizzle taken from ThunderKittens. Should be changed when we switch to
|
||||||
|
// cute Layouts.
|
||||||
|
//
|
||||||
|
// See inludes/types/shared/st.cuh
|
||||||
|
//
|
||||||
|
// I do feel that it is too math heavy and can be improved. Also the math is
|
||||||
|
// done every time although the addresses don't change from load to load. I
|
||||||
|
// guess we are expecting the compiler to figure that out.
|
||||||
|
static constexpr int swizzle_bytes =
|
||||||
|
(sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32))
|
||||||
|
: (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0));
|
||||||
|
|
||||||
|
T data[ROWS * COLS];
|
||||||
|
|
||||||
|
__device__ inline uint32_t base_addr() const {
|
||||||
|
return __cvta_generic_to_shared(&data[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a pointer to the element at (row, col) using the swizzle.
|
||||||
|
__device__ static inline T* ptr(T* ptr, int row, int col) {
|
||||||
|
if constexpr (swizzle_bytes > 0) {
|
||||||
|
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||||
|
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||||
|
const int outer_idx = col / subtile_cols;
|
||||||
|
const uint64_t addr =
|
||||||
|
(uint64_t)(&ptr
|
||||||
|
[outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||||
|
col % subtile_cols]);
|
||||||
|
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||||
|
return (T*)(addr ^ swizzle);
|
||||||
|
} else {
|
||||||
|
return ptr + row * COLS + col;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the location of the element at (row, col) using the swizzle.
|
||||||
|
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
||||||
|
if constexpr (swizzle_bytes > 0) {
|
||||||
|
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||||
|
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||||
|
const int outer_idx = col / subtile_cols;
|
||||||
|
const uint32_t addr = ptr +
|
||||||
|
sizeof(T) *
|
||||||
|
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||||
|
col % subtile_cols);
|
||||||
|
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||||
|
return (addr ^ swizzle);
|
||||||
|
} else {
|
||||||
|
return ptr + sizeof(T) * (row * COLS + col);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience functions to edit elements going through the swizzle.
|
||||||
|
__device__ inline T& operator()(int row, int col) {
|
||||||
|
return *ptr(data, row, col);
|
||||||
|
}
|
||||||
|
__device__ inline void store(float4& v, int row, int col) {
|
||||||
|
*(reinterpret_cast<float4*>(ptr(data, row, col))) = v;
|
||||||
|
}
|
||||||
|
__device__ inline void store(float2& v, int row, int col) {
|
||||||
|
*(reinterpret_cast<float2*>(ptr(data, row, col))) = v;
|
||||||
|
}
|
||||||
|
__device__ inline void store(float& v, int row, int col) {
|
||||||
|
*(reinterpret_cast<float*>(ptr(data, row, col))) = v;
|
||||||
|
}
|
||||||
|
template <int N>
|
||||||
|
__device__ inline void store(T (&v)[N], int row, int col) {
|
||||||
|
if constexpr (sizeof(T) * N == 4) {
|
||||||
|
store(*(reinterpret_cast<float*>(&v[0])), row, col);
|
||||||
|
} else if constexpr (sizeof(T) * N == 8) {
|
||||||
|
store(*(reinterpret_cast<float2*>(&v[0])), row, col);
|
||||||
|
} else if constexpr (sizeof(T) * N == 16) {
|
||||||
|
store(*(reinterpret_cast<float4*>(&v[0])), row, col);
|
||||||
|
} else {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
*ptr(data, row, col + i) = v[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load the tile from global memory by loading 16 bytes at a time and storing
|
||||||
|
* them immediately.
|
||||||
|
*
|
||||||
|
* Can also be used as a fallback for architectures before sm_80.
|
||||||
|
*/
|
||||||
|
template <int NUM_WARPS, typename T, typename Tile>
|
||||||
|
__device__ inline void load(Tile& tile, const T* x, int N) {
|
||||||
|
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||||
|
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||||
|
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||||
|
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||||
|
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||||
|
float4 tmp;
|
||||||
|
tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));
|
||||||
|
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The asynchronous equivalent of load.
|
||||||
|
*
|
||||||
|
* Loads the tile from global memory by submitting a bunch of async copy
|
||||||
|
* instructions. The copy won't start until commit is called and we don't have
|
||||||
|
* a guarantee it will finish until wait is called.
|
||||||
|
*
|
||||||
|
* It should be used as follows
|
||||||
|
*
|
||||||
|
* load(...)
|
||||||
|
* load(...)
|
||||||
|
* cp_async_commit()
|
||||||
|
* do_other_stuff()
|
||||||
|
* cp_async_wait_all()
|
||||||
|
* do_stuff_with_shmem()
|
||||||
|
*/
|
||||||
|
template <int NUM_WARPS, typename T, typename Tile>
|
||||||
|
__device__ inline void
|
||||||
|
load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
|
||||||
|
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||||
|
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||||
|
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||||
|
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||||
|
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||||
|
cp_async<16>(
|
||||||
|
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
||||||
|
x + i * STEP_ROWS * N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Same as load_async but checks if we can load the row.
|
||||||
|
*
|
||||||
|
* NOTE: It should be changed to use a predicated cp async instead.
|
||||||
|
*/
|
||||||
|
template <int NUM_WARPS, typename T, typename Tile>
|
||||||
|
__device__ inline void load_async_safe(
|
||||||
|
Tile& tile,
|
||||||
|
uint32_t base_address,
|
||||||
|
const T* x,
|
||||||
|
int N,
|
||||||
|
int max_rows) {
|
||||||
|
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||||
|
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||||
|
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||||
|
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||||
|
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||||
|
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||||
|
|
||||||
|
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||||
|
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||||
|
if (row + i * STEP_ROWS < max_rows) {
|
||||||
|
cp_async<16>(
|
||||||
|
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
||||||
|
x + i * STEP_ROWS * N);
|
||||||
|
} else {
|
||||||
|
float4 tmp = {0, 0, 0, 0};
|
||||||
|
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
89
mlx/backend/cuda/steel/utils.cuh
Normal file
89
mlx/backend/cuda/steel/utils.cuh
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/defines.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy bytes from the global memory address pointed to by x to the smem
|
||||||
|
* address pointed to by row_address.
|
||||||
|
*
|
||||||
|
* A simple wrapper over the PTX.
|
||||||
|
*/
|
||||||
|
template <int N, typename T>
|
||||||
|
__device__ inline void cp_async(uint32_t row_address, const T* x) {
|
||||||
|
static_assert(
|
||||||
|
N == 16 || N == 8 || N == 4,
|
||||||
|
"cp.async is only supported for N in {4, 8, 16}.");
|
||||||
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
|
if constexpr (N == 16) {
|
||||||
|
asm volatile(
|
||||||
|
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||||
|
"l"(reinterpret_cast<const int4*>(x)));
|
||||||
|
} else if constexpr (N == 8) {
|
||||||
|
asm volatile(
|
||||||
|
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
|
||||||
|
"l"(reinterpret_cast<const int2*>(x)));
|
||||||
|
} else if constexpr (N == 4) {
|
||||||
|
asm volatile(
|
||||||
|
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
|
||||||
|
"l"(reinterpret_cast<const int*>(x)));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Submit all the previous async copies to be executed.
|
||||||
|
*/
|
||||||
|
__device__ inline void cp_async_commit() {
|
||||||
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wait for all but N of the async copies to finish.
|
||||||
|
*/
|
||||||
|
template <int N>
|
||||||
|
__device__ inline void cp_async_wait() {
|
||||||
|
#if defined(MLX_CUDA_SM_80_ENABLED)
|
||||||
|
if constexpr (N == 0) {
|
||||||
|
asm volatile("cp.async.wait_all;\n" ::);
|
||||||
|
} else {
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wait for all the async copies to finish.
|
||||||
|
*/
|
||||||
|
__device__ inline void cp_async_wait_all() {
|
||||||
|
cp_async_wait<0>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract ``bits`` bits from the 32 bit value.
|
||||||
|
*
|
||||||
|
* Single instruction shift and mask.
|
||||||
|
*/
|
||||||
|
template <int bits>
|
||||||
|
__device__ inline uint32_t extract_bits(uint32_t value, int start_bit) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"extract_bits only supports 2, 4, 8 for now.");
|
||||||
|
uint32_t result;
|
||||||
|
if constexpr (bits == 2) {
|
||||||
|
asm("bfe.u32 %0, %1, %2, 2;" : "=r"(result) : "r"(value), "r"(start_bit));
|
||||||
|
} else if constexpr (bits == 4) {
|
||||||
|
asm("bfe.u32 %0, %1, %2, 4;" : "=r"(result) : "r"(value), "r"(start_bit));
|
||||||
|
} else if constexpr (bits == 8) {
|
||||||
|
asm("bfe.u32 %0, %1, %2, 8;" : "=r"(result) : "r"(value), "r"(start_bit));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
@ -130,6 +130,7 @@ void ternary_op_gpu_inplace(
|
|||||||
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
|
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
@ -146,6 +147,7 @@ void ternary_op_gpu_inplace(
|
|||||||
cu::ternary_g<Op, DType, IdxT>,
|
cu::ternary_g<Op, DType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
@ -168,6 +170,7 @@ void ternary_op_gpu_inplace(
|
|||||||
cu::ternary_v<Op, DType, IdxT, N_READS>,
|
cu::ternary_v<Op, DType, IdxT, N_READS>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
|
@ -135,6 +135,7 @@ void unary_op_gpu_inplace(
|
|||||||
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
|
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>(),
|
in.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size());
|
out.data_size());
|
||||||
@ -146,6 +147,7 @@ void unary_op_gpu_inplace(
|
|||||||
cu::unary_g<Op, InType, OutType, IdxT>,
|
cu::unary_g<Op, InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>(),
|
in.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size(),
|
out.data_size(),
|
||||||
|
Loading…
Reference in New Issue
Block a user