mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	[CUDA] Always use batched matmul (#2404)
* cuda batched mm * addmm as well * comment
This commit is contained in:
		| @@ -21,7 +21,8 @@ target_sources( | |||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp |           ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/event.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/event.cu | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp |           ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/gemv.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu | ||||||
|  |           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp |           ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp |           ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu | ||||||
| @@ -47,6 +48,14 @@ target_sources( | |||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) |           ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) | ||||||
|  |  | ||||||
|  | if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) | ||||||
|  |   target_sources( | ||||||
|  |     mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu) | ||||||
|  | else() | ||||||
|  |   target_sources( | ||||||
|  |     mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp) | ||||||
|  | endif() | ||||||
|  |  | ||||||
| target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) | target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) | ||||||
|  |  | ||||||
| # Embed kernel sources in binary for JIT compilation. | # Embed kernel sources in binary for JIT compilation. | ||||||
|   | |||||||
							
								
								
									
										73
									
								
								mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | |||||||
|  | // Copyright © 2025 Apple Inc. | ||||||
|  |  | ||||||
|  | #include "mlx/backend/common/utils.h" | ||||||
|  | #include "mlx/backend/cuda/device.h" | ||||||
|  | #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||||
|  |  | ||||||
|  | namespace mlx::core::cu { | ||||||
|  |  | ||||||
|  | void Matmul::run_batched( | ||||||
|  |     cu::CommandEncoder& encoder, | ||||||
|  |     array& out, | ||||||
|  |     const array& a, | ||||||
|  |     const array& b, | ||||||
|  |     const mlx::core::Shape& batch_shape, | ||||||
|  |     const mlx::core::Strides& a_batch_strides, | ||||||
|  |     const mlx::core::Strides& b_batch_strides) { | ||||||
|  |   encoder.set_input_array(a); | ||||||
|  |   encoder.set_input_array(b); | ||||||
|  |   encoder.set_output_array(out); | ||||||
|  |   auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); | ||||||
|  |   ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); | ||||||
|  |   ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); | ||||||
|  |   auto concurrent = encoder.concurrent_context(); | ||||||
|  |   for (size_t i = 0; i < nbatch; ++i) { | ||||||
|  |     run_impl( | ||||||
|  |         encoder, | ||||||
|  |         out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, | ||||||
|  |         a.data<int8_t>() + a.itemsize() * a_it.loc, | ||||||
|  |         b.data<int8_t>() + b.itemsize() * b_it.loc, | ||||||
|  |         nullptr); | ||||||
|  |     a_it.step(); | ||||||
|  |     b_it.step(); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void Matmul::run_batched( | ||||||
|  |     cu::CommandEncoder& encoder, | ||||||
|  |     array& out, | ||||||
|  |     const array& a, | ||||||
|  |     const array& b, | ||||||
|  |     const array& c, | ||||||
|  |     const mlx::core::Shape& batch_shape, | ||||||
|  |     const mlx::core::Strides& a_batch_strides, | ||||||
|  |     const mlx::core::Strides& b_batch_strides, | ||||||
|  |     const mlx::core::Strides& c_batch_strides, | ||||||
|  |     float alpha, | ||||||
|  |     float beta) { | ||||||
|  |   encoder.set_input_array(a); | ||||||
|  |   encoder.set_input_array(b); | ||||||
|  |   encoder.set_input_array(c); | ||||||
|  |   encoder.set_output_array(out); | ||||||
|  |  | ||||||
|  |   auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); | ||||||
|  |   ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); | ||||||
|  |   ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); | ||||||
|  |   ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); | ||||||
|  |   auto concurrent = encoder.concurrent_context(); | ||||||
|  |   for (size_t i = 0; i < nbatch; ++i) { | ||||||
|  |     run_impl( | ||||||
|  |         encoder, | ||||||
|  |         out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, | ||||||
|  |         a.data<int8_t>() + a.itemsize() * a_it.loc, | ||||||
|  |         b.data<int8_t>() + b.itemsize() * b_it.loc, | ||||||
|  |         c.data<int8_t>() + c.itemsize() * c_it.loc, | ||||||
|  |         alpha, | ||||||
|  |         beta); | ||||||
|  |     a_it.step(); | ||||||
|  |     b_it.step(); | ||||||
|  |     c_it.step(); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace mlx::core::cu | ||||||
							
								
								
									
										206
									
								
								mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,206 @@ | |||||||
|  | // Copyright © 2025 Apple Inc. | ||||||
|  |  | ||||||
|  | #include "mlx/backend/cuda/device.h" | ||||||
|  | #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||||
|  | #include "mlx/backend/cuda/kernel_utils.cuh" | ||||||
|  |  | ||||||
|  | #include <cooperative_groups.h> | ||||||
|  |  | ||||||
|  | namespace mlx::core::cu { | ||||||
|  |  | ||||||
|  | namespace cg = cooperative_groups; | ||||||
|  |  | ||||||
|  | __global__ void set_mm_device_pointers( | ||||||
|  |     int8_t** pointers, | ||||||
|  |     int8_t* a_start, | ||||||
|  |     int8_t* b_start, | ||||||
|  |     int8_t* out_start, | ||||||
|  |     int item_size, | ||||||
|  |     const __grid_constant__ Shape batch_shape, | ||||||
|  |     const __grid_constant__ Strides a_batch_strides, | ||||||
|  |     const __grid_constant__ Strides b_batch_strides, | ||||||
|  |     int64_t batch_stride, | ||||||
|  |     int batch_ndim, | ||||||
|  |     int batch_count) { | ||||||
|  |   auto index = cg::this_grid().thread_rank(); | ||||||
|  |   if (index >= batch_count) { | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |   auto [a_offset, b_offset] = elem_to_loc( | ||||||
|  |       index, | ||||||
|  |       batch_shape.data(), | ||||||
|  |       a_batch_strides.data(), | ||||||
|  |       b_batch_strides.data(), | ||||||
|  |       batch_ndim); | ||||||
|  |   pointers[index] = a_start + item_size * a_offset; | ||||||
|  |   pointers[index + batch_count] = b_start + item_size * b_offset; | ||||||
|  |   pointers[index + 2 * batch_count] = | ||||||
|  |       out_start + item_size * index * batch_stride; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | __global__ void set_addmm_device_pointers( | ||||||
|  |     int8_t** pointers, | ||||||
|  |     int8_t* a_start, | ||||||
|  |     int8_t* b_start, | ||||||
|  |     int8_t* c_start, | ||||||
|  |     int8_t* out_start, | ||||||
|  |     int item_size, | ||||||
|  |     const __grid_constant__ Shape batch_shape, | ||||||
|  |     const __grid_constant__ Strides a_batch_strides, | ||||||
|  |     const __grid_constant__ Strides b_batch_strides, | ||||||
|  |     const __grid_constant__ Strides c_batch_strides, | ||||||
|  |     int64_t batch_stride, | ||||||
|  |     int batch_ndim, | ||||||
|  |     int batch_count) { | ||||||
|  |   auto index = cg::this_grid().thread_rank(); | ||||||
|  |   if (index >= batch_count) { | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |   auto [a_offset, b_offset, c_offset] = elem_to_loc( | ||||||
|  |       index, | ||||||
|  |       batch_shape.data(), | ||||||
|  |       a_batch_strides.data(), | ||||||
|  |       b_batch_strides.data(), | ||||||
|  |       c_batch_strides.data(), | ||||||
|  |       batch_ndim); | ||||||
|  |   pointers[index] = a_start + item_size * a_offset; | ||||||
|  |   pointers[index + batch_count] = b_start + item_size * b_offset; | ||||||
|  |   pointers[index + 2 * batch_count] = c_start + item_size * c_offset; | ||||||
|  |   pointers[index + 3 * batch_count] = | ||||||
|  |       out_start + item_size * index * batch_stride; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { | ||||||
|  |   auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||||
|  |       desc, | ||||||
|  |       CUBLASLT_MATRIX_LAYOUT_BATCH_MODE, | ||||||
|  |       &batch_mode, | ||||||
|  |       sizeof(batch_mode))); | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||||
|  |       desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void Matmul::run_batched( | ||||||
|  |     cu::CommandEncoder& encoder, | ||||||
|  |     array& out, | ||||||
|  |     const array& a, | ||||||
|  |     const array& b, | ||||||
|  |     const mlx::core::Shape& batch_shape, | ||||||
|  |     const mlx::core::Strides& a_batch_strides, | ||||||
|  |     const mlx::core::Strides& b_batch_strides) { | ||||||
|  |   auto batch_count = out.size() / (M_ * N_); | ||||||
|  |   set_pointer_mode(a_desc_, batch_count); | ||||||
|  |   set_pointer_mode(b_desc_, batch_count); | ||||||
|  |   set_pointer_mode(out_desc_, batch_count); | ||||||
|  |  | ||||||
|  |   // Launch kernel to set device offsets | ||||||
|  |   auto pointers = array( | ||||||
|  |       allocator::malloc(batch_count * sizeof(uint64_t) * 3), | ||||||
|  |       {static_cast<int>(batch_count * 3)}, | ||||||
|  |       uint64); | ||||||
|  |  | ||||||
|  |   encoder.add_temporary(pointers); | ||||||
|  |   int block_size = 512; | ||||||
|  |   encoder.set_output_array(pointers); | ||||||
|  |  | ||||||
|  |   encoder.add_kernel_node( | ||||||
|  |       cu::set_mm_device_pointers, | ||||||
|  |       cuda::ceil_div(pointers.size(), block_size), | ||||||
|  |       block_size, | ||||||
|  |       pointers.data<int8_t*>(), | ||||||
|  |       a.data<int8_t>(), | ||||||
|  |       b.data<int8_t>(), | ||||||
|  |       out.data<int8_t>(), | ||||||
|  |       static_cast<int>(out.dtype().size()), | ||||||
|  |       const_param(batch_shape), | ||||||
|  |       const_param(a_batch_strides), | ||||||
|  |       const_param(b_batch_strides), | ||||||
|  |       static_cast<int64_t>(M_) * N_, | ||||||
|  |       static_cast<int>(batch_shape.size()), | ||||||
|  |       batch_count); | ||||||
|  |  | ||||||
|  |   // Run matmul | ||||||
|  |   encoder.set_input_array(pointers); | ||||||
|  |   encoder.set_input_array(a); | ||||||
|  |   encoder.set_input_array(b); | ||||||
|  |   encoder.set_output_array(out); | ||||||
|  |  | ||||||
|  |   auto a_pointers = pointers.data<int8_t*>(); | ||||||
|  |   auto b_pointers = a_pointers + batch_count; | ||||||
|  |   auto out_pointers = b_pointers + batch_count; | ||||||
|  |   run_impl( | ||||||
|  |       encoder, | ||||||
|  |       reinterpret_cast<void*>(out_pointers), | ||||||
|  |       reinterpret_cast<void*>(a_pointers), | ||||||
|  |       reinterpret_cast<void*>(b_pointers), | ||||||
|  |       nullptr); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void Matmul::run_batched( | ||||||
|  |     cu::CommandEncoder& encoder, | ||||||
|  |     array& out, | ||||||
|  |     const array& a, | ||||||
|  |     const array& b, | ||||||
|  |     const array& c, | ||||||
|  |     const mlx::core::Shape& batch_shape, | ||||||
|  |     const mlx::core::Strides& a_batch_strides, | ||||||
|  |     const mlx::core::Strides& b_batch_strides, | ||||||
|  |     const mlx::core::Strides& c_batch_strides, | ||||||
|  |     float alpha, | ||||||
|  |     float beta) { | ||||||
|  |   auto batch_count = out.size() / (M_ * N_); | ||||||
|  |   set_pointer_mode(a_desc_, batch_count); | ||||||
|  |   set_pointer_mode(b_desc_, batch_count); | ||||||
|  |   set_pointer_mode(c_desc_, batch_count); | ||||||
|  |   set_pointer_mode(out_desc_, batch_count); | ||||||
|  |  | ||||||
|  |   // Launch kernel to set device offsets | ||||||
|  |   auto pointers = array( | ||||||
|  |       allocator::malloc(batch_count * sizeof(uint64_t) * 4), | ||||||
|  |       {static_cast<int>(batch_count * 4)}, | ||||||
|  |       uint64); | ||||||
|  |  | ||||||
|  |   encoder.add_temporary(pointers); | ||||||
|  |   int block_size = 512; | ||||||
|  |   encoder.set_output_array(pointers); | ||||||
|  |   encoder.add_kernel_node( | ||||||
|  |       cu::set_addmm_device_pointers, | ||||||
|  |       cuda::ceil_div(pointers.size(), block_size), | ||||||
|  |       block_size, | ||||||
|  |       pointers.data<int8_t*>(), | ||||||
|  |       a.data<int8_t>(), | ||||||
|  |       b.data<int8_t>(), | ||||||
|  |       c.data<int8_t>(), | ||||||
|  |       out.data<int8_t>(), | ||||||
|  |       static_cast<int>(out.dtype().size()), | ||||||
|  |       const_param(batch_shape), | ||||||
|  |       const_param(a_batch_strides), | ||||||
|  |       const_param(b_batch_strides), | ||||||
|  |       const_param(c_batch_strides), | ||||||
|  |       static_cast<int64_t>(M_) * N_, | ||||||
|  |       static_cast<int>(batch_shape.size()), | ||||||
|  |       batch_count); | ||||||
|  |  | ||||||
|  |   // Run matmul | ||||||
|  |   encoder.set_input_array(pointers); | ||||||
|  |   encoder.set_input_array(a); | ||||||
|  |   encoder.set_input_array(b); | ||||||
|  |   encoder.set_input_array(c); | ||||||
|  |   encoder.set_output_array(out); | ||||||
|  |  | ||||||
|  |   auto a_pointers = pointers.data<int8_t*>(); | ||||||
|  |   auto b_pointers = a_pointers + batch_count; | ||||||
|  |   auto c_pointers = b_pointers + batch_count; | ||||||
|  |   auto out_pointers = c_pointers + batch_count; | ||||||
|  |   run_impl( | ||||||
|  |       encoder, | ||||||
|  |       reinterpret_cast<void*>(out_pointers), | ||||||
|  |       reinterpret_cast<void*>(a_pointers), | ||||||
|  |       reinterpret_cast<void*>(b_pointers), | ||||||
|  |       reinterpret_cast<void*>(c_pointers), | ||||||
|  |       alpha, | ||||||
|  |       beta); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace mlx::core::cu | ||||||
							
								
								
									
										282
									
								
								mlx/backend/cuda/gemms/cublas_gemm.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										282
									
								
								mlx/backend/cuda/gemms/cublas_gemm.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,282 @@ | |||||||
|  | // Copyright © 2025 Apple Inc. | ||||||
|  |  | ||||||
|  | #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||||
|  | #include "mlx/backend/cuda/device.h" | ||||||
|  | #include "mlx/dtype_utils.h" | ||||||
|  | #include "mlx/utils.h" | ||||||
|  |  | ||||||
|  | #include <fmt/format.h> | ||||||
|  |  | ||||||
|  | namespace mlx::core::cu { | ||||||
|  |  | ||||||
|  | struct CublasPreference { | ||||||
|  |   CublasPreference(Device& device) { | ||||||
|  |     // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB | ||||||
|  |     // for Hopper+: | ||||||
|  |     // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace | ||||||
|  |     uint64_t MiB = 1024 * 1024; | ||||||
|  |     uint64_t workspace_size = | ||||||
|  |         device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; | ||||||
|  |  | ||||||
|  |     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); | ||||||
|  |     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( | ||||||
|  |         pref_, | ||||||
|  |         CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, | ||||||
|  |         &workspace_size, | ||||||
|  |         sizeof(uint64_t))); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   ~CublasPreference() { | ||||||
|  |     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   cublasLtMatmulPreference_t pref_{nullptr}; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | cublasLtMatmulPreference_t cublas_preference(Device& device) { | ||||||
|  |   static CublasPreference pref(device); | ||||||
|  |   return pref.pref_; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | cublasComputeType_t dtype_to_compute_type(Dtype dtype) { | ||||||
|  |   switch (dtype) { | ||||||
|  |     case float16: | ||||||
|  |       return CUBLAS_COMPUTE_32F; | ||||||
|  |     case bfloat16: | ||||||
|  |       return CUBLAS_COMPUTE_32F; | ||||||
|  |     case float32: | ||||||
|  |       return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 | ||||||
|  |                                            : CUBLAS_COMPUTE_32F; | ||||||
|  |     case float64: | ||||||
|  |     case complex64: | ||||||
|  |       return CUBLAS_COMPUTE_64F; | ||||||
|  |     default: | ||||||
|  |       throw std::runtime_error(fmt::format( | ||||||
|  |           "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | cudaDataType_t dtype_to_cublas_type(Dtype dtype) { | ||||||
|  |   switch (dtype) { | ||||||
|  |     case float16: | ||||||
|  |       return CUDA_R_16F; | ||||||
|  |     case bfloat16: | ||||||
|  |       return CUDA_R_16BF; | ||||||
|  |     case float32: | ||||||
|  |       return CUDA_R_32F; | ||||||
|  |     case float64: | ||||||
|  |       return CUDA_R_64F; | ||||||
|  |     case complex64: | ||||||
|  |       return CUDA_C_32F; | ||||||
|  |     default: | ||||||
|  |       throw std::runtime_error(fmt::format( | ||||||
|  |           "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | cublasLtMatrixLayout_t create_matrix_layout( | ||||||
|  |     cudaDataType_t type, | ||||||
|  |     uint64_t rows, | ||||||
|  |     uint64_t cols, | ||||||
|  |     bool transposed, | ||||||
|  |     int64_t ld, | ||||||
|  |     int32_t batch_count, | ||||||
|  |     int64_t batch_stride) { | ||||||
|  |   cublasLtMatrixLayout_t desc; | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); | ||||||
|  |   cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||||
|  |       desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); | ||||||
|  |   if (batch_count > 1) { | ||||||
|  |     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||||
|  |         desc, | ||||||
|  |         CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, | ||||||
|  |         &batch_count, | ||||||
|  |         sizeof(int32_t))); | ||||||
|  |     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||||
|  |         desc, | ||||||
|  |         CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, | ||||||
|  |         &batch_stride, | ||||||
|  |         sizeof(int64_t))); | ||||||
|  |   } | ||||||
|  |   return desc; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | Matmul::Matmul( | ||||||
|  |     Device& device, | ||||||
|  |     Dtype dtype, | ||||||
|  |     bool a_transposed, | ||||||
|  |     uint64_t a_rows, | ||||||
|  |     uint64_t a_cols, | ||||||
|  |     int64_t lda, | ||||||
|  |     bool b_transposed, | ||||||
|  |     uint64_t b_rows, | ||||||
|  |     uint64_t b_cols, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int32_t batch_count, | ||||||
|  |     int64_t a_batch_stride, | ||||||
|  |     int64_t b_batch_stride) | ||||||
|  |     : handle_(device.lt_handle()), | ||||||
|  |       pref_(cublas_preference(device)), | ||||||
|  |       M_(a_rows), | ||||||
|  |       N_(b_cols) { | ||||||
|  |   heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; | ||||||
|  |  | ||||||
|  |   auto scale_type = dtype_to_cublas_type(dtype); | ||||||
|  |   if (dtype == bfloat16 || dtype == float16) { | ||||||
|  |     scale_type = CUDA_R_32F; | ||||||
|  |   } | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( | ||||||
|  |       &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); | ||||||
|  |   int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||||
|  |       matmul_desc_, | ||||||
|  |       CUBLASLT_MATMUL_DESC_POINTER_MODE, | ||||||
|  |       &pointer_mode, | ||||||
|  |       sizeof(int32_t))); | ||||||
|  |   cublasOperation_t op = CUBLAS_OP_N; | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||||
|  |       matmul_desc_, | ||||||
|  |       CUBLASLT_MATMUL_DESC_TRANSA, | ||||||
|  |       &op, | ||||||
|  |       sizeof(cublasOperation_t))); | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||||
|  |       matmul_desc_, | ||||||
|  |       CUBLASLT_MATMUL_DESC_TRANSB, | ||||||
|  |       &op, | ||||||
|  |       sizeof(cublasOperation_t))); | ||||||
|  |  | ||||||
|  |   auto type = dtype_to_cublas_type(dtype); | ||||||
|  |   a_desc_ = create_matrix_layout( | ||||||
|  |       type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); | ||||||
|  |   b_desc_ = create_matrix_layout( | ||||||
|  |       type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); | ||||||
|  |   out_desc_ = create_matrix_layout( | ||||||
|  |       type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | Matmul::Matmul( | ||||||
|  |     Device& device, | ||||||
|  |     Dtype dtype, | ||||||
|  |     bool a_transposed, | ||||||
|  |     uint64_t a_rows, | ||||||
|  |     uint64_t a_cols, | ||||||
|  |     int64_t lda, | ||||||
|  |     bool b_transposed, | ||||||
|  |     uint64_t b_rows, | ||||||
|  |     uint64_t b_cols, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc, | ||||||
|  |     int32_t batch_count, | ||||||
|  |     int64_t a_batch_stride, | ||||||
|  |     int64_t b_batch_stride, | ||||||
|  |     int64_t c_batch_stride) | ||||||
|  |     : Matmul( | ||||||
|  |           device, | ||||||
|  |           dtype, | ||||||
|  |           a_transposed, | ||||||
|  |           a_rows, | ||||||
|  |           a_cols, | ||||||
|  |           lda, | ||||||
|  |           b_transposed, | ||||||
|  |           b_rows, | ||||||
|  |           b_cols, | ||||||
|  |           ldb, | ||||||
|  |           batch_count, | ||||||
|  |           a_batch_stride, | ||||||
|  |           b_batch_stride) { | ||||||
|  |   auto type = dtype_to_cublas_type(dtype); | ||||||
|  |   c_desc_ = create_matrix_layout( | ||||||
|  |       type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | Matmul::~Matmul() { | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void Matmul::run_impl( | ||||||
|  |     cu::CommandEncoder& encoder, | ||||||
|  |     void* out, | ||||||
|  |     const void* a, | ||||||
|  |     const void* b, | ||||||
|  |     const void* c, | ||||||
|  |     float alpha /* = 1 */, | ||||||
|  |     float beta /* = 0 */) { | ||||||
|  |   if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { | ||||||
|  |     int ret = 0; | ||||||
|  |     CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( | ||||||
|  |         handle_, | ||||||
|  |         matmul_desc_, | ||||||
|  |         a_desc_, | ||||||
|  |         b_desc_, | ||||||
|  |         out_desc_, // TODO should that be c_desc is it's set? | ||||||
|  |         out_desc_, | ||||||
|  |         pref_, | ||||||
|  |         1, | ||||||
|  |         &heuristic_, | ||||||
|  |         &ret)); | ||||||
|  |     if (ret == 0) { | ||||||
|  |       throw std::runtime_error("Can not find algorithm for matmul."); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void* workspace_ptr = nullptr; | ||||||
|  |   if (heuristic_.workspaceSize > 0) { | ||||||
|  |     array workspace( | ||||||
|  |         allocator::malloc(heuristic_.workspaceSize), | ||||||
|  |         {static_cast<int>(heuristic_.workspaceSize)}, | ||||||
|  |         int8); | ||||||
|  |     encoder.add_temporary(workspace); | ||||||
|  |     workspace_ptr = workspace.data<void>(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   auto capture = encoder.capture_context(); | ||||||
|  |   CHECK_CUBLAS_ERROR(cublasLtMatmul( | ||||||
|  |       handle_, | ||||||
|  |       matmul_desc_, | ||||||
|  |       &alpha, | ||||||
|  |       a, | ||||||
|  |       a_desc_, | ||||||
|  |       b, | ||||||
|  |       b_desc_, | ||||||
|  |       &beta, | ||||||
|  |       c ? c : out, | ||||||
|  |       c ? c_desc_ : out_desc_, | ||||||
|  |       out, | ||||||
|  |       out_desc_, | ||||||
|  |       &heuristic_.algo, | ||||||
|  |       workspace_ptr, | ||||||
|  |       heuristic_.workspaceSize, | ||||||
|  |       encoder.stream())); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void Matmul::run( | ||||||
|  |     cu::CommandEncoder& encoder, | ||||||
|  |     array& out, | ||||||
|  |     const array& a, | ||||||
|  |     const array& b, | ||||||
|  |     const std::optional<array>& c /* = std::nullopt */, | ||||||
|  |     float alpha /* = 1 */, | ||||||
|  |     float beta /* = 0 */) { | ||||||
|  |   encoder.set_input_array(a); | ||||||
|  |   encoder.set_input_array(b); | ||||||
|  |   if (c) { | ||||||
|  |     encoder.set_input_array(*c); | ||||||
|  |   } | ||||||
|  |   encoder.set_output_array(out); | ||||||
|  |  | ||||||
|  |   run_impl( | ||||||
|  |       encoder, | ||||||
|  |       out.data<void>(), | ||||||
|  |       a.data<void>(), | ||||||
|  |       b.data<void>(), | ||||||
|  |       c ? c->data<void>() : nullptr, | ||||||
|  |       alpha, | ||||||
|  |       beta); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace mlx::core::cu | ||||||
							
								
								
									
										100
									
								
								mlx/backend/cuda/gemms/cublas_gemm.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								mlx/backend/cuda/gemms/cublas_gemm.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | |||||||
|  | // Copyright © 2025 Apple Inc. | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include "mlx/array.h" | ||||||
|  | #include "mlx/backend/cuda/device.h" | ||||||
|  |  | ||||||
|  | #include <cublasLt.h> | ||||||
|  | #include <optional> | ||||||
|  |  | ||||||
|  | namespace mlx::core::cu { | ||||||
|  | class Matmul { | ||||||
|  |  public: | ||||||
|  |   Matmul( | ||||||
|  |       Device& device, | ||||||
|  |       Dtype dtype, | ||||||
|  |       bool a_transposed, | ||||||
|  |       uint64_t a_rows, | ||||||
|  |       uint64_t a_cols, | ||||||
|  |       int64_t lda, | ||||||
|  |       bool b_transposed, | ||||||
|  |       uint64_t b_rows, | ||||||
|  |       uint64_t b_cols, | ||||||
|  |       int64_t ldb, | ||||||
|  |       int32_t batch_count, | ||||||
|  |       int64_t a_batch_stride, | ||||||
|  |       int64_t b_batch_stride); | ||||||
|  |  | ||||||
|  |   Matmul( | ||||||
|  |       Device& device, | ||||||
|  |       Dtype dtype, | ||||||
|  |       bool a_transposed, | ||||||
|  |       uint64_t a_rows, | ||||||
|  |       uint64_t a_cols, | ||||||
|  |       int64_t lda, | ||||||
|  |       bool b_transposed, | ||||||
|  |       uint64_t b_rows, | ||||||
|  |       uint64_t b_cols, | ||||||
|  |       int64_t ldb, | ||||||
|  |       int64_t ldc, | ||||||
|  |       int32_t batch_count, | ||||||
|  |       int64_t a_batch_stride, | ||||||
|  |       int64_t b_batch_stride, | ||||||
|  |       int64_t c_batch_stride); | ||||||
|  |  | ||||||
|  |   ~Matmul(); | ||||||
|  |  | ||||||
|  |   void run( | ||||||
|  |       cu::CommandEncoder& encoder, | ||||||
|  |       array& out, | ||||||
|  |       const array& a, | ||||||
|  |       const array& b, | ||||||
|  |       const std::optional<array>& c = std::nullopt, | ||||||
|  |       float alpha = 1, | ||||||
|  |       float beta = 0); | ||||||
|  |  | ||||||
|  |   void run_batched( | ||||||
|  |       cu::CommandEncoder& encoder, | ||||||
|  |       array& out, | ||||||
|  |       const array& a, | ||||||
|  |       const array& b, | ||||||
|  |       const mlx::core::Shape& batch_shape, | ||||||
|  |       const mlx::core::Strides& a_batch_strides, | ||||||
|  |       const mlx::core::Strides& b_batch_strides); | ||||||
|  |  | ||||||
|  |   void run_batched( | ||||||
|  |       cu::CommandEncoder& encoder, | ||||||
|  |       array& out, | ||||||
|  |       const array& a, | ||||||
|  |       const array& b, | ||||||
|  |       const array& c, | ||||||
|  |       const mlx::core::Shape& batch_shape, | ||||||
|  |       const mlx::core::Strides& a_batch_strides, | ||||||
|  |       const mlx::core::Strides& b_batch_strides, | ||||||
|  |       const mlx::core::Strides& c_batch_strides, | ||||||
|  |       float alpha, | ||||||
|  |       float beta); | ||||||
|  |  | ||||||
|  |  private: | ||||||
|  |   void run_impl( | ||||||
|  |       cu::CommandEncoder& encoder, | ||||||
|  |       void* out, | ||||||
|  |       const void* a, | ||||||
|  |       const void* b, | ||||||
|  |       const void* c, | ||||||
|  |       float alpha = 1, | ||||||
|  |       float beta = 0); | ||||||
|  |  | ||||||
|  |   uint64_t M_; | ||||||
|  |   uint64_t N_; | ||||||
|  |   cublasLtMatmulPreference_t pref_{nullptr}; | ||||||
|  |   cublasLtHandle_t handle_{nullptr}; | ||||||
|  |   cublasLtMatmulDesc_t matmul_desc_{nullptr}; | ||||||
|  |   cublasLtMatrixLayout_t a_desc_{nullptr}; | ||||||
|  |   cublasLtMatrixLayout_t b_desc_{nullptr}; | ||||||
|  |   cublasLtMatrixLayout_t c_desc_{nullptr}; | ||||||
|  |   cublasLtMatrixLayout_t out_desc_{nullptr}; | ||||||
|  |   cublasLtMatmulHeuristicResult_t heuristic_; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | } // namespace mlx::core::cu | ||||||
| @@ -1,6 +1,6 @@ | |||||||
| // Copyright © 2025 Apple Inc. | // Copyright © 2025 Apple Inc. | ||||||
| 
 | 
 | ||||||
| #include "mlx/backend/cuda/gemv.h" | #include "mlx/backend/cuda/gemms/gemv.h" | ||||||
| #include "mlx/backend/cuda/kernel_utils.cuh" | #include "mlx/backend/cuda/kernel_utils.cuh" | ||||||
| #include "mlx/dtype_utils.h" | #include "mlx/dtype_utils.h" | ||||||
| 
 | 
 | ||||||
| @@ -2,279 +2,15 @@ | |||||||
|  |  | ||||||
| #include "mlx/backend/common/matmul.h" | #include "mlx/backend/common/matmul.h" | ||||||
| #include "mlx/backend/cuda/device.h" | #include "mlx/backend/cuda/device.h" | ||||||
| #include "mlx/backend/cuda/gemv.h" | #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||||
|  | #include "mlx/backend/cuda/gemms/gemv.h" | ||||||
| #include "mlx/backend/gpu/copy.h" | #include "mlx/backend/gpu/copy.h" | ||||||
| #include "mlx/dtype_utils.h" |  | ||||||
| #include "mlx/primitives.h" | #include "mlx/primitives.h" | ||||||
| #include "mlx/utils.h" |  | ||||||
|  |  | ||||||
| #include <fmt/format.h> |  | ||||||
| #include <nvtx3/nvtx3.hpp> | #include <nvtx3/nvtx3.hpp> | ||||||
|  |  | ||||||
| #include <numeric> | #include <numeric> | ||||||
|  |  | ||||||
| namespace mlx::core { | namespace mlx::core { | ||||||
|  |  | ||||||
| namespace cu { |  | ||||||
|  |  | ||||||
| struct CublasPreference { |  | ||||||
|   CublasPreference(Device& device) { |  | ||||||
|     // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB |  | ||||||
|     // for Hopper+: |  | ||||||
|     // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace |  | ||||||
|     uint64_t MiB = 1024 * 1024; |  | ||||||
|     uint64_t workspace_size = |  | ||||||
|         device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; |  | ||||||
|  |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( |  | ||||||
|         pref_, |  | ||||||
|         CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, |  | ||||||
|         &workspace_size, |  | ||||||
|         sizeof(uint64_t))); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   ~CublasPreference() { |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   cublasLtMatmulPreference_t pref_{nullptr}; |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| cublasLtMatmulPreference_t cublas_preference(Device& device) { |  | ||||||
|   static CublasPreference pref(device); |  | ||||||
|   return pref.pref_; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| class MatMul { |  | ||||||
|  public: |  | ||||||
|   MatMul( |  | ||||||
|       Device& device, |  | ||||||
|       Dtype dtype, |  | ||||||
|       bool a_transposed, |  | ||||||
|       uint64_t a_rows, |  | ||||||
|       uint64_t a_cols, |  | ||||||
|       int64_t lda, |  | ||||||
|       bool b_transposed, |  | ||||||
|       uint64_t b_rows, |  | ||||||
|       uint64_t b_cols, |  | ||||||
|       int64_t ldb, |  | ||||||
|       int32_t batch_count, |  | ||||||
|       int64_t a_batch_stride, |  | ||||||
|       int64_t b_batch_stride) |  | ||||||
|       : handle_(device.lt_handle()), pref_(cublas_preference(device)) { |  | ||||||
|     heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; |  | ||||||
|  |  | ||||||
|     auto scale_type = dtype_to_cuda_type(dtype); |  | ||||||
|     if (dtype == bfloat16 || dtype == float16) { |  | ||||||
|       scale_type = CUDA_R_32F; |  | ||||||
|     } |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( |  | ||||||
|         &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); |  | ||||||
|     int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( |  | ||||||
|         matmul_desc_, |  | ||||||
|         CUBLASLT_MATMUL_DESC_POINTER_MODE, |  | ||||||
|         &pointer_mode, |  | ||||||
|         sizeof(int32_t))); |  | ||||||
|     cublasOperation_t op = CUBLAS_OP_N; |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( |  | ||||||
|         matmul_desc_, |  | ||||||
|         CUBLASLT_MATMUL_DESC_TRANSA, |  | ||||||
|         &op, |  | ||||||
|         sizeof(cublasOperation_t))); |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( |  | ||||||
|         matmul_desc_, |  | ||||||
|         CUBLASLT_MATMUL_DESC_TRANSB, |  | ||||||
|         &op, |  | ||||||
|         sizeof(cublasOperation_t))); |  | ||||||
|  |  | ||||||
|     auto type = dtype_to_cuda_type(dtype); |  | ||||||
|     a_desc_ = create_matrix_layout( |  | ||||||
|         type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); |  | ||||||
|     b_desc_ = create_matrix_layout( |  | ||||||
|         type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); |  | ||||||
|     out_desc_ = create_matrix_layout( |  | ||||||
|         type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   MatMul( |  | ||||||
|       Device& device, |  | ||||||
|       Dtype dtype, |  | ||||||
|       bool a_transposed, |  | ||||||
|       uint64_t a_rows, |  | ||||||
|       uint64_t a_cols, |  | ||||||
|       int64_t lda, |  | ||||||
|       bool b_transposed, |  | ||||||
|       uint64_t b_rows, |  | ||||||
|       uint64_t b_cols, |  | ||||||
|       int64_t ldb, |  | ||||||
|       int64_t ldc, |  | ||||||
|       int32_t batch_count, |  | ||||||
|       int64_t a_batch_stride, |  | ||||||
|       int64_t b_batch_stride, |  | ||||||
|       int64_t c_batch_stride) |  | ||||||
|       : MatMul( |  | ||||||
|             device, |  | ||||||
|             dtype, |  | ||||||
|             a_transposed, |  | ||||||
|             a_rows, |  | ||||||
|             a_cols, |  | ||||||
|             lda, |  | ||||||
|             b_transposed, |  | ||||||
|             b_rows, |  | ||||||
|             b_cols, |  | ||||||
|             ldb, |  | ||||||
|             batch_count, |  | ||||||
|             a_batch_stride, |  | ||||||
|             b_batch_stride) { |  | ||||||
|     auto type = dtype_to_cuda_type(dtype); |  | ||||||
|     c_desc_ = create_matrix_layout( |  | ||||||
|         type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   ~MatMul() { |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   void run( |  | ||||||
|       cu::CommandEncoder& encoder, |  | ||||||
|       void* out, |  | ||||||
|       void* a, |  | ||||||
|       void* b, |  | ||||||
|       void* c = nullptr, |  | ||||||
|       float alpha = 1, |  | ||||||
|       float beta = 0) { |  | ||||||
|     if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { |  | ||||||
|       int ret = 0; |  | ||||||
|       CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( |  | ||||||
|           handle_, |  | ||||||
|           matmul_desc_, |  | ||||||
|           a_desc_, |  | ||||||
|           b_desc_, |  | ||||||
|           out_desc_, |  | ||||||
|           out_desc_, |  | ||||||
|           pref_, |  | ||||||
|           1, |  | ||||||
|           &heuristic_, |  | ||||||
|           &ret)); |  | ||||||
|       if (ret == 0) { |  | ||||||
|         throw std::runtime_error("Can not find algorithm for matmul."); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     void* workspace_ptr = nullptr; |  | ||||||
|     if (heuristic_.workspaceSize > 0) { |  | ||||||
|       array workspace( |  | ||||||
|           allocator::malloc(heuristic_.workspaceSize), |  | ||||||
|           {static_cast<int>(heuristic_.workspaceSize)}, |  | ||||||
|           int8); |  | ||||||
|       encoder.add_temporary(workspace); |  | ||||||
|       workspace_ptr = workspace.data<void>(); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     auto capture = encoder.capture_context(); |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmul( |  | ||||||
|         handle_, |  | ||||||
|         matmul_desc_, |  | ||||||
|         &alpha, |  | ||||||
|         a, |  | ||||||
|         a_desc_, |  | ||||||
|         b, |  | ||||||
|         b_desc_, |  | ||||||
|         &beta, |  | ||||||
|         c ? c : out, |  | ||||||
|         c ? c_desc_ : out_desc_, |  | ||||||
|         out, |  | ||||||
|         out_desc_, |  | ||||||
|         &heuristic_.algo, |  | ||||||
|         workspace_ptr, |  | ||||||
|         heuristic_.workspaceSize, |  | ||||||
|         encoder.stream())); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|  private: |  | ||||||
|   cublasComputeType_t dtype_to_compute_type(Dtype dtype) { |  | ||||||
|     switch (dtype) { |  | ||||||
|       case float16: |  | ||||||
|         return CUBLAS_COMPUTE_32F; |  | ||||||
|       case bfloat16: |  | ||||||
|         return CUBLAS_COMPUTE_32F; |  | ||||||
|       case float32: |  | ||||||
|         return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 |  | ||||||
|                                              : CUBLAS_COMPUTE_32F; |  | ||||||
|       case float64: |  | ||||||
|       case complex64: |  | ||||||
|         return CUBLAS_COMPUTE_64F; |  | ||||||
|       default: |  | ||||||
|         throw std::runtime_error(fmt::format( |  | ||||||
|             "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   cudaDataType_t dtype_to_cuda_type(Dtype dtype) { |  | ||||||
|     switch (dtype) { |  | ||||||
|       case float16: |  | ||||||
|         return CUDA_R_16F; |  | ||||||
|       case bfloat16: |  | ||||||
|         return CUDA_R_16BF; |  | ||||||
|       case float32: |  | ||||||
|         return CUDA_R_32F; |  | ||||||
|       case float64: |  | ||||||
|         return CUDA_R_64F; |  | ||||||
|       case complex64: |  | ||||||
|         return CUDA_C_32F; |  | ||||||
|       default: |  | ||||||
|         throw std::runtime_error(fmt::format( |  | ||||||
|             "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   cublasLtMatrixLayout_t create_matrix_layout( |  | ||||||
|       cudaDataType_t type, |  | ||||||
|       uint64_t rows, |  | ||||||
|       uint64_t cols, |  | ||||||
|       bool transposed, |  | ||||||
|       int64_t ld, |  | ||||||
|       int32_t batch_count, |  | ||||||
|       int64_t batch_stride) { |  | ||||||
|     cublasLtMatrixLayout_t desc; |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); |  | ||||||
|     cublasLtOrder_t order = |  | ||||||
|         transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; |  | ||||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( |  | ||||||
|         desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); |  | ||||||
|     if (batch_count > 1) { |  | ||||||
|       CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( |  | ||||||
|           desc, |  | ||||||
|           CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, |  | ||||||
|           &batch_count, |  | ||||||
|           sizeof(int32_t))); |  | ||||||
|       CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( |  | ||||||
|           desc, |  | ||||||
|           CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, |  | ||||||
|           &batch_stride, |  | ||||||
|           sizeof(int64_t))); |  | ||||||
|     } |  | ||||||
|     return desc; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   cublasLtMatmulPreference_t pref_{nullptr}; |  | ||||||
|   cublasLtHandle_t handle_{nullptr}; |  | ||||||
|   cublasLtMatmulDesc_t matmul_desc_{nullptr}; |  | ||||||
|   cublasLtMatrixLayout_t a_desc_{nullptr}; |  | ||||||
|   cublasLtMatrixLayout_t b_desc_{nullptr}; |  | ||||||
|   cublasLtMatrixLayout_t c_desc_{nullptr}; |  | ||||||
|   cublasLtMatrixLayout_t out_desc_{nullptr}; |  | ||||||
|   cublasLtMatmulHeuristicResult_t heuristic_; |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| } // namespace cu |  | ||||||
|  |  | ||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
| std::tuple<bool, int64_t, array> | std::tuple<bool, int64_t, array> | ||||||
| @@ -361,8 +97,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|  |  | ||||||
|   ///////////////////////////////////////////////////////////////////////////// |   ///////////////////////////////////////////////////////////////////////////// | ||||||
|   // Invoke cublasLt |   // Invoke cublasLt | ||||||
|  |   cu::Matmul matmul( | ||||||
|   cu::MatMul matmul( |  | ||||||
|       cu::device(s.device), |       cu::device(s.device), | ||||||
|       a.dtype(), |       a.dtype(), | ||||||
|       a_transposed, |       a_transposed, | ||||||
| @@ -377,27 +112,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|       a_batch_strides.back(), |       a_batch_strides.back(), | ||||||
|       b_batch_strides.back()); |       b_batch_strides.back()); | ||||||
|  |  | ||||||
|   encoder.set_input_array(a); |   if ((batch_count / batch_shape.back()) == 1) { | ||||||
|   encoder.set_input_array(b); |     matmul.run(encoder, out, a, b); | ||||||
|   encoder.set_output_array(out); |  | ||||||
|   auto nbatch = batch_count / batch_shape.back(); |  | ||||||
|   if (nbatch == 1) { |  | ||||||
|     matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>()); |  | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); |   matmul.run_batched( | ||||||
|   ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); |       encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); | ||||||
|   auto concurrent = encoder.concurrent_context(); |  | ||||||
|   for (size_t i = 0; i < nbatch; ++i) { |  | ||||||
|     matmul.run( |  | ||||||
|         encoder, |  | ||||||
|         out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, |  | ||||||
|         a.data<int8_t>() + a.itemsize() * a_it.loc, |  | ||||||
|         b.data<int8_t>() + b.itemsize() * b_it.loc); |  | ||||||
|     a_it.step(); |  | ||||||
|     b_it.step(); |  | ||||||
|   } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||||
| @@ -465,7 +186,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   ///////////////////////////////////////////////////////////////////////////// |   ///////////////////////////////////////////////////////////////////////////// | ||||||
|   // Invoke cublasLt |   // Invoke cublasLt | ||||||
|  |  | ||||||
|   cu::MatMul matmul( |   cu::Matmul matmul( | ||||||
|       cu::device(s.device), |       cu::device(s.device), | ||||||
|       a.dtype(), |       a.dtype(), | ||||||
|       a_transposed, |       a_transposed, | ||||||
| @@ -482,41 +203,22 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|       b_batch_strides.back(), |       b_batch_strides.back(), | ||||||
|       c_batch_strides.back()); |       c_batch_strides.back()); | ||||||
|  |  | ||||||
|   encoder.set_input_array(a); |   if ((batch_count / batch_shape.back()) == 1) { | ||||||
|   encoder.set_input_array(b); |     matmul.run(encoder, out, a, b, c, alpha_, beta_); | ||||||
|   encoder.set_input_array(c); |  | ||||||
|   encoder.set_output_array(out); |  | ||||||
|  |  | ||||||
|   auto nbatch = batch_count / batch_shape.back(); |  | ||||||
|   if (nbatch == 1) { |  | ||||||
|     matmul.run( |  | ||||||
|         encoder, |  | ||||||
|         out.data<int8_t>(), |  | ||||||
|         a.data<int8_t>(), |  | ||||||
|         b.data<int8_t>(), |  | ||||||
|         c.data<int8_t>(), |  | ||||||
|         alpha_, |  | ||||||
|         beta_); |  | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |   matmul.run_batched( | ||||||
|   ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); |       encoder, | ||||||
|   ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); |       out, | ||||||
|   ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); |       a, | ||||||
|   auto concurrent = encoder.concurrent_context(); |       b, | ||||||
|   for (size_t i = 0; i < nbatch; ++i) { |       c, | ||||||
|     matmul.run( |       batch_shape, | ||||||
|         encoder, |       a_batch_strides, | ||||||
|         out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, |       b_batch_strides, | ||||||
|         a.data<int8_t>() + a.itemsize() * a_it.loc, |       c_batch_strides, | ||||||
|         b.data<int8_t>() + b.itemsize() * b_it.loc, |       alpha_, | ||||||
|         c.data<int8_t>() + c.itemsize() * c_it.loc, |       beta_); | ||||||
|         alpha_, |  | ||||||
|         beta_); |  | ||||||
|     a_it.step(); |  | ||||||
|     b_it.step(); |  | ||||||
|     c_it.step(); |  | ||||||
|   } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| } // namespace mlx::core | } // namespace mlx::core | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun