Compare commits

..

3 Commits

Author SHA1 Message Date
Awni Hannun
641be9463b Add more CUDA architectures for PyPi package (#2427)
* add cuda sm 90

* add more archs
2025-07-28 12:35:15 -07:00
Awni Hannun
ab0e608862 [CUDA] More sizes for gemv (#2429)
* route more to gemv

* route more sizes to custom gemv
2025-07-28 12:35:01 -07:00
Awni Hannun
1588659062 no occupancy query for launch params (#2426) 2025-07-28 09:09:41 -07:00
6 changed files with 63 additions and 74 deletions

View File

@@ -105,11 +105,11 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"70;80"
CACHE STRING "CUDA architectures")
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native")
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")

View File

@@ -97,23 +97,6 @@ CommandEncoder::CaptureContext::~CaptureContext() {
if (discard) {
return;
}
// Extract and add as single kernel node when possible.
size_t num_nodes;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
if (num_nodes == 1) {
cudaGraphNode_t captured_node;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
if (type == cudaGraphNodeTypeKernel) {
CUDA_KERNEL_NODE_PARAMS params;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
enc.add_kernel_node(params);
return;
}
}
// Otherwise add the captured graph as subgraph.
enc.add_graph_node(graph);
}

View File

@@ -49,11 +49,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
}
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
#if __CUDA_ARCH__ < 900
atomic_add_general(out, val);
#else
atomicAdd(out, val);
#endif
}
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {

View File

@@ -11,7 +11,6 @@ namespace mlx::core::cu {
namespace cg = cooperative_groups;
static constexpr int n_per_thread = 4;
static constexpr int rows_per_block = 8;
template <typename T, int rows_per_block, int n_per_thread>
@@ -74,8 +73,23 @@ __global__ void gemv_batched(
}
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
return K % (WARP_SIZE * n_per_thread) == 0 &&
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0;
return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
}
template <typename F>
void dispatch_n_per_thread(int n_per_thread, F&& f) {
switch (n_per_thread) {
case 1:
f(std::integral_constant<int, 1>{});
break;
case 2:
f(std::integral_constant<int, 2>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
}
}
void gemv(
@@ -114,33 +128,39 @@ void gemv(
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>;
encoder.add_kernel_node(
kernel,
num_blocks_x,
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols);
} else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
int n_per_t = 4;
while (K % (n_per_t * WARP_SIZE) != 0) {
n_per_t >>= 1;
}
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
num_blocks_x,
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols);
} else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
}
});
});
}

View File

@@ -120,20 +120,6 @@ dim3 get_2d_grid_dims(
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim;
}
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
template <typename T>
@@ -145,7 +131,7 @@ inline std::tuple<dim3, uint> get_launch_args(
bool large,
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}

View File

@@ -44,6 +44,8 @@ def get_version():
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
# A CMakeExtension needs a sourcedir instead of a file list.
@@ -85,6 +87,11 @@ class CMakeBuild(build_ext):
"-DMLX_BUILD_EXAMPLES=OFF",
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
]
if build_stage == 2 and build_cuda:
# Last arch is always real and virtual for forward-compatibility
cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
# Some generators require explcitly passing config when building.
build_args = ["--config", cfg]
# Adding CMake arguments set as environment variable
@@ -95,7 +102,7 @@ class CMakeBuild(build_ext):
# Pass version to C++
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
if platform.system() == "Darwin":
if build_macos:
# Cross-compile support for macOS - respect ARCHFLAGS if set
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
if archs:
@@ -202,9 +209,6 @@ if __name__ == "__main__":
],
)
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
version = get_version()
_setup = partial(