mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8269c9d02d | ||
|
|
903b40627c | ||
|
|
700f7dcf01 | ||
|
|
6c60bd1cbf | ||
|
|
a64cc02a0c | ||
|
|
346ae5fdb5 |
@@ -22,7 +22,7 @@ project(
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
|
||||
@@ -42,7 +42,9 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
@@ -90,7 +92,7 @@ target_compile_options(
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"70;80"
|
||||
"80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
@@ -130,3 +132,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||
|
||||
# Make Thunderkittens available
|
||||
FetchContent_Declare(
|
||||
kittens
|
||||
GIT_REPOSITORY https://github.com/HazyResearch/ThunderKittens.git
|
||||
GIT_TAG aaab847f430ed313ed466e64b25b9177babd1db8
|
||||
GIT_SHALLOW TRUE)
|
||||
FetchContent_MakeAvailable(kittens)
|
||||
target_include_directories(mlx BEFORE PRIVATE "${kittens_SOURCE_DIR}/include")
|
||||
|
||||
@@ -166,6 +166,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dim(),
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
|
||||
@@ -219,6 +219,7 @@ void binary_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
@@ -235,6 +236,7 @@ void binary_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
@@ -269,6 +271,7 @@ void binary_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
|
||||
@@ -239,6 +239,7 @@ void binary_two_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
@@ -256,6 +257,7 @@ void binary_two_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
@@ -291,6 +293,7 @@ void binary_two_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
|
||||
@@ -295,7 +295,7 @@ void Compiled::eval_gpu(
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, outputs[0], large, work_per_thread);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -82,6 +82,7 @@ void copy_contiguous(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
out.data_size());
|
||||
|
||||
@@ -79,6 +79,7 @@ void copy_general(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
@@ -94,6 +95,7 @@ void copy_general(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
|
||||
@@ -82,6 +82,7 @@ void copy_general_dynamic(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
@@ -99,6 +100,7 @@ void copy_general_dynamic(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
|
||||
@@ -71,6 +71,7 @@ void copy_general_input(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
@@ -85,6 +86,7 @@ void copy_general_input(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
|
||||
@@ -215,12 +215,14 @@ void CommandEncoder::add_kernel_node(
|
||||
void* func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
cudaKernelNodeParams kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
kernel_params.gridDim = grid_dim;
|
||||
kernel_params.blockDim = block_dim;
|
||||
kernel_params.kernelParams = params;
|
||||
kernel_params.sharedMemBytes = smem_bytes;
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||
@@ -231,6 +233,7 @@ void CommandEncoder::add_kernel_node(
|
||||
CUfunction func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
@@ -241,6 +244,7 @@ void CommandEncoder::add_kernel_node(
|
||||
kernel_params.blockDimY = block_dim.y;
|
||||
kernel_params.blockDimZ = block_dim.z;
|
||||
kernel_params.kernelParams = params;
|
||||
kernel_params.sharedMemBytes = smem_bytes;
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(
|
||||
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
|
||||
|
||||
@@ -45,25 +45,34 @@ class CommandEncoder {
|
||||
void set_output_array(const array& arr);
|
||||
|
||||
template <typename F, typename... Params>
|
||||
void
|
||||
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
|
||||
void add_kernel_node(
|
||||
F* func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
Params&&... params) {
|
||||
constexpr size_t num = sizeof...(Params);
|
||||
void* ptrs[num];
|
||||
size_t i = 0;
|
||||
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
||||
std::forward<Params>(params)),
|
||||
...);
|
||||
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
|
||||
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
|
||||
}
|
||||
|
||||
void add_kernel_node(
|
||||
CUfunction func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params);
|
||||
|
||||
void
|
||||
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
|
||||
void add_kernel_node(
|
||||
void* func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params);
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
|
||||
@@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_output_array(out);
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -318,7 +318,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_output_array(out);
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -422,7 +422,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_output_array(out);
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -266,6 +266,7 @@ void LayerNorm::eval_gpu(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
b.data<DataType>(),
|
||||
@@ -378,6 +379,7 @@ void LayerNormVJP::eval_gpu(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
|
||||
@@ -151,6 +151,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
axis_size);
|
||||
|
||||
108
mlx/backend/cuda/matmul/mma.cuh
Normal file
108
mlx/backend/cuda/matmul/mma.cuh
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/matmul/tiles.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename U, typename T>
|
||||
__device__ inline void
|
||||
mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
|
||||
|
||||
/**
|
||||
* Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16
|
||||
* float tile.
|
||||
*
|
||||
* We actually perform C += A @ B.T
|
||||
*/
|
||||
__device__ inline void mma_t(
|
||||
Tile16x16<float>& C,
|
||||
Tile16x16<__nv_bfloat16>& A,
|
||||
Tile16x16<__nv_bfloat16>& B) {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0, %1, %2, %3}, "
|
||||
"{%4, %5, %6, %7}, "
|
||||
"{%8, %9}, "
|
||||
"{%10, %11, %12, %13};"
|
||||
|
||||
// D matrix
|
||||
: "+f"(C.values[0].x),
|
||||
"+f"(C.values[0].y),
|
||||
"+f"(C.values[1].x),
|
||||
"+f"(C.values[1].y)
|
||||
|
||||
// A matrix
|
||||
: "r"(*(uint32_t*)(&A.values[0])),
|
||||
"r"(*(uint32_t*)(&A.values[1])),
|
||||
"r"(*(uint32_t*)(&A.values[2])),
|
||||
"r"(*(uint32_t*)(&A.values[3])),
|
||||
|
||||
// B matrix
|
||||
"r"(*(uint32_t*)(&B.values[0])),
|
||||
"r"(*(uint32_t*)(&B.values[2])),
|
||||
|
||||
// C matrix
|
||||
"f"(C.values[0].x),
|
||||
"f"(C.values[0].y),
|
||||
"f"(C.values[1].x),
|
||||
"f"(C.values[1].y));
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0, %1, %2, %3}, "
|
||||
"{%4, %5, %6, %7}, "
|
||||
"{%8, %9}, "
|
||||
"{%10, %11, %12, %13};"
|
||||
|
||||
// D matrix
|
||||
: "+f"(C.values[2].x),
|
||||
"+f"(C.values[2].y),
|
||||
"+f"(C.values[3].x),
|
||||
"+f"(C.values[3].y)
|
||||
|
||||
// A matrix
|
||||
: "r"(*(uint32_t*)(&A.values[0])),
|
||||
"r"(*(uint32_t*)(&A.values[1])),
|
||||
"r"(*(uint32_t*)(&A.values[2])),
|
||||
"r"(*(uint32_t*)(&A.values[3])),
|
||||
|
||||
// B matrix
|
||||
"r"(*(uint32_t*)(&B.values[1])),
|
||||
"r"(*(uint32_t*)(&B.values[3])),
|
||||
|
||||
// C matrix
|
||||
"f"(C.values[2].x),
|
||||
"f"(C.values[2].y),
|
||||
"f"(C.values[3].x),
|
||||
"f"(C.values[3].y));
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiply larger register tiles by delegating to mma_t.
|
||||
*/
|
||||
template <typename U, typename T, int M, int N, int K>
|
||||
__device__ inline void mma_t(
|
||||
RegisterTile<U, M, N>& C,
|
||||
RegisterTile<T, M, K>& A,
|
||||
RegisterTile<T, N, K>& B) {
|
||||
constexpr int TILES_M = RegisterTile<T, M, K>::TILES_Y;
|
||||
constexpr int TILES_K = RegisterTile<T, M, K>::TILES_X;
|
||||
constexpr int TILES_N = RegisterTile<T, N, K>::TILES_Y;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < TILES_K; k++) {
|
||||
MLX_UNROLL
|
||||
for (int m = 0; m < TILES_M; m++) {
|
||||
MLX_UNROLL
|
||||
for (int n = 0; n < TILES_N; n++) {
|
||||
mma_t(
|
||||
C.data[m * TILES_N + n],
|
||||
A.data[m * TILES_K + k],
|
||||
B.data[n * TILES_K + k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
419
mlx/backend/cuda/matmul/tiles.cuh
Normal file
419
mlx/backend/cuda/matmul/tiles.cuh
Normal file
@@ -0,0 +1,419 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define MLX_UNROLL _Pragma("unroll")
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Map types to their vector of 2 type float -> float2, double -> double2 etc
|
||||
template <typename T>
|
||||
struct Vector2;
|
||||
template <>
|
||||
struct Vector2<double> {
|
||||
using type = double2;
|
||||
};
|
||||
template <>
|
||||
struct Vector2<float> {
|
||||
using type = float2;
|
||||
};
|
||||
template <>
|
||||
struct Vector2<__half> {
|
||||
using type = __half2;
|
||||
};
|
||||
template <>
|
||||
struct Vector2<__nv_bfloat16> {
|
||||
using type = __nv_bfloat162;
|
||||
};
|
||||
template <typename T>
|
||||
using Vector2_t = typename Vector2<T>::type;
|
||||
|
||||
/**
|
||||
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
||||
* the warp.
|
||||
*
|
||||
* Each thread holds 8 values. They are distributed according to
|
||||
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
||||
*
|
||||
* For use instructions see the individual methods eg load().
|
||||
*/
|
||||
template <typename T>
|
||||
struct Tile16x16 {
|
||||
using T2 = Vector2_t<T>;
|
||||
|
||||
T2 values[4];
|
||||
|
||||
__device__ inline void fill(T v) {
|
||||
T2 v2 = {v, v};
|
||||
for (int i = 0; i < 4; i++) {
|
||||
values[i] = v2;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a 16x16 tile from shared memory.
|
||||
*
|
||||
* The instruction is a bit weird in the sense that the address provided by
|
||||
* each thread and the elements loaded are not the same.
|
||||
*
|
||||
* We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a
|
||||
* result the warp provides 4*8 = 32 addresses one per row.
|
||||
*
|
||||
* Threads 0-7 provide the addresses for the first tile, 8-15 for the second
|
||||
* and so on. For instance to load a non swizzled tile we would do
|
||||
*
|
||||
* base_addr + (laneid % 16) * BK + (laneid / 2) * 8
|
||||
*
|
||||
* See
|
||||
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
|
||||
*/
|
||||
__device__ inline void load(uint32_t row_address) {
|
||||
if constexpr (
|
||||
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(*(uint32_t*)&(values[0])),
|
||||
"=r"(*(uint32_t*)&(values[1])),
|
||||
"=r"(*(uint32_t*)&(values[2])),
|
||||
"=r"(*(uint32_t*)&(values[3]))
|
||||
: "r"(row_address));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Store the tile to the address pointed to by `x`.
|
||||
*
|
||||
* The provided pointer is a generic pointer but this is meant to be used to
|
||||
* store to global memory. For storing to shared memory we should use
|
||||
* `stmatrix`.
|
||||
*
|
||||
* This also showcases the format of the tile quite nicely. Each register is
|
||||
* holding to adjacent values. The indices are
|
||||
*
|
||||
* row + 0, col + 0
|
||||
* row + 8, col + 0
|
||||
* row + 0, col + 8
|
||||
* row + 8, col + 8
|
||||
*
|
||||
* Given that we are dealing with Vector2_t<U> the column offsets are 4
|
||||
* instead of 8.
|
||||
*/
|
||||
template <typename U>
|
||||
__device__ inline void store_global(U* x, int N) {
|
||||
using U2 = Vector2_t<U>;
|
||||
U2* x2 = reinterpret_cast<U2*>(x);
|
||||
const int laneid = threadIdx.x % 32;
|
||||
const int row = laneid / 4;
|
||||
const int col = laneid % 4;
|
||||
if constexpr (std::is_same_v<U2, T2>) {
|
||||
x2[(row + 0) * (N / 2) + col + 0] = values[0];
|
||||
x2[(row + 0) * (N / 2) + col + 4] = values[2];
|
||||
x2[(row + 8) * (N / 2) + col + 0] = values[1];
|
||||
x2[(row + 8) * (N / 2) + col + 4] = values[3];
|
||||
} else if constexpr (
|
||||
std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {
|
||||
x2[(row + 0) * (N / 2) + col + 0] =
|
||||
__floats2bfloat162_rn(values[0].x, values[0].y);
|
||||
x2[(row + 0) * (N / 2) + col + 4] =
|
||||
__floats2bfloat162_rn(values[2].x, values[2].y);
|
||||
x2[(row + 8) * (N / 2) + col + 0] =
|
||||
__floats2bfloat162_rn(values[1].x, values[1].y);
|
||||
x2[(row + 8) * (N / 2) + col + 4] =
|
||||
__floats2bfloat162_rn(values[3].x, values[3].y);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
|
||||
const int laneid = threadIdx.x % 32;
|
||||
const int row = laneid / 4;
|
||||
const int col = laneid % 4;
|
||||
if (row < max_rows) {
|
||||
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
|
||||
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
|
||||
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
|
||||
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
|
||||
}
|
||||
if (row + 8 < max_rows) {
|
||||
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
|
||||
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
|
||||
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
|
||||
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A simple container of multiple Tile16x16.
|
||||
*
|
||||
* Provides utility functions for loading and manipulating collections of basic
|
||||
* tiles.
|
||||
*/
|
||||
template <typename T, int ROWS_, int COLS_>
|
||||
struct RegisterTile {
|
||||
static constexpr int ROWS = ROWS_;
|
||||
static constexpr int COLS = COLS_;
|
||||
static constexpr int TILES_X = COLS / 16;
|
||||
static constexpr int TILES_Y = ROWS / 16;
|
||||
|
||||
Tile16x16<T> data[TILES_X * TILES_Y];
|
||||
|
||||
__device__ inline void fill(T v) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].fill(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tile>
|
||||
__device__ inline void
|
||||
load(Tile& tile, uint32_t base_address, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].load(
|
||||
tile.loc(base_address, row + i * 16, col + j * 16));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ inline void store_global(U* x, int N, int row, int col) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].store_global(
|
||||
x + (row + i * 16) * N + col + j * 16, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ inline void
|
||||
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < TILES_Y; i++) {
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < TILES_X; j++) {
|
||||
data[i * TILES_X + j].store_global_safe(
|
||||
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int ROWS_, int COLS_>
|
||||
struct SharedTile {
|
||||
static constexpr int ROWS = ROWS_;
|
||||
static constexpr int COLS = COLS_;
|
||||
static constexpr int TILES_X = COLS / 16;
|
||||
static constexpr int TILES_Y = ROWS / 16;
|
||||
static constexpr int NUMEL = ROWS * COLS;
|
||||
|
||||
// Swizzle taken from ThunderKittens.
|
||||
//
|
||||
// See inludes/types/shared/st.cuh
|
||||
//
|
||||
// I do feel that it is too math heavy and can be improved. Also the math is
|
||||
// done every time although the addresses don't change from load to load. I
|
||||
// guess we are expecting the compiler to figure that out.
|
||||
static constexpr int swizzle_bytes =
|
||||
(sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32))
|
||||
: (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0));
|
||||
|
||||
T data[ROWS * COLS];
|
||||
|
||||
// Return a pointer to the element at (row, col) using the swizzle.
|
||||
__device__ static inline T* ptr(T* ptr, int row, int col) {
|
||||
if constexpr (swizzle_bytes > 0) {
|
||||
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||
const int outer_idx = col / subtile_cols;
|
||||
const uint64_t addr =
|
||||
(uint64_t)(&ptr
|
||||
[outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||
col % subtile_cols]);
|
||||
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||
return (T*)(addr ^ swizzle);
|
||||
} else {
|
||||
return ptr + row * COLS + col;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the location of the element at (row, col) using the swizzle.
|
||||
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
||||
if constexpr (swizzle_bytes > 0) {
|
||||
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
||||
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
||||
const int outer_idx = col / subtile_cols;
|
||||
const uint32_t addr = ptr +
|
||||
sizeof(T) *
|
||||
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
||||
col % subtile_cols);
|
||||
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||
return (addr ^ swizzle);
|
||||
} else {
|
||||
return ptr + sizeof(T) * (row * COLS + col);
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience functions to edit elements going through the swizzle.
|
||||
__device__ inline T& operator()(int row, int col) {
|
||||
return *ptr(data, row, col);
|
||||
}
|
||||
__device__ inline void store(float4& v, int row, int col) {
|
||||
*(reinterpret_cast<float4*>(ptr(data, row, col))) = v;
|
||||
}
|
||||
__device__ inline void store(float2& v, int row, int col) {
|
||||
*(reinterpret_cast<float2*>(ptr(data, row, col))) = v;
|
||||
}
|
||||
__device__ inline void store(float& v, int row, int col) {
|
||||
*(reinterpret_cast<float*>(ptr(data, row, col))) = v;
|
||||
}
|
||||
template <int N>
|
||||
__device__ inline void store(T (&v)[N], int row, int col) {
|
||||
if constexpr (sizeof(T) * N == 4) {
|
||||
store(*(reinterpret_cast<float*>(&v[0])), row, col);
|
||||
} else if constexpr (sizeof(T) * N == 8) {
|
||||
store(*(reinterpret_cast<float2*>(&v[0])), row, col);
|
||||
} else if constexpr (sizeof(T) * N == 16) {
|
||||
store(*(reinterpret_cast<float4*>(&v[0])), row, col);
|
||||
} else {
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < N; i++) {
|
||||
*ptr(data, row, col + i) = v[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Load the tile from global memory by loading 16 bytes at a time and storing
|
||||
* them immediately.
|
||||
*/
|
||||
template <int NUM_WARPS, typename T, typename Tile>
|
||||
__device__ inline void load(Tile& tile, const T* x, int N) {
|
||||
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||
|
||||
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||
|
||||
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||
float4 tmp;
|
||||
tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));
|
||||
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy 16 bytes from the globale memory address pointed to by x to the smem
|
||||
* address pointed to by row_address.
|
||||
*
|
||||
* A simple wrapper over the PTX.
|
||||
*/
|
||||
template <typename T>
|
||||
__device__ inline void cp_async_16(uint32_t row_address, const T* x) {
|
||||
asm volatile(
|
||||
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
|
||||
"l"(reinterpret_cast<const int4*>(x)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Submit all the previous async copies to be executed.
|
||||
*/
|
||||
__device__ inline void cp_async_commit() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for all the async copies to finish.
|
||||
*/
|
||||
__device__ inline void cp_async_wait_all() {
|
||||
asm volatile("cp.async.wait_all;\n" ::);
|
||||
}
|
||||
|
||||
/**
|
||||
* The asynchronous equivalent of load.
|
||||
*
|
||||
* Loads the tile from global memory by submitting a bunch of async copy
|
||||
* instructions. The copy won't start until commit is called and we don't have
|
||||
* a guarantee it will finish until wait is called.
|
||||
*
|
||||
* It should be used as follows
|
||||
*
|
||||
* load(...)
|
||||
* load(...)
|
||||
* cp_async_commit()
|
||||
* do_other_stuff()
|
||||
* cp_async_wait_all()
|
||||
* do_stuff_with_shmem()
|
||||
*/
|
||||
template <int NUM_WARPS, typename T, typename Tile>
|
||||
__device__ inline void
|
||||
load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
|
||||
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||
|
||||
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||
|
||||
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||
cp_async_16(
|
||||
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
||||
x + i * STEP_ROWS * N);
|
||||
}
|
||||
}
|
||||
|
||||
template <int NUM_WARPS, typename T, typename Tile>
|
||||
__device__ inline void load_async_safe(
|
||||
Tile& tile,
|
||||
uint32_t base_address,
|
||||
const T* x,
|
||||
int N,
|
||||
int max_rows) {
|
||||
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
||||
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||
|
||||
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||
|
||||
x += row * N + col * ELEMENTS_PER_LOAD;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||
if (row + i * STEP_ROWS < max_rows) {
|
||||
cp_async_16(
|
||||
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
||||
x + i * STEP_ROWS * N);
|
||||
} else {
|
||||
float4 tmp = {0, 0, 0, 0};
|
||||
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -81,7 +81,6 @@ NO_GPU(Hadamard)
|
||||
NO_GPU(Load)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(SegmentedMM)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Inverse)
|
||||
|
||||
@@ -2,30 +2,17 @@
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_pack_factor() {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_bytes_per_pack() {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
__global__ void
|
||||
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
|
||||
@@ -240,144 +227,102 @@ __global__ void affine_dequantize(
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
namespace {
|
||||
|
||||
inline array ensure_row_contiguous(
|
||||
const array& x,
|
||||
void affine_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
array& biases,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename F>
|
||||
void dispatch_groups(int group_size, F&& f) {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
f(std::integral_constant<int, 32>{});
|
||||
break;
|
||||
case 64:
|
||||
f(std::integral_constant<int, 64>{});
|
||||
break;
|
||||
case 128:
|
||||
f(std::integral_constant<int, 128>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_bits(int bits, F&& f) {
|
||||
switch (bits) {
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 3:
|
||||
f(std::integral_constant<int, 3>{});
|
||||
break;
|
||||
case 4:
|
||||
f(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
case 5:
|
||||
f(std::integral_constant<int, 5>{});
|
||||
break;
|
||||
case 6:
|
||||
f(std::integral_constant<int, 6>{});
|
||||
break;
|
||||
case 8:
|
||||
f(std::integral_constant<int, 8>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
auto w = ensure_row_contiguous(w_pre, enc, s);
|
||||
enc.set_input_array(w);
|
||||
if (dequantize_) {
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
enc.set_input_array(scales);
|
||||
enc.set_input_array(biases);
|
||||
enc.set_output_array(out);
|
||||
} else {
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
enc.set_output_array(out);
|
||||
enc.set_output_array(scales);
|
||||
enc.set_output_array(biases);
|
||||
}
|
||||
|
||||
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
|
||||
|
||||
// Treat uint32 as uint8 in kernel
|
||||
int uint8_per_uint32 = 4;
|
||||
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
|
||||
: bits_ == 6 ? 4
|
||||
: 8 / bits_;
|
||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
|
||||
size_t size =
|
||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||
// Calculate the number of elements per thread
|
||||
int per_thread = group_size_ / WARP_SIZE;
|
||||
size_t size = w.size() / per_thread;
|
||||
|
||||
// Calculate the thread grid that we need to launch
|
||||
bool large = size > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
grid_shape.back() /= per_thread;
|
||||
|
||||
if (dequantize_) {
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
} else {
|
||||
grid_shape.back() /= per_thread;
|
||||
}
|
||||
|
||||
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
|
||||
enc.set_input_array(w);
|
||||
enc.set_output_array(wq);
|
||||
enc.set_output_array(scales);
|
||||
enc.set_output_array(biases);
|
||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
if (dequantize_) {
|
||||
auto kernel =
|
||||
cu::affine_dequantize<DataType, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
w.data<uint8_t>(),
|
||||
inputs[1].data<DataType>(),
|
||||
inputs[2].data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
out.size());
|
||||
} else {
|
||||
auto kernel =
|
||||
cu::affine_quantize<DataType, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
w.data<DataType>(),
|
||||
out.data<uint8_t>(),
|
||||
outputs[1].data<DataType>(),
|
||||
outputs[2].data<DataType>(),
|
||||
w.size());
|
||||
}
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
w.data<T>(),
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
biases.data<T>(),
|
||||
w.size());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void affine_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& w,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
|
||||
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
int packs_per_int;
|
||||
switch (bits_) {
|
||||
case 3:
|
||||
case 5:
|
||||
packs_per_int = 8;
|
||||
break;
|
||||
case 6:
|
||||
packs_per_int = 4;
|
||||
break;
|
||||
default:
|
||||
packs_per_int = 8 / bits_;
|
||||
}
|
||||
|
||||
size_t size = w.size() / packs_per_int;
|
||||
bool large = size > UINT_MAX;
|
||||
auto grid_shape = w.shape();
|
||||
grid_shape.back() *= uint8_per_uint32;
|
||||
|
||||
enc.set_input_array(wq);
|
||||
enc.set_input_array(scales);
|
||||
enc.set_input_array(biases);
|
||||
enc.set_output_array(w);
|
||||
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(kernel, size, grid_shape, w.strides(), large);
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
wq.data<uint8_t>(),
|
||||
scales.data<T>(),
|
||||
biases.data<T>(),
|
||||
w.data<T>(),
|
||||
w.size());
|
||||
});
|
||||
});
|
||||
});
|
||||
228
mlx/backend/cuda/quantized/qmm.cu
Normal file
228
mlx/backend/cuda/quantized/qmm.cu
Normal file
@@ -0,0 +1,228 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/matmul/mma.cuh"
|
||||
#include "mlx/backend/cuda/matmul/tiles.cuh"
|
||||
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <int NUM_WARPS, int group_size, int bits, typename T, typename Tile>
|
||||
__device__ inline void load_quantized(
|
||||
Tile& tile,
|
||||
const uint8_t* x,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int N) {
|
||||
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
||||
constexpr int ELEMENTS_PER_LOAD = sizeof(uint32_t) * get_pack_factor<bits>();
|
||||
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
||||
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
||||
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
||||
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
||||
constexpr int MASK = (1 << bits) - 1;
|
||||
|
||||
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
||||
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
||||
|
||||
const int Nx = N / get_pack_factor<bits>();
|
||||
const int Ng = N / group_size;
|
||||
|
||||
x += row * Nx + col * (ELEMENTS_PER_LOAD / get_pack_factor<bits>());
|
||||
scales += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
|
||||
biases += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
|
||||
|
||||
MLX_UNROLL
|
||||
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
||||
T vs[ELEMENTS_PER_LOAD];
|
||||
uint32_t w = *reinterpret_cast<const uint32_t*>(x + i * STEP_ROWS * Nx);
|
||||
T s = scales[i * STEP_ROWS * Ng];
|
||||
T b = biases[i * STEP_ROWS * Ng];
|
||||
MLX_UNROLL
|
||||
for (int j = 0; j < ELEMENTS_PER_LOAD; j++) {
|
||||
vs[j] = static_cast<T>((w >> (j * bits)) & MASK) * s + b;
|
||||
}
|
||||
tile.store(vs, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool aligned_M>
|
||||
__global__ void qmm_t(
|
||||
const T* x,
|
||||
const uint8_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
T* y,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int WARPS_M = 2;
|
||||
constexpr int WARPS_N = 4;
|
||||
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
||||
constexpr int WARP_STEP_M = BM / WARPS_M;
|
||||
constexpr int WARP_STEP_N = BN / WARPS_N;
|
||||
|
||||
const int warpid = threadIdx.x / 32;
|
||||
const int laneid = threadIdx.x % 32;
|
||||
const int wm = warpid / WARPS_N;
|
||||
const int wn = warpid % WARPS_N;
|
||||
const int offset_m = wm * WARP_STEP_M;
|
||||
const int offset_n = wn * WARP_STEP_N;
|
||||
|
||||
extern __shared__ char shmem[];
|
||||
SharedTile<T, BM, BK>(&xs)[1] = *(SharedTile<T, BM, BK>(*)[1])(&shmem[0]);
|
||||
SharedTile<T, BN, BK>(&ws)[1] =
|
||||
*(SharedTile<T, BN, BK>(*)[1])(&shmem[1 * sizeof(T) * BM * BK]);
|
||||
|
||||
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
||||
RegisterTile<T, BM / WARPS_M, 16> A;
|
||||
RegisterTile<T, BN / WARPS_N, 16> B;
|
||||
|
||||
const int max_rows = M - blockIdx.y * BM;
|
||||
|
||||
x += blockIdx.y * BM * K;
|
||||
w += blockIdx.x * BN * K / get_pack_factor<bits>();
|
||||
scales += blockIdx.x * BN * K / group_size;
|
||||
biases += blockIdx.x * BN * K / group_size;
|
||||
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
||||
|
||||
C.fill(0);
|
||||
|
||||
int tic = 0;
|
||||
uint32_t base_addr_xs[1], base_addr_ws[1];
|
||||
base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]);
|
||||
base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]);
|
||||
|
||||
if (aligned_M || max_rows >= BM) {
|
||||
for (int k_block = 0; k_block < K; k_block += BK) {
|
||||
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
|
||||
cp_async_commit();
|
||||
load_quantized<NUM_WARPS, group_size, bits>(
|
||||
ws[tic],
|
||||
w + k_block / get_pack_factor<bits>(),
|
||||
scales + k_block / group_size,
|
||||
biases + k_block / group_size,
|
||||
K);
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
xs[tic],
|
||||
base_addr_xs[tic],
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
ws[tic],
|
||||
base_addr_ws[tic],
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
mma_t(C, A, B);
|
||||
}
|
||||
}
|
||||
C.store_global(y, N, offset_m, offset_n);
|
||||
} else {
|
||||
for (int k_block = 0; k_block < K; k_block += BK) {
|
||||
load_async_safe<NUM_WARPS>(
|
||||
xs[tic], base_addr_xs[tic], x + k_block, K, max_rows);
|
||||
cp_async_commit();
|
||||
load_quantized<NUM_WARPS, group_size, bits>(
|
||||
ws[tic],
|
||||
w + k_block / get_pack_factor<bits>(),
|
||||
scales + k_block / group_size,
|
||||
biases + k_block / group_size,
|
||||
K);
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
|
||||
MLX_UNROLL
|
||||
for (int k = 0; k < BK / 16; k++) {
|
||||
A.load(
|
||||
xs[tic],
|
||||
base_addr_xs[tic],
|
||||
offset_m + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
B.load(
|
||||
ws[tic],
|
||||
base_addr_ws[tic],
|
||||
offset_n + laneid % 16,
|
||||
k * 16 + laneid / 16 * 8);
|
||||
mma_t(C, A, B);
|
||||
}
|
||||
}
|
||||
C.store_global_safe(y, N, offset_m, offset_n, max_rows);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& out,
|
||||
bool transpose_,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (x.dtype() != bfloat16) {
|
||||
throw std::invalid_argument("[qmm] Only bfloat16 is supported for now");
|
||||
}
|
||||
if (!transpose_) {
|
||||
throw std::invalid_argument(
|
||||
"[qmm] Only transposed matmul is supported for now");
|
||||
}
|
||||
|
||||
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
|
||||
dispatch_groups(group_size_, [&](auto group_size) {
|
||||
dispatch_bits(bits_, [&](auto bits) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
constexpr int BM = 128;
|
||||
constexpr int BN = 128;
|
||||
constexpr int BK = 32;
|
||||
auto kernel =
|
||||
cu::qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, true>;
|
||||
if (M % BM != 0) {
|
||||
kernel = cu::
|
||||
qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, false>;
|
||||
}
|
||||
|
||||
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
|
||||
|
||||
enc.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
2 * 4 * 32,
|
||||
1 * sizeof(DataType) * (BM * BK + BN * BK),
|
||||
x.data<DataType>(),
|
||||
w.data<uint8_t>(),
|
||||
scales.data<DataType>(),
|
||||
biases.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
113
mlx/backend/cuda/quantized/quantized.cu
Normal file
113
mlx/backend/cuda/quantized/quantized.cu
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/quantized/quantized.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline array ensure_row_contiguous(
|
||||
const array& x,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
inline array ensure_row_contiguous_matrix(
|
||||
const array& x,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Make sure the last two dims of x and w, s, b are contiguous. This should
|
||||
// be relaxed for x.
|
||||
array x = ensure_row_contiguous_matrix(inputs[0], enc, s);
|
||||
array w = ensure_row_contiguous_matrix(inputs[1], enc, s);
|
||||
array scales = ensure_row_contiguous_matrix(inputs[2], enc, s);
|
||||
array biases = ensure_row_contiguous_matrix(inputs[3], enc, s);
|
||||
|
||||
// Extract the matmul shapes
|
||||
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
|
||||
int K = x.shape(-1);
|
||||
int M = non_batched ? x.size() / K : x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
qmm(x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
out,
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
enc,
|
||||
s);
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
if (dequantize_) {
|
||||
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
auto& w = outputs[0];
|
||||
|
||||
w.set_data(allocator::malloc(w.nbytes()));
|
||||
|
||||
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
auto w = ensure_row_contiguous(inputs[0], enc, s);
|
||||
auto& wq = outputs[0];
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
|
||||
wq.set_data(allocator::malloc(wq.nbytes()));
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
42
mlx/backend/cuda/quantized/quantized.cuh
Normal file
42
mlx/backend/cuda/quantized/quantized.cuh
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void affine_quantize(
|
||||
const array& w,
|
||||
array& wq,
|
||||
array& scales,
|
||||
array& biases,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void affine_dequantize(
|
||||
const array& wq,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& w,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
void qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& out,
|
||||
bool transpose_,
|
||||
int group_size_,
|
||||
int bits_,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
59
mlx/backend/cuda/quantized/quantized_utils.cuh
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_pack_factor() {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
template <int bits, int wsize = 8>
|
||||
inline constexpr __device__ short get_bytes_per_pack() {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename F>
|
||||
void dispatch_groups(int group_size, F&& f) {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
f(std::integral_constant<int, 32>{});
|
||||
break;
|
||||
case 64:
|
||||
f(std::integral_constant<int, 64>{});
|
||||
break;
|
||||
case 128:
|
||||
f(std::integral_constant<int, 128>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_bits(int bits, F&& f) {
|
||||
switch (bits) {
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 3:
|
||||
f(std::integral_constant<int, 3>{});
|
||||
break;
|
||||
case 4:
|
||||
f(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
case 5:
|
||||
f(std::integral_constant<int, 5>{});
|
||||
break;
|
||||
case 6:
|
||||
f(std::integral_constant<int, 6>{});
|
||||
break;
|
||||
case 8:
|
||||
f(std::integral_constant<int, 8>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
cu::rbitsc,
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
@@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
cu::rbits,
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
keys.data<uint32_t>(),
|
||||
out.data<uint8_t>(),
|
||||
grid_dims,
|
||||
|
||||
@@ -120,6 +120,7 @@ void all_reduce(
|
||||
kernel,
|
||||
blocks,
|
||||
threads,
|
||||
0,
|
||||
static_cast<T*>(indata),
|
||||
intermediate.data<U>(),
|
||||
block_step,
|
||||
@@ -146,6 +147,7 @@ void all_reduce(
|
||||
kernel,
|
||||
blocks,
|
||||
threads,
|
||||
0,
|
||||
static_cast<T*>(indata),
|
||||
out.data<U>(),
|
||||
block_step,
|
||||
|
||||
@@ -230,7 +230,7 @@ void col_reduce_looped(
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, blocks, indata, out.data<U>(), args);
|
||||
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -41,7 +41,8 @@ void init_reduce(
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||
grid.x = (grid.x + 1023) / 1024;
|
||||
encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size());
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, out.data<U>(), out.size());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -269,7 +269,7 @@ void row_reduce_simple(
|
||||
|
||||
int size = plan.shape.back();
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, indata, out.data<U>(), out.size(), size);
|
||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -322,7 +322,7 @@ void row_reduce_looped(
|
||||
});
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, indata, out.data<U>(), out.size(), args);
|
||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -232,6 +232,7 @@ void RMSNorm::eval_gpu(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
@@ -327,6 +328,7 @@ void RMSNormVJP::eval_gpu(
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
x.data<DataType>(),
|
||||
w.data<DataType>(),
|
||||
g.data<DataType>(),
|
||||
|
||||
@@ -325,6 +325,7 @@ void RoPE::eval_gpu(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
@@ -341,6 +342,7 @@ void RoPE::eval_gpu(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
@@ -360,6 +362,7 @@ void RoPE::eval_gpu(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
@@ -381,6 +384,7 @@ void RoPE::eval_gpu(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
(donated ? out : in).data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
offset.data<int32_t>(),
|
||||
|
||||
@@ -414,6 +414,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel,
|
||||
in.data_size() / axis_size,
|
||||
block_dim,
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
axis_size);
|
||||
@@ -443,6 +444,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dim,
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
axis_size,
|
||||
|
||||
@@ -152,6 +152,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel,
|
||||
n_rows,
|
||||
block_dim(),
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
out.data<DataType>(),
|
||||
axis_size);
|
||||
|
||||
@@ -133,6 +133,7 @@ void ternary_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
@@ -151,6 +152,7 @@ void ternary_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
@@ -180,6 +182,7 @@ void ternary_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
|
||||
@@ -142,6 +142,7 @@ void unary_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
@@ -154,6 +155,7 @@ void unary_op_gpu_inplace(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
|
||||
Reference in New Issue
Block a user