mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Simple gemm example
This commit is contained in:
		| @@ -26,6 +26,7 @@ target_sources( | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/event.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp | ||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp | ||||
|   | ||||
							
								
								
									
										47
									
								
								mlx/backend/cuda/gemms/simple_gemm.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								mlx/backend/cuda/gemms/simple_gemm.cu
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
|  | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/backend/cuda/kernel_utils.cuh" | ||||
| #include "mlx/backend/cuda/steel/gemm.cuh" | ||||
| #include "mlx/dtype_utils.h" | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
|  | ||||
| void simple_gemm( | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     array& out, | ||||
|     int M, | ||||
|     int N, | ||||
|     int K, | ||||
|     cu::CommandEncoder& enc) { | ||||
|   enc.set_input_array(a); | ||||
|   enc.set_input_array(b); | ||||
|   enc.set_output_array(out); | ||||
|   dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) { | ||||
|     using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||||
|     constexpr int BM = 128; | ||||
|     constexpr int BN = 128; | ||||
|     constexpr int BK = 64; | ||||
|  | ||||
|     auto kernel = ab_t_aligned<DataType, BM, BN, BK>; | ||||
|     cudaFuncSetAttribute( | ||||
|         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304); | ||||
|     cudaFuncSetAttribute( | ||||
|         kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); | ||||
|  | ||||
|     dim3 grid(N / BN, M / BM); | ||||
|     enc.add_kernel_node( | ||||
|         kernel, | ||||
|         grid, | ||||
|         4 * WARP_SIZE, | ||||
|         2 * sizeof(DataType) * (BM * BK + BN * BK), | ||||
|         a.data<DataType>(), | ||||
|         b.data<DataType>(), | ||||
|         out.data<DataType>(), | ||||
|         N, | ||||
|         K); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::cu | ||||
							
								
								
									
										18
									
								
								mlx/backend/cuda/gemms/simple_gemm.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								mlx/backend/cuda/gemms/simple_gemm.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include "mlx/backend/cuda/device.h" | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
|  | ||||
| void simple_gemm( | ||||
|     const array& a, | ||||
|     const array& b, | ||||
|     array& out, | ||||
|     int M, | ||||
|     int N, | ||||
|     int K, | ||||
|     cu::CommandEncoder& enc); | ||||
|  | ||||
| } | ||||
| @@ -4,6 +4,7 @@ | ||||
| #include "mlx/backend/cuda/device.h" | ||||
| #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||
| #include "mlx/backend/cuda/gemms/gemv.h" | ||||
| #include "mlx/backend/cuda/gemms/simple_gemm.h" | ||||
| #include "mlx/backend/gpu/copy.h" | ||||
| #include "mlx/primitives.h" | ||||
|  | ||||
| @@ -11,6 +12,7 @@ | ||||
| #include <numeric> | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| std::tuple<bool, int64_t, array> | ||||
| @@ -95,6 +97,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed && | ||||
|       b_transposed && batch_count == 1 && | ||||
|       env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 1) { | ||||
|     cu::simple_gemm(a, b, out, M, N, K, encoder); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   ///////////////////////////////////////////////////////////////////////////// | ||||
|   // Invoke cublasLt | ||||
|   CublasGemm gemm( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos