mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 07:03:10 +08:00
Simple gemm example
This commit is contained in:
parent
0c5fc63a36
commit
f70c62d69c
@ -26,6 +26,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
|
47
mlx/backend/cuda/gemms/simple_gemm.cu
Normal file
47
mlx/backend/cuda/gemms/simple_gemm.cu
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/steel/gemm.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
void simple_gemm(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
cu::CommandEncoder& enc) {
|
||||||
|
enc.set_input_array(a);
|
||||||
|
enc.set_input_array(b);
|
||||||
|
enc.set_output_array(out);
|
||||||
|
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
constexpr int BM = 128;
|
||||||
|
constexpr int BN = 128;
|
||||||
|
constexpr int BK = 64;
|
||||||
|
|
||||||
|
auto kernel = ab_t_aligned<DataType, BM, BN, BK>;
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 98304);
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||||
|
|
||||||
|
dim3 grid(N / BN, M / BM);
|
||||||
|
enc.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
4 * WARP_SIZE,
|
||||||
|
2 * sizeof(DataType) * (BM * BK + BN * BK),
|
||||||
|
a.data<DataType>(),
|
||||||
|
b.data<DataType>(),
|
||||||
|
out.data<DataType>(),
|
||||||
|
N,
|
||||||
|
K);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
18
mlx/backend/cuda/gemms/simple_gemm.h
Normal file
18
mlx/backend/cuda/gemms/simple_gemm.h
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
void simple_gemm(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
cu::CommandEncoder& enc);
|
||||||
|
|
||||||
|
}
|
@ -4,6 +4,7 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||||
#include "mlx/backend/cuda/gemms/gemv.h"
|
#include "mlx/backend/cuda/gemms/gemv.h"
|
||||||
|
#include "mlx/backend/cuda/gemms/simple_gemm.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@ -11,6 +12,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::tuple<bool, int64_t, array>
|
std::tuple<bool, int64_t, array>
|
||||||
@ -95,6 +97,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
|
||||||
|
b_transposed && batch_count == 1 &&
|
||||||
|
env::get_var("MLX_ENABLE_TEST_GEMM", 0) == 1) {
|
||||||
|
cu::simple_gemm(a, b, out, M, N, K, encoder);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Invoke cublasLt
|
// Invoke cublasLt
|
||||||
CublasGemm gemm(
|
CublasGemm gemm(
|
||||||
|
Loading…
Reference in New Issue
Block a user