mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
3 Commits
b9e88fb976
...
641be9463b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 |
@@ -105,11 +105,11 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||
endif()
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"70;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||
# managed memory.
|
||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||
set(MLX_CUDA_ARCHITECTURES "native")
|
||||
endif()
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
"${MLX_CUDA_ARCHITECTURES}")
|
||||
|
||||
@@ -97,23 +97,6 @@ CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
if (discard) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract and add as single kernel node when possible.
|
||||
size_t num_nodes;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
|
||||
if (num_nodes == 1) {
|
||||
cudaGraphNode_t captured_node;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
|
||||
if (type == cudaGraphNodeTypeKernel) {
|
||||
CUDA_KERNEL_NODE_PARAMS params;
|
||||
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, ¶ms));
|
||||
enc.add_kernel_node(params);
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Otherwise add the captured graph as subgraph.
|
||||
enc.add_graph_node(graph);
|
||||
}
|
||||
|
||||
|
||||
@@ -49,11 +49,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
|
||||
#if __CUDA_ARCH__ < 900
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||
|
||||
@@ -11,7 +11,6 @@ namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
static constexpr int n_per_thread = 4;
|
||||
static constexpr int rows_per_block = 8;
|
||||
|
||||
template <typename T, int rows_per_block, int n_per_thread>
|
||||
@@ -74,8 +73,23 @@ __global__ void gemv_batched(
|
||||
}
|
||||
|
||||
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
|
||||
return K % (WARP_SIZE * n_per_thread) == 0 &&
|
||||
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||
bool is_multiple = K % 32 == 0 || K % 64 == 0 || K % 128 == 0;
|
||||
return is_multiple && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_n_per_thread(int n_per_thread, F&& f) {
|
||||
switch (n_per_thread) {
|
||||
case 1:
|
||||
f(std::integral_constant<int, 1>{});
|
||||
break;
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 4:
|
||||
f(std::integral_constant<int, 4>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void gemv(
|
||||
@@ -114,33 +128,39 @@ void gemv(
|
||||
rows = M;
|
||||
}
|
||||
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
|
||||
if (batch_count == 1) {
|
||||
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks_x,
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols);
|
||||
} else {
|
||||
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
dim3{num_blocks_x, batch_count},
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols,
|
||||
const_param(batch_shape),
|
||||
mat_strides,
|
||||
vec_strides,
|
||||
batch_shape.size());
|
||||
int n_per_t = 4;
|
||||
while (K % (n_per_t * WARP_SIZE) != 0) {
|
||||
n_per_t >>= 1;
|
||||
}
|
||||
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
|
||||
if (batch_count == 1) {
|
||||
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks_x,
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols);
|
||||
} else {
|
||||
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
dim3{num_blocks_x, batch_count},
|
||||
block_dims,
|
||||
mat,
|
||||
vec,
|
||||
out.data<DataType>(),
|
||||
rows,
|
||||
cols,
|
||||
const_param(batch_shape),
|
||||
mat_strides,
|
||||
vec_strides,
|
||||
batch_shape.size());
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -120,20 +120,6 @@ dim3 get_2d_grid_dims(
|
||||
size_t divisor);
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
|
||||
// Return a block size that achieves maximum potential occupancy for kernel.
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
if constexpr (std::is_same_v<T, CUfunction>) {
|
||||
CHECK_CUDA_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
} else {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||
}
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||
template <typename T>
|
||||
@@ -145,7 +131,7 @@ inline std::tuple<dim3, uint> get_launch_args(
|
||||
bool large,
|
||||
int work_per_thread = 1) {
|
||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||
uint block_dim = max_occupancy_block_dim(kernel);
|
||||
uint block_dim = 1024;
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
|
||||
12
setup.py
12
setup.py
@@ -44,6 +44,8 @@ def get_version():
|
||||
|
||||
|
||||
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
|
||||
build_macos = platform.system() == "Darwin"
|
||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||
|
||||
|
||||
# A CMakeExtension needs a sourcedir instead of a file list.
|
||||
@@ -85,6 +87,11 @@ class CMakeBuild(build_ext):
|
||||
"-DMLX_BUILD_EXAMPLES=OFF",
|
||||
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
||||
]
|
||||
if build_stage == 2 and build_cuda:
|
||||
# Last arch is always real and virtual for forward-compatibility
|
||||
cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
|
||||
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
|
||||
|
||||
# Some generators require explcitly passing config when building.
|
||||
build_args = ["--config", cfg]
|
||||
# Adding CMake arguments set as environment variable
|
||||
@@ -95,7 +102,7 @@ class CMakeBuild(build_ext):
|
||||
# Pass version to C++
|
||||
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
if build_macos:
|
||||
# Cross-compile support for macOS - respect ARCHFLAGS if set
|
||||
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
|
||||
if archs:
|
||||
@@ -202,9 +209,6 @@ if __name__ == "__main__":
|
||||
],
|
||||
)
|
||||
|
||||
build_macos = platform.system() == "Darwin"
|
||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||
|
||||
version = get_version()
|
||||
|
||||
_setup = partial(
|
||||
|
||||
Reference in New Issue
Block a user