Option to JIT steel gemm / conv (#1139)

This commit is contained in:
Awni Hannun
2024-05-23 18:07:34 -07:00
committed by GitHub
parent eab2685c67
commit 7e26fd8032
31 changed files with 2504 additions and 1540 deletions

View File

@@ -1,4 +1,4 @@
function(make_jit_source SRC_NAME)
function(make_jit_source SRC_FILE)
# This function takes a metal header file,
# runs the C preprocessesor on it, and makes
# the processed contents available as a string in a C++ function
@@ -9,6 +9,7 @@ function(make_jit_source SRC_NAME)
#
# Additional arguments to this function are treated as dependencies
# in the Cmake build system.
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp
COMMAND /bin/bash
@@ -16,10 +17,10 @@ function(make_jit_source SRC_NAME)
${CMAKE_CURRENT_BINARY_DIR}/jit
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
${SRC_NAME}
${SRC_FILE}
"-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/${SRC_NAME}.h
kernels/${SRC_FILE}.h
${ARGN}
)
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
@@ -73,6 +74,39 @@ if (MLX_METAL_JIT)
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
)
make_jit_source(
steel/gemm/gemm
kernels/steel/utils.h
kernels/steel/gemm/loader.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(
steel/gemm/kernels/steel_gemm_masked
kernels/steel/defines.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(
steel/conv/conv
kernels/steel/utils.h
kernels/steel/defines.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/transforms.h
kernels/steel/conv/params.h
kernels/steel/conv/loader.h
kernels/steel/conv/loaders/loader_channel_l.h
kernels/steel/conv/loaders/loader_channel_n.h
)
make_jit_source(
steel/conv/kernels/steel_conv
)
make_jit_source(
steel/conv/kernels/steel_conv_general
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
else()
target_sources(
mlx