mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
add a half simd gemm fallback (#2046)
* add a half simd gemm fallback * nit
This commit is contained in:
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const float16_t* a,
|
||||
const float16_t* b,
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
simd_gemm<float16_t, float>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
out + M * N * i,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
Reference in New Issue
Block a user