mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
1 Commits
jagrit06/c
...
1034009b82
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1034009b82 |
@@ -24,7 +24,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/steel_gemm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/gemms/steel_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <numeric>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#include "mlx/backend/cuda/steel/gemm.cuh"
|
||||
#include "mlx/backend/cuda/steel/mma.cuh"
|
||||
#include "mlx/backend/cuda/steel/tiles.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
struct GemmParams {
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int lda;
|
||||
int ldb;
|
||||
int ldd;
|
||||
|
||||
int NblockM;
|
||||
int NblockN;
|
||||
int NblockK;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int SL,
|
||||
int Nstages>
|
||||
__global__ void kernel_steel_gemm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* d,
|
||||
__grid_constant__ const GemmParams params) {
|
||||
const int bM_idx = (blockIdx.y << SL) + (blockIdx.x & ((1 << SL) - 1));
|
||||
const int bN_idx = blockIdx.x >> SL;
|
||||
|
||||
if (params.NblockN <= bN_idx || params.NblockM <= bM_idx) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int d_row = bM_idx * BM;
|
||||
const int d_col = bN_idx * BN;
|
||||
const size_t d_row_long = size_t(d_row);
|
||||
const size_t d_col_long = size_t(d_col);
|
||||
|
||||
a += transpose_a ? d_row_long : d_row_long * params.K;
|
||||
b += transpose_b ? d_col_long * params.K : d_col_long;
|
||||
d += d_row_long * params.ldd + d_col_long;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
|
||||
const int lane_idx = warp.thread_rank();
|
||||
const int warp_idx = warp.meta_group_rank();
|
||||
|
||||
const int wm = warp_idx / WN;
|
||||
const int wn = warp_idx % WN;
|
||||
|
||||
constexpr int SM = BM / WM;
|
||||
constexpr int SN = BN / WN;
|
||||
constexpr int SK = BK;
|
||||
constexpr int TK = SK / 16;
|
||||
|
||||
constexpr int NUM_WARPS = WM * WN;
|
||||
|
||||
// Allocate shared memory
|
||||
extern __shared__ char shmem[];
|
||||
SharedTile<T, BM, BK>(&as)[Nstages] =
|
||||
*(SharedTile<T, BM, BK>(*)[Nstages])(&shmem[0]);
|
||||
SharedTile<T, BN, BK>(&bs)[Nstages] = *(SharedTile<T, BN, BK>(*)[Nstages])(
|
||||
&shmem[sizeof(T) * Nstages * BM * BK]);
|
||||
|
||||
// Allocate registers for the MMA
|
||||
RegisterTile<float, SM, SN> C;
|
||||
RegisterTile<T, SM, 16> A[TK];
|
||||
RegisterTile<T, SN, 16> B[TK];
|
||||
|
||||
// Zero the accumulators
|
||||
C.fill(0);
|
||||
|
||||
// Start gmem -> smem copies
|
||||
int k_block_read = 0;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int bk = 0; bk < (Nstages - 1); bk++) {
|
||||
load_async<NUM_WARPS>(
|
||||
as[bk], as[bk].base_addr(), a + k_block_read, params.K);
|
||||
load_async<NUM_WARPS>(
|
||||
bs[bk], bs[bk].base_addr(), b + k_block_read, params.K);
|
||||
k_block_read += BK;
|
||||
cp_async_commit();
|
||||
}
|
||||
|
||||
int smem_pipe_read = 0;
|
||||
int smem_pipe_write = Nstages - 1;
|
||||
|
||||
// Wait till only 1 remains laoding
|
||||
cp_async_wait<1>();
|
||||
block.sync();
|
||||
|
||||
const int offset_m = wm * SM;
|
||||
const int offset_n = wn * SN;
|
||||
|
||||
// Start smem -> register copy
|
||||
A[0].load(
|
||||
as[smem_pipe_read],
|
||||
as[smem_pipe_read].base_addr(),
|
||||
offset_m + lane_idx % 16,
|
||||
lane_idx / 16 * 8);
|
||||
B[0].load(
|
||||
bs[smem_pipe_read],
|
||||
bs[smem_pipe_read].base_addr(),
|
||||
offset_n + lane_idx % 16,
|
||||
lane_idx / 16 * 8);
|
||||
|
||||
// Main loop
|
||||
for (int kb = 0; kb < params.NblockK; kb++) {
|
||||
// Prepare next registers
|
||||
{
|
||||
A[1].load(
|
||||
as[smem_pipe_read],
|
||||
as[smem_pipe_read].base_addr(),
|
||||
offset_m + lane_idx % 16,
|
||||
16 + lane_idx / 16 * 8);
|
||||
B[1].load(
|
||||
bs[smem_pipe_read],
|
||||
bs[smem_pipe_read].base_addr(),
|
||||
offset_n + lane_idx % 16,
|
||||
16 + lane_idx / 16 * 8);
|
||||
}
|
||||
|
||||
// Prepare next smem
|
||||
if ((kb + Nstages - 1) < params.NblockK) {
|
||||
load_async<NUM_WARPS>(
|
||||
as[smem_pipe_write],
|
||||
as[smem_pipe_write].base_addr(),
|
||||
a + k_block_read,
|
||||
params.K);
|
||||
load_async<NUM_WARPS>(
|
||||
bs[smem_pipe_write],
|
||||
bs[smem_pipe_write].base_addr(),
|
||||
b + k_block_read,
|
||||
params.K);
|
||||
}
|
||||
k_block_read += BK;
|
||||
|
||||
cp_async_commit();
|
||||
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
smem_pipe_read = smem_pipe_read + 1;
|
||||
smem_pipe_read = (smem_pipe_read == Nstages) ? 0 : smem_pipe_read;
|
||||
|
||||
// Do current gemm
|
||||
mma_t(C, A[0], B[0]);
|
||||
|
||||
// Do wait for next register
|
||||
cp_async_wait<1>();
|
||||
block.sync();
|
||||
|
||||
// Prepare next register (smem_pipe_read has moved to the next)
|
||||
{
|
||||
A[0].load(
|
||||
as[smem_pipe_read],
|
||||
as[smem_pipe_read].base_addr(),
|
||||
offset_m + lane_idx % 16,
|
||||
lane_idx / 16 * 8);
|
||||
B[0].load(
|
||||
bs[smem_pipe_read],
|
||||
bs[smem_pipe_read].base_addr(),
|
||||
offset_n + lane_idx % 16,
|
||||
lane_idx / 16 * 8);
|
||||
}
|
||||
|
||||
// Do current gemm
|
||||
mma_t(C, A[1], B[1]);
|
||||
}
|
||||
|
||||
// Wait and clear
|
||||
cp_async_wait_all();
|
||||
block.sync();
|
||||
|
||||
C.store_global(d, params.ldd, offset_m, offset_n);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void dispatch_steel_gemm(
|
||||
const Stream& s,
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& d,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool a_transposed,
|
||||
bool b_transposed) {
|
||||
using DataType = cuda_type_t<float16_t>;
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(d);
|
||||
|
||||
constexpr int BM = 128;
|
||||
constexpr int BN = 128;
|
||||
constexpr int BK = 32;
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
|
||||
constexpr int SL = 0;
|
||||
constexpr int Nstages = 3;
|
||||
|
||||
constexpr uint32_t smem_bytes = BK * (BM + BN) * Nstages * sizeof(DataType);
|
||||
|
||||
const int NblockM = (M + BM - 1) / BM;
|
||||
const int NblockN = (N + BN - 1) / BN;
|
||||
const int NblockK = (K + BK - 1) / BK;
|
||||
|
||||
cu::GemmParams params{
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ ldd,
|
||||
|
||||
/* int NblockM = */ NblockM,
|
||||
/* int NblockN = */ NblockN,
|
||||
/* int NblockK = */ NblockK,
|
||||
};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << SL;
|
||||
int tm = (NblockM + tile - 1) / tile;
|
||||
int tn = NblockN * tile;
|
||||
|
||||
dim3 grid_dim(tn, tm, 1);
|
||||
dim3 block_dim(32 * WM * WN, 1, 1);
|
||||
|
||||
dispatch_bool(a_transposed, [&](auto ta_) {
|
||||
dispatch_bool(b_transposed, [&](auto tb_) {
|
||||
constexpr bool ta = ta_.value;
|
||||
constexpr bool tb = tb_.value;
|
||||
|
||||
auto kernel = cu::ab_t_aligned<DataType, BM, BN, BK>;
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
smem_bytes,
|
||||
a.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
d.data<DataType>(),
|
||||
N,
|
||||
K);
|
||||
|
||||
// auto kernel = cu::kernel_steel_gemm<DataType, BM, BN, BK, WM, WN, ta,
|
||||
// tb, SL, Nstages>;
|
||||
|
||||
// cudaFuncSetAttribute(kernel,
|
||||
// cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||
|
||||
// encoder.add_kernel_node(
|
||||
// kernel,
|
||||
// grid_dim,
|
||||
// block_dim,
|
||||
// smem_bytes,
|
||||
// a.data<DataType>(),
|
||||
// b.data<DataType>(),
|
||||
// d.data<DataType>(),
|
||||
// params);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,27 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void dispatch_steel_gemm(
|
||||
const Stream& s,
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& d,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool a_transposed,
|
||||
bool b_transposed);
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -7,8 +7,6 @@
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/cuda/gemms/steel_gemm.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <numeric>
|
||||
|
||||
@@ -97,24 +95,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (out.dtype() == float16 && batch_count == 1 && !a_transposed &&
|
||||
b_transposed) {
|
||||
return dispatch_steel_gemm(
|
||||
/* const Stream& s = */ s,
|
||||
/* cu::CommandEncoder& encoder = */ encoder,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* array& d = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ N,
|
||||
/* bool a_transposed = */ a_transposed,
|
||||
/* bool b_transposed = */ b_transposed);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
CublasGemm gemm(
|
||||
|
||||
@@ -143,87 +143,85 @@ struct Tile16x16 {
|
||||
}
|
||||
};
|
||||
|
||||
// /**
|
||||
// * A simple container of multiple Tile16x16.
|
||||
// *
|
||||
// * Provides utility functions for loading and manipulating collections of
|
||||
// basic
|
||||
// * tiles.
|
||||
// */
|
||||
// template <typename T, int ROWS_, int COLS_>
|
||||
// struct RegisterTile {
|
||||
// static constexpr int ROWS = ROWS_;
|
||||
// static constexpr int COLS = COLS_;
|
||||
// static constexpr int TILES_X = COLS / 16;
|
||||
// static constexpr int TILES_Y = ROWS / 16;
|
||||
/**
|
||||
* A simple container of multiple Tile16x16.
|
||||
*
|
||||
* Provides utility functions for loading and manipulating collections of basic
|
||||
* tiles.
|
||||
*/
|
||||
template <typename T, int ROWS_, int COLS_>
|
||||
struct RegisterTile {
|
||||
static constexpr int ROWS = ROWS_;
|
||||
static constexpr int COLS = COLS_;
|
||||
static constexpr int TILES_X = COLS / 16;
|
||||
static constexpr int TILES_Y = ROWS / 16;
|
||||
|
||||
// Tile16x16<T> data[TILES_X * TILES_Y];
|
||||
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||
|
||||
// __device__ inline void fill(T v) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// data[i * TILES_X + j].fill(v);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
__device__ inline void fill(T v) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].fill(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename Tile>
|
||||
// __device__ __forceinline__ void
|
||||
// load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// data[i * TILES_X + j].load(
|
||||
// tile.loc(base_address, row + i * 16, col + j * 16));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
template <typename Tile>
|
||||
__device__ __forceinline__ void
|
||||
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].load(
|
||||
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename Tile, typename F>
|
||||
// __device__ __forceinline__ void
|
||||
// load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// f(data[i * TILES_X + j],
|
||||
// tile,
|
||||
// base_address,
|
||||
// row + i * 16,
|
||||
// col + j * 16);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
template <typename Tile, typename F>
|
||||
__device__ __forceinline__ void
|
||||
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
f(data[i * TILES_X + j],
|
||||
tile,
|
||||
base_address,
|
||||
row + i * 16,
|
||||
col + j * 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename U>
|
||||
// __device__ inline void store_global(U* x, int N, int row, int col) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// data[i * TILES_X + j].store_global(
|
||||
// x + (row + i * 16) * N + col + j * 16, N);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
template <typename U>
|
||||
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].store_global(
|
||||
x + (row + i * 16) * N + col + j * 16, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename U>
|
||||
// __device__ inline void
|
||||
// store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||
// MLX_UNROLL
|
||||
// for (int i = 0; i < TILES_Y; i++) {
|
||||
// MLX_UNROLL
|
||||
// for (int j = 0; j < TILES_X; j++) {
|
||||
// data[i * TILES_X + j].store_global_safe(
|
||||
// x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i *
|
||||
// 16);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// };
|
||||
template <typename U>
|
||||
__device__ inline void
|
||||
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].store_global_safe(
|
||||
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A simple container of multiple Tile16x16.
|
||||
|
||||
@@ -37,19 +37,35 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void unary_g(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
IdxT size_rest,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
|
||||
out[index] = Op{}(in[idx]);
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto stride_x = strides[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto idx = elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
|
||||
auto in_vec = load_vector<N_READS>(
|
||||
in + idx, index_x, shape_x, stride_x, In(0));
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(in_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
@@ -127,8 +143,7 @@ void unary_op_gpu_inplace(
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
if (contig) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
// TODO: Choose optimized value based on type size.
|
||||
constexpr int N_READS = 4;
|
||||
constexpr int N_READS = 16 / sizeof(OutType);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out.data_size(), out.shape(), out.strides(), large, N_READS);
|
||||
encoder.add_kernel_node(
|
||||
@@ -142,18 +157,30 @@ void unary_op_gpu_inplace(
|
||||
} else {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large);
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
|
||||
work_per_thread = 4;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
encoder.add_kernel_node(
|
||||
cu::unary_g<Op, InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
rest,
|
||||
const_param(shape),
|
||||
const_param(strides),
|
||||
shape.size());
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user