mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add dynamic shared memory
This commit is contained in:
@@ -171,6 +171,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<uint32_t>(),
|
out.data<uint32_t>(),
|
||||||
out.size(),
|
out.size(),
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ void binary_op_gpu_inplace(
|
|||||||
dims_constant()>,
|
dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
@@ -236,6 +237,7 @@ void binary_op_gpu_inplace(
|
|||||||
cu::binary_g<Op, InType, OutType, IdxT>,
|
cu::binary_g<Op, InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
@@ -264,6 +266,7 @@ void binary_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
|
|||||||
@@ -238,6 +238,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
dims_constant()>,
|
dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
@@ -254,6 +255,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
cu::binary_two_g<Op, InType, OutType, IdxT>,
|
cu::binary_two_g<Op, InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
@@ -287,6 +289,7 @@ void binary_two_op_gpu_inplace(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ void Compiled::eval_gpu(
|
|||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(outputs[0], large, work_per_thread);
|
get_launch_args(outputs[0], large, work_per_thread);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ void copy_contiguous(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>() + in_offset,
|
in.data<InType>() + in_offset,
|
||||||
out.data<OutType>() + out_offset,
|
out.data<OutType>() + out_offset,
|
||||||
out.data_size());
|
out.data_size());
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ void copy_general(
|
|||||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
|
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
data_size,
|
||||||
@@ -91,6 +92,7 @@ void copy_general(
|
|||||||
cu::copy_gg<InType, OutType, IdxT>,
|
cu::copy_gg<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
data_size,
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ void copy_general_dynamic(
|
|||||||
dims_constant()>,
|
dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
@@ -98,6 +99,7 @@ void copy_general_dynamic(
|
|||||||
cu::copy_gg_dynamic<InType, OutType, IdxT>,
|
cu::copy_gg_dynamic<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ void copy_general_input(
|
|||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
@@ -80,6 +81,7 @@ void copy_general_input(
|
|||||||
cu::copy_g<InType, OutType, IdxT>,
|
cu::copy_g<InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
|
|||||||
@@ -224,12 +224,14 @@ void CommandEncoder::add_kernel_node(
|
|||||||
void* func,
|
void* func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
cudaKernelNodeParams kernel_params = {0};
|
cudaKernelNodeParams kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
kernel_params.gridDim = grid_dim;
|
kernel_params.gridDim = grid_dim;
|
||||||
kernel_params.blockDim = block_dim;
|
kernel_params.blockDim = block_dim;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
|
kernel_params.sharedMemBytes = smem_bytes;
|
||||||
add_kernel_node(kernel_params);
|
add_kernel_node(kernel_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -237,6 +239,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
CUfunction func,
|
CUfunction func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params) {
|
void** params) {
|
||||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||||
kernel_params.func = func;
|
kernel_params.func = func;
|
||||||
@@ -247,6 +250,7 @@ void CommandEncoder::add_kernel_node(
|
|||||||
kernel_params.blockDimY = block_dim.y;
|
kernel_params.blockDimY = block_dim.y;
|
||||||
kernel_params.blockDimZ = block_dim.z;
|
kernel_params.blockDimZ = block_dim.z;
|
||||||
kernel_params.kernelParams = params;
|
kernel_params.kernelParams = params;
|
||||||
|
kernel_params.sharedMemBytes = smem_bytes;
|
||||||
add_kernel_node(kernel_params);
|
add_kernel_node(kernel_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -47,25 +47,34 @@ class CommandEncoder {
|
|||||||
void set_output_array(const array& arr);
|
void set_output_array(const array& arr);
|
||||||
|
|
||||||
template <typename F, typename... Params>
|
template <typename F, typename... Params>
|
||||||
void
|
void add_kernel_node(
|
||||||
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
|
F* func,
|
||||||
|
dim3 grid_dim,
|
||||||
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
|
Params&&... params) {
|
||||||
constexpr size_t num = sizeof...(Params);
|
constexpr size_t num = sizeof...(Params);
|
||||||
void* ptrs[num];
|
void* ptrs[num];
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
||||||
std::forward<Params>(params)),
|
std::forward<Params>(params)),
|
||||||
...);
|
...);
|
||||||
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
|
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_kernel_node(
|
void add_kernel_node(
|
||||||
CUfunction func,
|
CUfunction func,
|
||||||
dim3 grid_dim,
|
dim3 grid_dim,
|
||||||
dim3 block_dim,
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
void** params);
|
void** params);
|
||||||
|
|
||||||
void
|
void add_kernel_node(
|
||||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
void* func,
|
||||||
|
dim3 grid_dim,
|
||||||
|
dim3 block_dim,
|
||||||
|
uint32_t smem_bytes,
|
||||||
|
void** params);
|
||||||
|
|
||||||
// Low-level graph helpers.
|
// Low-level graph helpers.
|
||||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ void Matmul::run_batched(
|
|||||||
cu::set_mm_device_pointers,
|
cu::set_mm_device_pointers,
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
cuda::ceil_div(pointers.size(), block_size),
|
||||||
block_size,
|
block_size,
|
||||||
|
0,
|
||||||
pointers.data<int8_t*>(),
|
pointers.data<int8_t*>(),
|
||||||
a.data<int8_t>(),
|
a.data<int8_t>(),
|
||||||
b.data<int8_t>(),
|
b.data<int8_t>(),
|
||||||
@@ -168,6 +169,7 @@ void Matmul::run_batched(
|
|||||||
cu::set_addmm_device_pointers,
|
cu::set_addmm_device_pointers,
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
cuda::ceil_div(pointers.size(), block_size),
|
||||||
block_size,
|
block_size,
|
||||||
|
0,
|
||||||
pointers.data<int8_t*>(),
|
pointers.data<int8_t*>(),
|
||||||
a.data<int8_t>(),
|
a.data<int8_t>(),
|
||||||
b.data<int8_t>(),
|
b.data<int8_t>(),
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(out, large);
|
auto [num_blocks, block_dims] = get_launch_args(out, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(upd, large);
|
auto [num_blocks, block_dims] = get_launch_args(upd, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -318,7 +318,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -422,7 +422,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
auto kernel = mod.get_kernel(kernel_name);
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
auto [num_blocks, block_dims] = get_launch_args(idx, large);
|
||||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -279,6 +279,7 @@ void LayerNorm::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
b.data<DataType>(),
|
b.data<DataType>(),
|
||||||
@@ -391,6 +392,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
g.data<DataType>(),
|
g.data<DataType>(),
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<DataType>(),
|
in.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
|
|||||||
@@ -261,6 +261,7 @@ void affine_quantize(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
w.data<T>(),
|
w.data<T>(),
|
||||||
wq.data<uint8_t>(),
|
wq.data<uint8_t>(),
|
||||||
scales.data<T>(),
|
scales.data<T>(),
|
||||||
@@ -316,6 +317,7 @@ void affine_dequantize(
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
wq.data<uint8_t>(),
|
wq.data<uint8_t>(),
|
||||||
scales.data<T>(),
|
scales.data<T>(),
|
||||||
biases.data<T>(),
|
biases.data<T>(),
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
cu::rbitsc,
|
cu::rbitsc,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
keys.data<uint32_t>(),
|
keys.data<uint32_t>(),
|
||||||
out.data<uint8_t>(),
|
out.data<uint8_t>(),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
@@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
cu::rbits,
|
cu::rbits,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
keys.data<uint32_t>(),
|
keys.data<uint32_t>(),
|
||||||
out.data<uint8_t>(),
|
out.data<uint8_t>(),
|
||||||
grid_dims,
|
grid_dims,
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ void all_reduce(
|
|||||||
kernel,
|
kernel,
|
||||||
blocks,
|
blocks,
|
||||||
threads,
|
threads,
|
||||||
|
0,
|
||||||
static_cast<T*>(indata),
|
static_cast<T*>(indata),
|
||||||
intermediate.data<U>(),
|
intermediate.data<U>(),
|
||||||
block_step,
|
block_step,
|
||||||
@@ -146,6 +147,7 @@ void all_reduce(
|
|||||||
kernel,
|
kernel,
|
||||||
blocks,
|
blocks,
|
||||||
threads,
|
threads,
|
||||||
|
0,
|
||||||
static_cast<T*>(indata),
|
static_cast<T*>(indata),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
block_step,
|
block_step,
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ void col_reduce_looped(
|
|||||||
auto kernel =
|
auto kernel =
|
||||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, blocks, indata, out.data<U>(), args);
|
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -41,7 +41,8 @@ void init_reduce(
|
|||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||||
grid.x = (grid.x + 1023) / 1024;
|
grid.x = (grid.x + 1023) / 1024;
|
||||||
encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size());
|
encoder.add_kernel_node(
|
||||||
|
kernel, grid, block, 0, out.data<U>(), out.size());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -269,7 +269,7 @@ void row_reduce_simple(
|
|||||||
|
|
||||||
int size = plan.shape.back();
|
int size = plan.shape.back();
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, indata, out.data<U>(), out.size(), size);
|
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -322,7 +322,7 @@ void row_reduce_looped(
|
|||||||
});
|
});
|
||||||
|
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel, grid, block, indata, out.data<U>(), out.size(), args);
|
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ void RMSNorm::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
@@ -316,6 +317,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
g.data<DataType>(),
|
g.data<DataType>(),
|
||||||
|
|||||||
@@ -325,6 +325,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
@@ -341,6 +342,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
@@ -360,6 +362,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
@@ -381,6 +384,7 @@ void RoPE::eval_gpu(
|
|||||||
kernel,
|
kernel,
|
||||||
grid,
|
grid,
|
||||||
block,
|
block,
|
||||||
|
0,
|
||||||
(donated ? out : in).data<DataType>(),
|
(donated ? out : in).data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
offset.data<int32_t>(),
|
offset.data<int32_t>(),
|
||||||
|
|||||||
@@ -414,6 +414,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
in.data_size() / axis_size,
|
in.data_size() / axis_size,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
@@ -443,6 +444,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dim,
|
block_dim,
|
||||||
|
0,
|
||||||
in.data<T>(),
|
in.data<T>(),
|
||||||
out.data<U>(),
|
out.data<U>(),
|
||||||
axis_size,
|
axis_size,
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
block_dim(),
|
||||||
|
0,
|
||||||
in.data<DataType>(),
|
in.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
axis_size);
|
axis_size);
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/steel/utils.cuh"
|
#include "mlx/backend/cuda/steel/utils.cuh"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
@@ -223,6 +223,57 @@ struct RegisterTile {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple container of multiple Tile16x16.
|
||||||
|
*
|
||||||
|
* Provides utility functions for loading and manipulating collections of basic
|
||||||
|
* tiles.
|
||||||
|
*/
|
||||||
|
template <typename T, int ROWS_, int COLS_>
|
||||||
|
struct RegisterTile {
|
||||||
|
static constexpr int ROWS = ROWS_;
|
||||||
|
static constexpr int COLS = COLS_;
|
||||||
|
static constexpr int TILES_X = COLS / 16;
|
||||||
|
static constexpr int TILES_Y = ROWS / 16;
|
||||||
|
|
||||||
|
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||||
|
|
||||||
|
__device__ inline void fill(T v) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].fill(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tile>
|
||||||
|
__device__ inline void
|
||||||
|
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].load(
|
||||||
|
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int i = 0; i < TILES_Y; i++) {
|
||||||
|
MLX_UNROLL
|
||||||
|
for (int j = 0; j < TILES_X; j++) {
|
||||||
|
data[i * TILES_X + j].store_global(
|
||||||
|
x + (row + i * 16) * N + col + j * 16, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, int ROWS_, int COLS_>
|
template <typename T, int ROWS_, int COLS_>
|
||||||
struct SharedTile {
|
struct SharedTile {
|
||||||
static constexpr int ROWS = ROWS_;
|
static constexpr int ROWS = ROWS_;
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ void ternary_op_gpu_inplace(
|
|||||||
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
|
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
@@ -146,6 +147,7 @@ void ternary_op_gpu_inplace(
|
|||||||
cu::ternary_g<Op, DType, IdxT>,
|
cu::ternary_g<Op, DType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
@@ -168,6 +170,7 @@ void ternary_op_gpu_inplace(
|
|||||||
cu::ternary_v<Op, DType, IdxT, N_READS>,
|
cu::ternary_v<Op, DType, IdxT, N_READS>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
|
|||||||
@@ -135,6 +135,7 @@ void unary_op_gpu_inplace(
|
|||||||
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
|
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>(),
|
in.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size());
|
out.data_size());
|
||||||
@@ -146,6 +147,7 @@ void unary_op_gpu_inplace(
|
|||||||
cu::unary_g<Op, InType, OutType, IdxT>,
|
cu::unary_g<Op, InType, OutType, IdxT>,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
block_dims,
|
block_dims,
|
||||||
|
0,
|
||||||
in.data<InType>(),
|
in.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size(),
|
out.data_size(),
|
||||||
|
|||||||
Reference in New Issue
Block a user