From f2c85308c1ed542610179f9844da089ea6045d81 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 7 Apr 2025 09:31:29 -0700 Subject: [PATCH] add a half simd gemm fallback (#2046) * add a half simd gemm fallback * nit --- mlx/backend/cpu/CMakeLists.txt | 4 +- mlx/backend/cpu/gemms/no_bf16.cpp | 27 ------ mlx/backend/cpu/gemms/no_fp16.cpp | 27 ------ mlx/backend/cpu/gemms/simd_bf16.cpp | 45 +++++++++ mlx/backend/cpu/gemms/simd_fp16.cpp | 45 +++++++++ mlx/backend/cpu/gemms/simd_gemm.h | 139 ++++++++++++++++++++++++++++ python/tests/test_blas.py | 2 +- 7 files changed, 232 insertions(+), 57 deletions(-) delete mode 100644 mlx/backend/cpu/gemms/no_bf16.cpp delete mode 100644 mlx/backend/cpu/gemms/no_fp16.cpp create mode 100644 mlx/backend/cpu/gemms/simd_bf16.cpp create mode 100644 mlx/backend/cpu/gemms/simd_fp16.cpp create mode 100644 mlx/backend/cpu/gemms/simd_gemm.h diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index 96159dfa8..152f33b17 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -74,8 +74,8 @@ target_sources( if(MLX_BUILD_ACCELERATE) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) else() - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp) endif() if(IOS) diff --git a/mlx/backend/cpu/gemms/no_bf16.cpp b/mlx/backend/cpu/gemms/no_bf16.cpp deleted file mode 100644 index 157c07f46..000000000 --- a/mlx/backend/cpu/gemms/no_bf16.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/cpu/gemm.h" - -namespace mlx::core { - -template <> -void matmul( - const bfloat16_t*, - const bfloat16_t*, - bfloat16_t*, - bool, - bool, - size_t, - size_t, - size_t, - float, - float, - size_t, - const Shape&, - const Strides&, - const Shape&, - const Strides&) { - throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported."); -} - -} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/no_fp16.cpp b/mlx/backend/cpu/gemms/no_fp16.cpp deleted file mode 100644 index 3f3f41cc5..000000000 --- a/mlx/backend/cpu/gemms/no_fp16.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/cpu/gemm.h" - -namespace mlx::core { - -template <> -void matmul( - const float16_t*, - const float16_t*, - float16_t*, - bool, - bool, - size_t, - size_t, - size_t, - float, - float, - size_t, - const Shape&, - const Strides&, - const Shape&, - const Strides&) { - throw std::runtime_error("[Matmul::eval_cpu] float16 not supported."); -} - -} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/simd_bf16.cpp b/mlx/backend/cpu/gemms/simd_bf16.cpp new file mode 100644 index 000000000..58f5964b6 --- /dev/null +++ b/mlx/backend/cpu/gemms/simd_bf16.cpp @@ -0,0 +1,45 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/gemm.h" +#include "mlx/backend/cpu/gemms/simd_gemm.h" + +namespace mlx::core { + +template <> +void matmul( + const bfloat16_t* a, + const bfloat16_t* b, + bfloat16_t* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + size_t ldc, + float alpha, + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + auto ndim = a_shape.size(); + size_t M = a_shape[ndim - 2]; + size_t N = b_shape[ndim - 1]; + size_t K = a_shape[ndim - 1]; + for (int i = 0; i < batch_size; ++i) { + simd_gemm( + a + elem_to_loc(M * K * i, a_shape, a_strides), + b + elem_to_loc(K * N * i, b_shape, b_strides), + out + M * N * i, + a_transposed, + b_transposed, + M, + N, + K, + alpha, + beta); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/simd_fp16.cpp b/mlx/backend/cpu/gemms/simd_fp16.cpp new file mode 100644 index 000000000..93467da86 --- /dev/null +++ b/mlx/backend/cpu/gemms/simd_fp16.cpp @@ -0,0 +1,45 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/gemm.h" +#include "mlx/backend/cpu/gemms/simd_gemm.h" + +namespace mlx::core { + +template <> +void matmul( + const float16_t* a, + const float16_t* b, + float16_t* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + size_t ldc, + float alpha, + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides) { + auto ndim = a_shape.size(); + size_t M = a_shape[ndim - 2]; + size_t N = b_shape[ndim - 1]; + size_t K = a_shape[ndim - 1]; + for (int i = 0; i < batch_size; ++i) { + simd_gemm( + a + elem_to_loc(M * K * i, a_shape, a_strides), + b + elem_to_loc(K * N * i, b_shape, b_strides), + out + M * N * i, + a_transposed, + b_transposed, + M, + N, + K, + alpha, + beta); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/simd_gemm.h b/mlx/backend/cpu/gemms/simd_gemm.h new file mode 100644 index 000000000..a23c7dea3 --- /dev/null +++ b/mlx/backend/cpu/gemms/simd_gemm.h @@ -0,0 +1,139 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core { + +inline int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +template +void load_block( + const T* in, + AccT* out, + int M, + int N, + int i, + int j, + bool transpose) { + if (transpose) { + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + out[jj * block_size + ii] = + in[(i * block_size + ii) * N + j * block_size + jj]; + } + } + } else { + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + out[ii * block_size + jj] = + in[(i * block_size + ii) * N + j * block_size + jj]; + } + } + } +} + +template +void simd_gemm( + const T* a, + const T* b, + T* c, + bool a_trans, + bool b_trans, + int M, + int N, + int K, + float alpha, + float beta) { + constexpr int block_size = 16; + constexpr int simd_size = simd::max_size; + static_assert( + (block_size % simd_size) == 0, + "Block size must be divisible by SIMD size"); + + int last_k_block_size = K - block_size * (K / block_size); + int last_k_simd_block = (last_k_block_size / simd_size) * simd_size; + for (int i = 0; i < ceildiv(M, block_size); i++) { + for (int j = 0; j < ceildiv(N, block_size); j++) { + AccT c_block[block_size * block_size] = {0.0}; + AccT a_block[block_size * block_size]; + AccT b_block[block_size * block_size]; + + int k = 0; + for (; k < K / block_size; k++) { + // Load a and b blocks + if (a_trans) { + load_block(a, a_block, K, M, k, i, true); + } else { + load_block(a, a_block, M, K, i, k, false); + } + if (b_trans) { + load_block(b, b_block, N, K, j, k, false); + } else { + load_block(b, b_block, K, N, k, j, true); + } + + // Multiply and accumulate + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + for (int kk = 0; kk < block_size; kk += simd_size) { + auto av = + simd::load(a_block + ii * block_size + kk); + auto bv = + simd::load(b_block + jj * block_size + kk); + c_block[ii * block_size + jj] += simd::sum(av * bv); + } + } + } + } + if (last_k_block_size) { + // Load a and b blocks + if (a_trans) { + load_block(a, a_block, K, M, k, i, true); + } else { + load_block(a, a_block, M, K, i, k, false); + } + if (b_trans) { + load_block(b, b_block, N, K, j, k, false); + } else { + load_block(b, b_block, K, N, k, j, true); + } + + // Multiply and accumulate + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + int kk = 0; + for (; kk < last_k_simd_block; kk += simd_size) { + auto av = + simd::load(a_block + ii * block_size + kk); + auto bv = + simd::load(b_block + jj * block_size + kk); + c_block[ii * block_size + jj] += simd::sum(av * bv); + } + for (; kk < last_k_block_size; ++kk) { + c_block[ii * block_size + jj] += + a_block[ii * block_size + kk] * b_block[jj * block_size + kk]; + } + } + } + } + + // Store + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + auto c_idx = (i * block_size + ii) * N + j * block_size + jj; + if (beta != 0) { + c[c_idx] = static_cast( + alpha * c_block[ii * block_size + jj] + beta * c[c_idx]); + } else { + c[c_idx] = static_cast(alpha * c_block[ii * block_size + jj]); + } + } + } + } + } +} + +} // namespace mlx::core diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index c8627dbbd..8b7fb462d 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -12,7 +12,7 @@ import numpy as np class TestBlas(mlx_tests.MLXTestCase): @property def dtypes(self): - return ["float32", "float16"] if mx.metal.is_available() else ["float32"] + return ["float32", "float16"] def __gemm_test( self,