mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	[CUDA] Always use batched matmul (#2404)
* cuda batched mm * addmm as well * comment
This commit is contained in:
		| @@ -21,7 +21,8 @@ target_sources( | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/event.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/gemv.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu | ||||
| @@ -47,6 +48,14 @@ target_sources( | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) | ||||
|  | ||||
| if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) | ||||
|   target_sources( | ||||
|     mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu) | ||||
| else() | ||||
|   target_sources( | ||||
|     mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp) | ||||
| endif() | ||||
|  | ||||
| target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) | ||||
|  | ||||
| # Embed kernel sources in binary for JIT compilation. | ||||
|   | ||||
							
								
								
									
										73
									
								
								mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
|  | ||||
| #include "mlx/backend/common/utils.h" | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
|  | ||||
| void Matmul::run_batched( | ||||
|     cu::CommandEncoder& encoder, | ||||
|     array& out, | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     const mlx::core::Shape& batch_shape, | ||||
|     const mlx::core::Strides& a_batch_strides, | ||||
|     const mlx::core::Strides& b_batch_strides) { | ||||
|   encoder.set_input_array(a); | ||||
|   encoder.set_input_array(b); | ||||
|   encoder.set_output_array(out); | ||||
|   auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); | ||||
|   ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); | ||||
|   ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); | ||||
|   auto concurrent = encoder.concurrent_context(); | ||||
|   for (size_t i = 0; i < nbatch; ++i) { | ||||
|     run_impl( | ||||
|         encoder, | ||||
|         out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, | ||||
|         a.data<int8_t>() + a.itemsize() * a_it.loc, | ||||
|         b.data<int8_t>() + b.itemsize() * b_it.loc, | ||||
|         nullptr); | ||||
|     a_it.step(); | ||||
|     b_it.step(); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void Matmul::run_batched( | ||||
|     cu::CommandEncoder& encoder, | ||||
|     array& out, | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     const array& c, | ||||
|     const mlx::core::Shape& batch_shape, | ||||
|     const mlx::core::Strides& a_batch_strides, | ||||
|     const mlx::core::Strides& b_batch_strides, | ||||
|     const mlx::core::Strides& c_batch_strides, | ||||
|     float alpha, | ||||
|     float beta) { | ||||
|   encoder.set_input_array(a); | ||||
|   encoder.set_input_array(b); | ||||
|   encoder.set_input_array(c); | ||||
|   encoder.set_output_array(out); | ||||
|  | ||||
|   auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); | ||||
|   ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); | ||||
|   ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); | ||||
|   ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); | ||||
|   auto concurrent = encoder.concurrent_context(); | ||||
|   for (size_t i = 0; i < nbatch; ++i) { | ||||
|     run_impl( | ||||
|         encoder, | ||||
|         out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, | ||||
|         a.data<int8_t>() + a.itemsize() * a_it.loc, | ||||
|         b.data<int8_t>() + b.itemsize() * b_it.loc, | ||||
|         c.data<int8_t>() + c.itemsize() * c_it.loc, | ||||
|         alpha, | ||||
|         beta); | ||||
|     a_it.step(); | ||||
|     b_it.step(); | ||||
|     c_it.step(); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::cu | ||||
							
								
								
									
										206
									
								
								mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,206 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
|  | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||
| #include "mlx/backend/cuda/kernel_utils.cuh" | ||||
|  | ||||
| #include <cooperative_groups.h> | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
|  | ||||
| namespace cg = cooperative_groups; | ||||
|  | ||||
| __global__ void set_mm_device_pointers( | ||||
|     int8_t** pointers, | ||||
|     int8_t* a_start, | ||||
|     int8_t* b_start, | ||||
|     int8_t* out_start, | ||||
|     int item_size, | ||||
|     const __grid_constant__ Shape batch_shape, | ||||
|     const __grid_constant__ Strides a_batch_strides, | ||||
|     const __grid_constant__ Strides b_batch_strides, | ||||
|     int64_t batch_stride, | ||||
|     int batch_ndim, | ||||
|     int batch_count) { | ||||
|   auto index = cg::this_grid().thread_rank(); | ||||
|   if (index >= batch_count) { | ||||
|     return; | ||||
|   } | ||||
|   auto [a_offset, b_offset] = elem_to_loc( | ||||
|       index, | ||||
|       batch_shape.data(), | ||||
|       a_batch_strides.data(), | ||||
|       b_batch_strides.data(), | ||||
|       batch_ndim); | ||||
|   pointers[index] = a_start + item_size * a_offset; | ||||
|   pointers[index + batch_count] = b_start + item_size * b_offset; | ||||
|   pointers[index + 2 * batch_count] = | ||||
|       out_start + item_size * index * batch_stride; | ||||
| } | ||||
|  | ||||
| __global__ void set_addmm_device_pointers( | ||||
|     int8_t** pointers, | ||||
|     int8_t* a_start, | ||||
|     int8_t* b_start, | ||||
|     int8_t* c_start, | ||||
|     int8_t* out_start, | ||||
|     int item_size, | ||||
|     const __grid_constant__ Shape batch_shape, | ||||
|     const __grid_constant__ Strides a_batch_strides, | ||||
|     const __grid_constant__ Strides b_batch_strides, | ||||
|     const __grid_constant__ Strides c_batch_strides, | ||||
|     int64_t batch_stride, | ||||
|     int batch_ndim, | ||||
|     int batch_count) { | ||||
|   auto index = cg::this_grid().thread_rank(); | ||||
|   if (index >= batch_count) { | ||||
|     return; | ||||
|   } | ||||
|   auto [a_offset, b_offset, c_offset] = elem_to_loc( | ||||
|       index, | ||||
|       batch_shape.data(), | ||||
|       a_batch_strides.data(), | ||||
|       b_batch_strides.data(), | ||||
|       c_batch_strides.data(), | ||||
|       batch_ndim); | ||||
|   pointers[index] = a_start + item_size * a_offset; | ||||
|   pointers[index + batch_count] = b_start + item_size * b_offset; | ||||
|   pointers[index + 2 * batch_count] = c_start + item_size * c_offset; | ||||
|   pointers[index + 3 * batch_count] = | ||||
|       out_start + item_size * index * batch_stride; | ||||
| } | ||||
|  | ||||
| void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { | ||||
|   auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|       desc, | ||||
|       CUBLASLT_MATRIX_LAYOUT_BATCH_MODE, | ||||
|       &batch_mode, | ||||
|       sizeof(batch_mode))); | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|       desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); | ||||
| } | ||||
|  | ||||
| void Matmul::run_batched( | ||||
|     cu::CommandEncoder& encoder, | ||||
|     array& out, | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     const mlx::core::Shape& batch_shape, | ||||
|     const mlx::core::Strides& a_batch_strides, | ||||
|     const mlx::core::Strides& b_batch_strides) { | ||||
|   auto batch_count = out.size() / (M_ * N_); | ||||
|   set_pointer_mode(a_desc_, batch_count); | ||||
|   set_pointer_mode(b_desc_, batch_count); | ||||
|   set_pointer_mode(out_desc_, batch_count); | ||||
|  | ||||
|   // Launch kernel to set device offsets | ||||
|   auto pointers = array( | ||||
|       allocator::malloc(batch_count * sizeof(uint64_t) * 3), | ||||
|       {static_cast<int>(batch_count * 3)}, | ||||
|       uint64); | ||||
|  | ||||
|   encoder.add_temporary(pointers); | ||||
|   int block_size = 512; | ||||
|   encoder.set_output_array(pointers); | ||||
|  | ||||
|   encoder.add_kernel_node( | ||||
|       cu::set_mm_device_pointers, | ||||
|       cuda::ceil_div(pointers.size(), block_size), | ||||
|       block_size, | ||||
|       pointers.data<int8_t*>(), | ||||
|       a.data<int8_t>(), | ||||
|       b.data<int8_t>(), | ||||
|       out.data<int8_t>(), | ||||
|       static_cast<int>(out.dtype().size()), | ||||
|       const_param(batch_shape), | ||||
|       const_param(a_batch_strides), | ||||
|       const_param(b_batch_strides), | ||||
|       static_cast<int64_t>(M_) * N_, | ||||
|       static_cast<int>(batch_shape.size()), | ||||
|       batch_count); | ||||
|  | ||||
|   // Run matmul | ||||
|   encoder.set_input_array(pointers); | ||||
|   encoder.set_input_array(a); | ||||
|   encoder.set_input_array(b); | ||||
|   encoder.set_output_array(out); | ||||
|  | ||||
|   auto a_pointers = pointers.data<int8_t*>(); | ||||
|   auto b_pointers = a_pointers + batch_count; | ||||
|   auto out_pointers = b_pointers + batch_count; | ||||
|   run_impl( | ||||
|       encoder, | ||||
|       reinterpret_cast<void*>(out_pointers), | ||||
|       reinterpret_cast<void*>(a_pointers), | ||||
|       reinterpret_cast<void*>(b_pointers), | ||||
|       nullptr); | ||||
| } | ||||
|  | ||||
| void Matmul::run_batched( | ||||
|     cu::CommandEncoder& encoder, | ||||
|     array& out, | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     const array& c, | ||||
|     const mlx::core::Shape& batch_shape, | ||||
|     const mlx::core::Strides& a_batch_strides, | ||||
|     const mlx::core::Strides& b_batch_strides, | ||||
|     const mlx::core::Strides& c_batch_strides, | ||||
|     float alpha, | ||||
|     float beta) { | ||||
|   auto batch_count = out.size() / (M_ * N_); | ||||
|   set_pointer_mode(a_desc_, batch_count); | ||||
|   set_pointer_mode(b_desc_, batch_count); | ||||
|   set_pointer_mode(c_desc_, batch_count); | ||||
|   set_pointer_mode(out_desc_, batch_count); | ||||
|  | ||||
|   // Launch kernel to set device offsets | ||||
|   auto pointers = array( | ||||
|       allocator::malloc(batch_count * sizeof(uint64_t) * 4), | ||||
|       {static_cast<int>(batch_count * 4)}, | ||||
|       uint64); | ||||
|  | ||||
|   encoder.add_temporary(pointers); | ||||
|   int block_size = 512; | ||||
|   encoder.set_output_array(pointers); | ||||
|   encoder.add_kernel_node( | ||||
|       cu::set_addmm_device_pointers, | ||||
|       cuda::ceil_div(pointers.size(), block_size), | ||||
|       block_size, | ||||
|       pointers.data<int8_t*>(), | ||||
|       a.data<int8_t>(), | ||||
|       b.data<int8_t>(), | ||||
|       c.data<int8_t>(), | ||||
|       out.data<int8_t>(), | ||||
|       static_cast<int>(out.dtype().size()), | ||||
|       const_param(batch_shape), | ||||
|       const_param(a_batch_strides), | ||||
|       const_param(b_batch_strides), | ||||
|       const_param(c_batch_strides), | ||||
|       static_cast<int64_t>(M_) * N_, | ||||
|       static_cast<int>(batch_shape.size()), | ||||
|       batch_count); | ||||
|  | ||||
|   // Run matmul | ||||
|   encoder.set_input_array(pointers); | ||||
|   encoder.set_input_array(a); | ||||
|   encoder.set_input_array(b); | ||||
|   encoder.set_input_array(c); | ||||
|   encoder.set_output_array(out); | ||||
|  | ||||
|   auto a_pointers = pointers.data<int8_t*>(); | ||||
|   auto b_pointers = a_pointers + batch_count; | ||||
|   auto c_pointers = b_pointers + batch_count; | ||||
|   auto out_pointers = c_pointers + batch_count; | ||||
|   run_impl( | ||||
|       encoder, | ||||
|       reinterpret_cast<void*>(out_pointers), | ||||
|       reinterpret_cast<void*>(a_pointers), | ||||
|       reinterpret_cast<void*>(b_pointers), | ||||
|       reinterpret_cast<void*>(c_pointers), | ||||
|       alpha, | ||||
|       beta); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::cu | ||||
							
								
								
									
										282
									
								
								mlx/backend/cuda/gemms/cublas_gemm.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										282
									
								
								mlx/backend/cuda/gemms/cublas_gemm.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,282 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
|  | ||||
| #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/dtype_utils.h" | ||||
| #include "mlx/utils.h" | ||||
|  | ||||
| #include <fmt/format.h> | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
|  | ||||
| struct CublasPreference { | ||||
|   CublasPreference(Device& device) { | ||||
|     // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB | ||||
|     // for Hopper+: | ||||
|     // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace | ||||
|     uint64_t MiB = 1024 * 1024; | ||||
|     uint64_t workspace_size = | ||||
|         device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; | ||||
|  | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( | ||||
|         pref_, | ||||
|         CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, | ||||
|         &workspace_size, | ||||
|         sizeof(uint64_t))); | ||||
|   } | ||||
|  | ||||
|   ~CublasPreference() { | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); | ||||
|   } | ||||
|  | ||||
|   cublasLtMatmulPreference_t pref_{nullptr}; | ||||
| }; | ||||
|  | ||||
| cublasLtMatmulPreference_t cublas_preference(Device& device) { | ||||
|   static CublasPreference pref(device); | ||||
|   return pref.pref_; | ||||
| } | ||||
|  | ||||
| cublasComputeType_t dtype_to_compute_type(Dtype dtype) { | ||||
|   switch (dtype) { | ||||
|     case float16: | ||||
|       return CUBLAS_COMPUTE_32F; | ||||
|     case bfloat16: | ||||
|       return CUBLAS_COMPUTE_32F; | ||||
|     case float32: | ||||
|       return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 | ||||
|                                            : CUBLAS_COMPUTE_32F; | ||||
|     case float64: | ||||
|     case complex64: | ||||
|       return CUBLAS_COMPUTE_64F; | ||||
|     default: | ||||
|       throw std::runtime_error(fmt::format( | ||||
|           "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); | ||||
|   } | ||||
| } | ||||
|  | ||||
| cudaDataType_t dtype_to_cublas_type(Dtype dtype) { | ||||
|   switch (dtype) { | ||||
|     case float16: | ||||
|       return CUDA_R_16F; | ||||
|     case bfloat16: | ||||
|       return CUDA_R_16BF; | ||||
|     case float32: | ||||
|       return CUDA_R_32F; | ||||
|     case float64: | ||||
|       return CUDA_R_64F; | ||||
|     case complex64: | ||||
|       return CUDA_C_32F; | ||||
|     default: | ||||
|       throw std::runtime_error(fmt::format( | ||||
|           "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); | ||||
|   } | ||||
| } | ||||
|  | ||||
| cublasLtMatrixLayout_t create_matrix_layout( | ||||
|     cudaDataType_t type, | ||||
|     uint64_t rows, | ||||
|     uint64_t cols, | ||||
|     bool transposed, | ||||
|     int64_t ld, | ||||
|     int32_t batch_count, | ||||
|     int64_t batch_stride) { | ||||
|   cublasLtMatrixLayout_t desc; | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); | ||||
|   cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|       desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); | ||||
|   if (batch_count > 1) { | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|         desc, | ||||
|         CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, | ||||
|         &batch_count, | ||||
|         sizeof(int32_t))); | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|         desc, | ||||
|         CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, | ||||
|         &batch_stride, | ||||
|         sizeof(int64_t))); | ||||
|   } | ||||
|   return desc; | ||||
| } | ||||
|  | ||||
| Matmul::Matmul( | ||||
|     Device& device, | ||||
|     Dtype dtype, | ||||
|     bool a_transposed, | ||||
|     uint64_t a_rows, | ||||
|     uint64_t a_cols, | ||||
|     int64_t lda, | ||||
|     bool b_transposed, | ||||
|     uint64_t b_rows, | ||||
|     uint64_t b_cols, | ||||
|     int64_t ldb, | ||||
|     int32_t batch_count, | ||||
|     int64_t a_batch_stride, | ||||
|     int64_t b_batch_stride) | ||||
|     : handle_(device.lt_handle()), | ||||
|       pref_(cublas_preference(device)), | ||||
|       M_(a_rows), | ||||
|       N_(b_cols) { | ||||
|   heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; | ||||
|  | ||||
|   auto scale_type = dtype_to_cublas_type(dtype); | ||||
|   if (dtype == bfloat16 || dtype == float16) { | ||||
|     scale_type = CUDA_R_32F; | ||||
|   } | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( | ||||
|       &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); | ||||
|   int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|       matmul_desc_, | ||||
|       CUBLASLT_MATMUL_DESC_POINTER_MODE, | ||||
|       &pointer_mode, | ||||
|       sizeof(int32_t))); | ||||
|   cublasOperation_t op = CUBLAS_OP_N; | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|       matmul_desc_, | ||||
|       CUBLASLT_MATMUL_DESC_TRANSA, | ||||
|       &op, | ||||
|       sizeof(cublasOperation_t))); | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|       matmul_desc_, | ||||
|       CUBLASLT_MATMUL_DESC_TRANSB, | ||||
|       &op, | ||||
|       sizeof(cublasOperation_t))); | ||||
|  | ||||
|   auto type = dtype_to_cublas_type(dtype); | ||||
|   a_desc_ = create_matrix_layout( | ||||
|       type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); | ||||
|   b_desc_ = create_matrix_layout( | ||||
|       type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); | ||||
|   out_desc_ = create_matrix_layout( | ||||
|       type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); | ||||
| } | ||||
|  | ||||
| Matmul::Matmul( | ||||
|     Device& device, | ||||
|     Dtype dtype, | ||||
|     bool a_transposed, | ||||
|     uint64_t a_rows, | ||||
|     uint64_t a_cols, | ||||
|     int64_t lda, | ||||
|     bool b_transposed, | ||||
|     uint64_t b_rows, | ||||
|     uint64_t b_cols, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc, | ||||
|     int32_t batch_count, | ||||
|     int64_t a_batch_stride, | ||||
|     int64_t b_batch_stride, | ||||
|     int64_t c_batch_stride) | ||||
|     : Matmul( | ||||
|           device, | ||||
|           dtype, | ||||
|           a_transposed, | ||||
|           a_rows, | ||||
|           a_cols, | ||||
|           lda, | ||||
|           b_transposed, | ||||
|           b_rows, | ||||
|           b_cols, | ||||
|           ldb, | ||||
|           batch_count, | ||||
|           a_batch_stride, | ||||
|           b_batch_stride) { | ||||
|   auto type = dtype_to_cublas_type(dtype); | ||||
|   c_desc_ = create_matrix_layout( | ||||
|       type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); | ||||
| } | ||||
|  | ||||
| Matmul::~Matmul() { | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); | ||||
| } | ||||
|  | ||||
| void Matmul::run_impl( | ||||
|     cu::CommandEncoder& encoder, | ||||
|     void* out, | ||||
|     const void* a, | ||||
|     const void* b, | ||||
|     const void* c, | ||||
|     float alpha /* = 1 */, | ||||
|     float beta /* = 0 */) { | ||||
|   if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { | ||||
|     int ret = 0; | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( | ||||
|         handle_, | ||||
|         matmul_desc_, | ||||
|         a_desc_, | ||||
|         b_desc_, | ||||
|         out_desc_, // TODO should that be c_desc is it's set? | ||||
|         out_desc_, | ||||
|         pref_, | ||||
|         1, | ||||
|         &heuristic_, | ||||
|         &ret)); | ||||
|     if (ret == 0) { | ||||
|       throw std::runtime_error("Can not find algorithm for matmul."); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void* workspace_ptr = nullptr; | ||||
|   if (heuristic_.workspaceSize > 0) { | ||||
|     array workspace( | ||||
|         allocator::malloc(heuristic_.workspaceSize), | ||||
|         {static_cast<int>(heuristic_.workspaceSize)}, | ||||
|         int8); | ||||
|     encoder.add_temporary(workspace); | ||||
|     workspace_ptr = workspace.data<void>(); | ||||
|   } | ||||
|  | ||||
|   auto capture = encoder.capture_context(); | ||||
|   CHECK_CUBLAS_ERROR(cublasLtMatmul( | ||||
|       handle_, | ||||
|       matmul_desc_, | ||||
|       &alpha, | ||||
|       a, | ||||
|       a_desc_, | ||||
|       b, | ||||
|       b_desc_, | ||||
|       &beta, | ||||
|       c ? c : out, | ||||
|       c ? c_desc_ : out_desc_, | ||||
|       out, | ||||
|       out_desc_, | ||||
|       &heuristic_.algo, | ||||
|       workspace_ptr, | ||||
|       heuristic_.workspaceSize, | ||||
|       encoder.stream())); | ||||
| } | ||||
|  | ||||
| void Matmul::run( | ||||
|     cu::CommandEncoder& encoder, | ||||
|     array& out, | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     const std::optional<array>& c /* = std::nullopt */, | ||||
|     float alpha /* = 1 */, | ||||
|     float beta /* = 0 */) { | ||||
|   encoder.set_input_array(a); | ||||
|   encoder.set_input_array(b); | ||||
|   if (c) { | ||||
|     encoder.set_input_array(*c); | ||||
|   } | ||||
|   encoder.set_output_array(out); | ||||
|  | ||||
|   run_impl( | ||||
|       encoder, | ||||
|       out.data<void>(), | ||||
|       a.data<void>(), | ||||
|       b.data<void>(), | ||||
|       c ? c->data<void>() : nullptr, | ||||
|       alpha, | ||||
|       beta); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::cu | ||||
							
								
								
									
										100
									
								
								mlx/backend/cuda/gemms/cublas_gemm.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								mlx/backend/cuda/gemms/cublas_gemm.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
| #pragma once | ||||
|  | ||||
| #include "mlx/array.h" | ||||
| #include "mlx/backend/cuda/device.h" | ||||
|  | ||||
| #include <cublasLt.h> | ||||
| #include <optional> | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
| class Matmul { | ||||
|  public: | ||||
|   Matmul( | ||||
|       Device& device, | ||||
|       Dtype dtype, | ||||
|       bool a_transposed, | ||||
|       uint64_t a_rows, | ||||
|       uint64_t a_cols, | ||||
|       int64_t lda, | ||||
|       bool b_transposed, | ||||
|       uint64_t b_rows, | ||||
|       uint64_t b_cols, | ||||
|       int64_t ldb, | ||||
|       int32_t batch_count, | ||||
|       int64_t a_batch_stride, | ||||
|       int64_t b_batch_stride); | ||||
|  | ||||
|   Matmul( | ||||
|       Device& device, | ||||
|       Dtype dtype, | ||||
|       bool a_transposed, | ||||
|       uint64_t a_rows, | ||||
|       uint64_t a_cols, | ||||
|       int64_t lda, | ||||
|       bool b_transposed, | ||||
|       uint64_t b_rows, | ||||
|       uint64_t b_cols, | ||||
|       int64_t ldb, | ||||
|       int64_t ldc, | ||||
|       int32_t batch_count, | ||||
|       int64_t a_batch_stride, | ||||
|       int64_t b_batch_stride, | ||||
|       int64_t c_batch_stride); | ||||
|  | ||||
|   ~Matmul(); | ||||
|  | ||||
|   void run( | ||||
|       cu::CommandEncoder& encoder, | ||||
|       array& out, | ||||
|       const array& a, | ||||
|       const array& b, | ||||
|       const std::optional<array>& c = std::nullopt, | ||||
|       float alpha = 1, | ||||
|       float beta = 0); | ||||
|  | ||||
|   void run_batched( | ||||
|       cu::CommandEncoder& encoder, | ||||
|       array& out, | ||||
|       const array& a, | ||||
|       const array& b, | ||||
|       const mlx::core::Shape& batch_shape, | ||||
|       const mlx::core::Strides& a_batch_strides, | ||||
|       const mlx::core::Strides& b_batch_strides); | ||||
|  | ||||
|   void run_batched( | ||||
|       cu::CommandEncoder& encoder, | ||||
|       array& out, | ||||
|       const array& a, | ||||
|       const array& b, | ||||
|       const array& c, | ||||
|       const mlx::core::Shape& batch_shape, | ||||
|       const mlx::core::Strides& a_batch_strides, | ||||
|       const mlx::core::Strides& b_batch_strides, | ||||
|       const mlx::core::Strides& c_batch_strides, | ||||
|       float alpha, | ||||
|       float beta); | ||||
|  | ||||
|  private: | ||||
|   void run_impl( | ||||
|       cu::CommandEncoder& encoder, | ||||
|       void* out, | ||||
|       const void* a, | ||||
|       const void* b, | ||||
|       const void* c, | ||||
|       float alpha = 1, | ||||
|       float beta = 0); | ||||
|  | ||||
|   uint64_t M_; | ||||
|   uint64_t N_; | ||||
|   cublasLtMatmulPreference_t pref_{nullptr}; | ||||
|   cublasLtHandle_t handle_{nullptr}; | ||||
|   cublasLtMatmulDesc_t matmul_desc_{nullptr}; | ||||
|   cublasLtMatrixLayout_t a_desc_{nullptr}; | ||||
|   cublasLtMatrixLayout_t b_desc_{nullptr}; | ||||
|   cublasLtMatrixLayout_t c_desc_{nullptr}; | ||||
|   cublasLtMatrixLayout_t out_desc_{nullptr}; | ||||
|   cublasLtMatmulHeuristicResult_t heuristic_; | ||||
| }; | ||||
|  | ||||
| } // namespace mlx::core::cu | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
| 
 | ||||
| #include "mlx/backend/cuda/gemv.h" | ||||
| #include "mlx/backend/cuda/gemms/gemv.h" | ||||
| #include "mlx/backend/cuda/kernel_utils.cuh" | ||||
| #include "mlx/dtype_utils.h" | ||||
| 
 | ||||
| @@ -2,279 +2,15 @@ | ||||
|  | ||||
| #include "mlx/backend/common/matmul.h" | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/backend/cuda/gemv.h" | ||||
| #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||
| #include "mlx/backend/cuda/gemms/gemv.h" | ||||
| #include "mlx/backend/gpu/copy.h" | ||||
| #include "mlx/dtype_utils.h" | ||||
| #include "mlx/primitives.h" | ||||
| #include "mlx/utils.h" | ||||
|  | ||||
| #include <fmt/format.h> | ||||
| #include <nvtx3/nvtx3.hpp> | ||||
|  | ||||
| #include <numeric> | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| namespace cu { | ||||
|  | ||||
| struct CublasPreference { | ||||
|   CublasPreference(Device& device) { | ||||
|     // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB | ||||
|     // for Hopper+: | ||||
|     // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace | ||||
|     uint64_t MiB = 1024 * 1024; | ||||
|     uint64_t workspace_size = | ||||
|         device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; | ||||
|  | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( | ||||
|         pref_, | ||||
|         CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, | ||||
|         &workspace_size, | ||||
|         sizeof(uint64_t))); | ||||
|   } | ||||
|  | ||||
|   ~CublasPreference() { | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); | ||||
|   } | ||||
|  | ||||
|   cublasLtMatmulPreference_t pref_{nullptr}; | ||||
| }; | ||||
|  | ||||
| cublasLtMatmulPreference_t cublas_preference(Device& device) { | ||||
|   static CublasPreference pref(device); | ||||
|   return pref.pref_; | ||||
| } | ||||
|  | ||||
| class MatMul { | ||||
|  public: | ||||
|   MatMul( | ||||
|       Device& device, | ||||
|       Dtype dtype, | ||||
|       bool a_transposed, | ||||
|       uint64_t a_rows, | ||||
|       uint64_t a_cols, | ||||
|       int64_t lda, | ||||
|       bool b_transposed, | ||||
|       uint64_t b_rows, | ||||
|       uint64_t b_cols, | ||||
|       int64_t ldb, | ||||
|       int32_t batch_count, | ||||
|       int64_t a_batch_stride, | ||||
|       int64_t b_batch_stride) | ||||
|       : handle_(device.lt_handle()), pref_(cublas_preference(device)) { | ||||
|     heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; | ||||
|  | ||||
|     auto scale_type = dtype_to_cuda_type(dtype); | ||||
|     if (dtype == bfloat16 || dtype == float16) { | ||||
|       scale_type = CUDA_R_32F; | ||||
|     } | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( | ||||
|         &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); | ||||
|     int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|         matmul_desc_, | ||||
|         CUBLASLT_MATMUL_DESC_POINTER_MODE, | ||||
|         &pointer_mode, | ||||
|         sizeof(int32_t))); | ||||
|     cublasOperation_t op = CUBLAS_OP_N; | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|         matmul_desc_, | ||||
|         CUBLASLT_MATMUL_DESC_TRANSA, | ||||
|         &op, | ||||
|         sizeof(cublasOperation_t))); | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( | ||||
|         matmul_desc_, | ||||
|         CUBLASLT_MATMUL_DESC_TRANSB, | ||||
|         &op, | ||||
|         sizeof(cublasOperation_t))); | ||||
|  | ||||
|     auto type = dtype_to_cuda_type(dtype); | ||||
|     a_desc_ = create_matrix_layout( | ||||
|         type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); | ||||
|     b_desc_ = create_matrix_layout( | ||||
|         type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); | ||||
|     out_desc_ = create_matrix_layout( | ||||
|         type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); | ||||
|   } | ||||
|  | ||||
|   MatMul( | ||||
|       Device& device, | ||||
|       Dtype dtype, | ||||
|       bool a_transposed, | ||||
|       uint64_t a_rows, | ||||
|       uint64_t a_cols, | ||||
|       int64_t lda, | ||||
|       bool b_transposed, | ||||
|       uint64_t b_rows, | ||||
|       uint64_t b_cols, | ||||
|       int64_t ldb, | ||||
|       int64_t ldc, | ||||
|       int32_t batch_count, | ||||
|       int64_t a_batch_stride, | ||||
|       int64_t b_batch_stride, | ||||
|       int64_t c_batch_stride) | ||||
|       : MatMul( | ||||
|             device, | ||||
|             dtype, | ||||
|             a_transposed, | ||||
|             a_rows, | ||||
|             a_cols, | ||||
|             lda, | ||||
|             b_transposed, | ||||
|             b_rows, | ||||
|             b_cols, | ||||
|             ldb, | ||||
|             batch_count, | ||||
|             a_batch_stride, | ||||
|             b_batch_stride) { | ||||
|     auto type = dtype_to_cuda_type(dtype); | ||||
|     c_desc_ = create_matrix_layout( | ||||
|         type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); | ||||
|   } | ||||
|  | ||||
|   ~MatMul() { | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); | ||||
|   } | ||||
|  | ||||
|   void run( | ||||
|       cu::CommandEncoder& encoder, | ||||
|       void* out, | ||||
|       void* a, | ||||
|       void* b, | ||||
|       void* c = nullptr, | ||||
|       float alpha = 1, | ||||
|       float beta = 0) { | ||||
|     if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { | ||||
|       int ret = 0; | ||||
|       CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( | ||||
|           handle_, | ||||
|           matmul_desc_, | ||||
|           a_desc_, | ||||
|           b_desc_, | ||||
|           out_desc_, | ||||
|           out_desc_, | ||||
|           pref_, | ||||
|           1, | ||||
|           &heuristic_, | ||||
|           &ret)); | ||||
|       if (ret == 0) { | ||||
|         throw std::runtime_error("Can not find algorithm for matmul."); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     void* workspace_ptr = nullptr; | ||||
|     if (heuristic_.workspaceSize > 0) { | ||||
|       array workspace( | ||||
|           allocator::malloc(heuristic_.workspaceSize), | ||||
|           {static_cast<int>(heuristic_.workspaceSize)}, | ||||
|           int8); | ||||
|       encoder.add_temporary(workspace); | ||||
|       workspace_ptr = workspace.data<void>(); | ||||
|     } | ||||
|  | ||||
|     auto capture = encoder.capture_context(); | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatmul( | ||||
|         handle_, | ||||
|         matmul_desc_, | ||||
|         &alpha, | ||||
|         a, | ||||
|         a_desc_, | ||||
|         b, | ||||
|         b_desc_, | ||||
|         &beta, | ||||
|         c ? c : out, | ||||
|         c ? c_desc_ : out_desc_, | ||||
|         out, | ||||
|         out_desc_, | ||||
|         &heuristic_.algo, | ||||
|         workspace_ptr, | ||||
|         heuristic_.workspaceSize, | ||||
|         encoder.stream())); | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   cublasComputeType_t dtype_to_compute_type(Dtype dtype) { | ||||
|     switch (dtype) { | ||||
|       case float16: | ||||
|         return CUBLAS_COMPUTE_32F; | ||||
|       case bfloat16: | ||||
|         return CUBLAS_COMPUTE_32F; | ||||
|       case float32: | ||||
|         return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 | ||||
|                                              : CUBLAS_COMPUTE_32F; | ||||
|       case float64: | ||||
|       case complex64: | ||||
|         return CUBLAS_COMPUTE_64F; | ||||
|       default: | ||||
|         throw std::runtime_error(fmt::format( | ||||
|             "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   cudaDataType_t dtype_to_cuda_type(Dtype dtype) { | ||||
|     switch (dtype) { | ||||
|       case float16: | ||||
|         return CUDA_R_16F; | ||||
|       case bfloat16: | ||||
|         return CUDA_R_16BF; | ||||
|       case float32: | ||||
|         return CUDA_R_32F; | ||||
|       case float64: | ||||
|         return CUDA_R_64F; | ||||
|       case complex64: | ||||
|         return CUDA_C_32F; | ||||
|       default: | ||||
|         throw std::runtime_error(fmt::format( | ||||
|             "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   cublasLtMatrixLayout_t create_matrix_layout( | ||||
|       cudaDataType_t type, | ||||
|       uint64_t rows, | ||||
|       uint64_t cols, | ||||
|       bool transposed, | ||||
|       int64_t ld, | ||||
|       int32_t batch_count, | ||||
|       int64_t batch_stride) { | ||||
|     cublasLtMatrixLayout_t desc; | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); | ||||
|     cublasLtOrder_t order = | ||||
|         transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; | ||||
|     CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|         desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); | ||||
|     if (batch_count > 1) { | ||||
|       CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|           desc, | ||||
|           CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, | ||||
|           &batch_count, | ||||
|           sizeof(int32_t))); | ||||
|       CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( | ||||
|           desc, | ||||
|           CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, | ||||
|           &batch_stride, | ||||
|           sizeof(int64_t))); | ||||
|     } | ||||
|     return desc; | ||||
|   } | ||||
|  | ||||
|   cublasLtMatmulPreference_t pref_{nullptr}; | ||||
|   cublasLtHandle_t handle_{nullptr}; | ||||
|   cublasLtMatmulDesc_t matmul_desc_{nullptr}; | ||||
|   cublasLtMatrixLayout_t a_desc_{nullptr}; | ||||
|   cublasLtMatrixLayout_t b_desc_{nullptr}; | ||||
|   cublasLtMatrixLayout_t c_desc_{nullptr}; | ||||
|   cublasLtMatrixLayout_t out_desc_{nullptr}; | ||||
|   cublasLtMatmulHeuristicResult_t heuristic_; | ||||
| }; | ||||
|  | ||||
| } // namespace cu | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| std::tuple<bool, int64_t, array> | ||||
| @@ -361,8 +97,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|  | ||||
|   ///////////////////////////////////////////////////////////////////////////// | ||||
|   // Invoke cublasLt | ||||
|  | ||||
|   cu::MatMul matmul( | ||||
|   cu::Matmul matmul( | ||||
|       cu::device(s.device), | ||||
|       a.dtype(), | ||||
|       a_transposed, | ||||
| @@ -377,27 +112,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       a_batch_strides.back(), | ||||
|       b_batch_strides.back()); | ||||
|  | ||||
|   encoder.set_input_array(a); | ||||
|   encoder.set_input_array(b); | ||||
|   encoder.set_output_array(out); | ||||
|   auto nbatch = batch_count / batch_shape.back(); | ||||
|   if (nbatch == 1) { | ||||
|     matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>()); | ||||
|   if ((batch_count / batch_shape.back()) == 1) { | ||||
|     matmul.run(encoder, out, a, b); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); | ||||
|   ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); | ||||
|   auto concurrent = encoder.concurrent_context(); | ||||
|   for (size_t i = 0; i < nbatch; ++i) { | ||||
|     matmul.run( | ||||
|         encoder, | ||||
|         out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, | ||||
|         a.data<int8_t>() + a.itemsize() * a_it.loc, | ||||
|         b.data<int8_t>() + b.itemsize() * b_it.loc); | ||||
|     a_it.step(); | ||||
|     b_it.step(); | ||||
|   } | ||||
|   matmul.run_batched( | ||||
|       encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); | ||||
| } | ||||
|  | ||||
| void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
| @@ -465,7 +186,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   ///////////////////////////////////////////////////////////////////////////// | ||||
|   // Invoke cublasLt | ||||
|  | ||||
|   cu::MatMul matmul( | ||||
|   cu::Matmul matmul( | ||||
|       cu::device(s.device), | ||||
|       a.dtype(), | ||||
|       a_transposed, | ||||
| @@ -482,41 +203,22 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       b_batch_strides.back(), | ||||
|       c_batch_strides.back()); | ||||
|  | ||||
|   encoder.set_input_array(a); | ||||
|   encoder.set_input_array(b); | ||||
|   encoder.set_input_array(c); | ||||
|   encoder.set_output_array(out); | ||||
|  | ||||
|   auto nbatch = batch_count / batch_shape.back(); | ||||
|   if (nbatch == 1) { | ||||
|     matmul.run( | ||||
|         encoder, | ||||
|         out.data<int8_t>(), | ||||
|         a.data<int8_t>(), | ||||
|         b.data<int8_t>(), | ||||
|         c.data<int8_t>(), | ||||
|         alpha_, | ||||
|         beta_); | ||||
|   if ((batch_count / batch_shape.back()) == 1) { | ||||
|     matmul.run(encoder, out, a, b, c, alpha_, beta_); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); | ||||
|   ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); | ||||
|   ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); | ||||
|   auto concurrent = encoder.concurrent_context(); | ||||
|   for (size_t i = 0; i < nbatch; ++i) { | ||||
|     matmul.run( | ||||
|         encoder, | ||||
|         out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, | ||||
|         a.data<int8_t>() + a.itemsize() * a_it.loc, | ||||
|         b.data<int8_t>() + b.itemsize() * b_it.loc, | ||||
|         c.data<int8_t>() + c.itemsize() * c_it.loc, | ||||
|         alpha_, | ||||
|         beta_); | ||||
|     a_it.step(); | ||||
|     b_it.step(); | ||||
|     c_it.step(); | ||||
|   } | ||||
|   matmul.run_batched( | ||||
|       encoder, | ||||
|       out, | ||||
|       a, | ||||
|       b, | ||||
|       c, | ||||
|       batch_shape, | ||||
|       a_batch_strides, | ||||
|       b_batch_strides, | ||||
|       c_batch_strides, | ||||
|       alpha_, | ||||
|       beta_); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun