mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Simple gemm example
This commit is contained in:
		| @@ -26,6 +26,7 @@ target_sources( | |||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/event.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/event.cu | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp |           ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu |           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu | ||||||
|  |           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp |           ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp |           ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp | ||||||
|           ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp |           ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp | ||||||
|   | |||||||
							
								
								
									
										47
									
								
								mlx/backend/cuda/gemms/simple_gemm.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								mlx/backend/cuda/gemms/simple_gemm.cu
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | |||||||
|  | // Copyright © 2025 Apple Inc. | ||||||
|  |  | ||||||
|  | #include "mlx/backend/cuda/device.h" | ||||||
|  | #include "mlx/backend/cuda/kernel_utils.cuh" | ||||||
|  | #include "mlx/backend/cuda/steel/gemm.cuh" | ||||||
|  | #include "mlx/dtype_utils.h" | ||||||
|  |  | ||||||
|  | namespace mlx::core::cu { | ||||||
|  |  | ||||||
|  | void simple_gemm( | ||||||
|  |     const array& a, | ||||||
|  |     const array& b, | ||||||
|  |     array& out, | ||||||
|  |     int M, | ||||||
|  |     int N, | ||||||
|  |     int K, | ||||||
|  |     cu::CommandEncoder& enc) { | ||||||
|  |   enc.set_input_array(a); | ||||||
|  |   enc.set_input_array(b); | ||||||
|  |   enc.set_output_array(out); | ||||||
|  |   dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) { | ||||||
|  |     using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; | ||||||
|  |     constexpr int BM = 128; | ||||||
|  |     constexpr int BN = 128; | ||||||
|  |     constexpr int BK = 64; | ||||||
|  |  | ||||||
|  |     auto kernel = ab_t_aligned<DataType, BM, BN, BK>; | ||||||
|  |     cudaFuncSetAttribute( | ||||||
|  |         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304); | ||||||
|  |     cudaFuncSetAttribute( | ||||||
|  |         kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); | ||||||
|  |  | ||||||
|  |     dim3 grid(N / BN, M / BM); | ||||||
|  |     enc.add_kernel_node( | ||||||
|  |         kernel, | ||||||
|  |         grid, | ||||||
|  |         4 * WARP_SIZE, | ||||||
|  |         2 * sizeof(DataType) * (BM * BK + BN * BK), | ||||||
|  |         a.data<DataType>(), | ||||||
|  |         b.data<DataType>(), | ||||||
|  |         out.data<DataType>(), | ||||||
|  |         N, | ||||||
|  |         K); | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace mlx::core::cu | ||||||
							
								
								
									
										18
									
								
								mlx/backend/cuda/gemms/simple_gemm.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								mlx/backend/cuda/gemms/simple_gemm.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | |||||||
|  | // Copyright © 2025 Apple Inc. | ||||||
|  |  | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include "mlx/backend/cuda/device.h" | ||||||
|  |  | ||||||
|  | namespace mlx::core::cu { | ||||||
|  |  | ||||||
|  | void simple_gemm( | ||||||
|  |     const array& a, | ||||||
|  |     const array& b, | ||||||
|  |     array& out, | ||||||
|  |     int M, | ||||||
|  |     int N, | ||||||
|  |     int K, | ||||||
|  |     cu::CommandEncoder& enc); | ||||||
|  |  | ||||||
|  | } | ||||||
| @@ -4,6 +4,7 @@ | |||||||
| #include "mlx/backend/cuda/device.h" | #include "mlx/backend/cuda/device.h" | ||||||
| #include "mlx/backend/cuda/gemms/cublas_gemm.h" | #include "mlx/backend/cuda/gemms/cublas_gemm.h" | ||||||
| #include "mlx/backend/cuda/gemms/gemv.h" | #include "mlx/backend/cuda/gemms/gemv.h" | ||||||
|  | #include "mlx/backend/cuda/gemms/simple_gemm.h" | ||||||
| #include "mlx/backend/gpu/copy.h" | #include "mlx/backend/gpu/copy.h" | ||||||
| #include "mlx/primitives.h" | #include "mlx/primitives.h" | ||||||
|  |  | ||||||
| @@ -11,6 +12,7 @@ | |||||||
| #include <numeric> | #include <numeric> | ||||||
|  |  | ||||||
| namespace mlx::core { | namespace mlx::core { | ||||||
|  |  | ||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
| std::tuple<bool, int64_t, array> | std::tuple<bool, int64_t, array> | ||||||
| @@ -95,6 +97,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed && | ||||||
|  |       b_transposed && batch_count == 1 && | ||||||
|  |       env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 1) { | ||||||
|  |     cu::simple_gemm(a, b, out, M, N, K, encoder); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   ///////////////////////////////////////////////////////////////////////////// |   ///////////////////////////////////////////////////////////////////////////// | ||||||
|   // Invoke cublasLt |   // Invoke cublasLt | ||||||
|   CublasGemm gemm( |   CublasGemm gemm( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos