Fp64 on the CPU (#1843)

* add fp64 data type

* clean build

* update docs

* fix bug
This commit is contained in:
Awni Hannun
2025-02-07 15:52:22 -08:00
committed by GitHub
parent 1a1b2108ec
commit 1c0c118f7c
32 changed files with 438 additions and 65 deletions

View File

@@ -41,4 +41,39 @@ void matmul<float>(
}
}
template <>
void matmul<double>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_dgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<double>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core