diff --git a/mlx/backend/accelerate/CMakeLists.txt b/mlx/backend/accelerate/CMakeLists.txt index 96afd3107..96add2ae5 100644 --- a/mlx/backend/accelerate/CMakeLists.txt +++ b/mlx/backend/accelerate/CMakeLists.txt @@ -1,6 +1,2 @@ -target_sources( - mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp) +target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp) diff --git a/mlx/backend/accelerate/conv.cpp b/mlx/backend/accelerate/conv.cpp deleted file mode 100644 index 026813aa2..000000000 --- a/mlx/backend/accelerate/conv.cpp +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include - -#include -#include - -#include "mlx/backend/common/copy.h" -#include "mlx/primitives.h" -#include "mlx/utils.h" - -namespace mlx::core { - -void Convolution::eval_cpu(const std::vector& inputs, array& out) { - eval(inputs, out); - - // TODO: Add accelerate based optimizations for CPU conv -} - -} // namespace mlx::core diff --git a/mlx/backend/accelerate/matmul.cpp b/mlx/backend/accelerate/matmul.cpp deleted file mode 100644 index 78ce66e7a..000000000 --- a/mlx/backend/accelerate/matmul.cpp +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include - -#include - -#include "mlx/backend/accelerate/utils.h" -#include "mlx/backend/common/copy.h" -#include "mlx/primitives.h" -#include "mlx/utils.h" - -namespace mlx::core { - -namespace { - -std::tuple check_transpose(const array& arr) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (stx == arr.shape(-1) && sty == 1) { - return std::make_tuple(false, stx, arr); - } else if (stx == 1 && sty == arr.shape(-2)) { - return std::make_tuple(true, sty, arr); - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); - size_t stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); - } -} - -inline void matmul_cblas_general( - const array& a_pre, - const array& b_pre, - array& out, - float alpha = 1.0f, - float beta = 0.0f) { - if (out.dtype() != float32) { - throw std::runtime_error( - "[matmul_cblas] on CPU currently only supports float32"); - } - - auto [a_transposed, lda, a] = check_transpose(a_pre); - auto [b_transposed, ldb, b] = check_transpose(b_pre); - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); - - if (M == 0 || N == 0) { - return; - } - if (K == 0) { - std::memset(static_cast(out.data()), 0, out.nbytes()); - return; - } - - for (int i = 0; i < (a.size() / (M * K)); ++i) { - cblas_sgemm( - CblasRowMajor, - a_transposed ? CblasTrans : CblasNoTrans, // transA - b_transposed ? CblasTrans : CblasNoTrans, // transB - M, - N, - K, - alpha, // alpha - a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), - lda, - b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), - ldb, - beta, // beta - out.data() + M * N * i, - out.shape(-1) // ldc - ); - } -} - -inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) { - if (out.dtype() != float32) { - throw std::runtime_error( - "[matmul_cblas] on CPU currently only supports float32"); - } - out.set_data(allocator::malloc_or_wait(out.nbytes())); - return matmul_cblas_general(a_pre, b_pre, out); -} - -inline void matmul_bnns_general( - const array& a_pre, - const array& b_pre, - array& out, - float alpha = 1.0f, - float beta = 0.0f) { - // TODO: Update to utilize BNNS broadcasting - - auto [a_transposed, lda, a] = check_transpose(a_pre); - auto [b_transposed, ldb, b] = check_transpose(b_pre); - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); - - if (M == 0 || N == 0) { - return; - } - if (K == 0) { - std::memset(static_cast(out.data()), 0, out.nbytes()); - return; - } - - BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype()); - - const BNNSLayerParametersBroadcastMatMul gemm_params{ - /* float alpha = */ alpha, - /* float beta = */ beta, - /* bool transA = */ a_transposed, - /* bool transB = */ b_transposed, - /* bool quadratic = */ false, - /* bool a_is_weights = */ false, - /* bool b_is_weights = */ false, - /* BNNSNDArrayDescriptor iA_desc = */ - BNNSNDArrayDescriptor{ - /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, - /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, - - /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ - {lda, (M * K) / lda, 0, 0, 0, 0, 0, 0}, - /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ - {1, lda, 0, 0, 0, 0, 0, 0}, - - /* void * _Nullable data = */ nullptr, - /* BNNSDataType data_type = */ bnns_dtype, - - /* void * _Nullable table_data = */ nullptr, - /* BNNSDataType table_data_type = */ bnns_dtype, - - /* float data_scale = */ 1.0, - /* float data_bias = */ 0.0, - }, - /* BNNSNDArrayDescriptor iB_desc = */ - BNNSNDArrayDescriptor{ - /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, - /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, - - /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ - {ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0}, - /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ - {1, ldb, 0, 0, 0, 0, 0, 0}, - - /* void * _Nullable data = */ nullptr, - /* BNNSDataType data_type = */ bnns_dtype, - - /* void * _Nullable table_data = */ nullptr, - /* BNNSDataType table_data_type = */ bnns_dtype, - - /* float data_scale = */ 1.0, - /* float data_bias = */ 0.0, - }, - /* BNNSNDArrayDescriptor o_desc = */ - BNNSNDArrayDescriptor{ - /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, - /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, - - /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ - {N, M, 0, 0, 0, 0, 0, 0}, - /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ - {1, N, 0, 0, 0, 0, 0, 0}, - - /* void * _Nullable data = */ nullptr, - /* BNNSDataType data_type = */ bnns_dtype, - - /* void * _Nullable table_data = */ nullptr, - /* BNNSDataType table_data_type = */ bnns_dtype, - - /* float data_scale = */ 1.0, - /* float data_bias = */ 0.0, - }, - }; - - auto bnns_filter = - BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr); - - for (int i = 0; i < (a.size() / (M * K)); ++i) { - BNNSFilterApplyTwoInput( - bnns_filter, - a.data() + - elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(), - b.data() + - elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(), - out.data() + M * N * i * out.itemsize()); - } - - BNNSFilterDestroy(bnns_filter); -} - -inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) { - // TODO: Update to utilize BNNS broadcasting - out.set_data(allocator::malloc_or_wait(out.nbytes())); - return matmul_bnns_general(a_pre, b_pre, out); -} - -template -inline void mask_matrix( - T* data, - const bool* mask, - int tile_size, - const int X, - const int Y, - const size_t X_data_str, - const size_t Y_data_str, - const size_t X_mask_str, - const size_t Y_mask_str) { - int tX = (X + tile_size - 1) / tile_size; - int tY = (Y + tile_size - 1) / tile_size; - - for (int i = 0; i < tX; i++) { - for (int j = 0; j < tY; j++) { - bool do_mask = mask[i * X_mask_str + j * Y_mask_str]; - if (!do_mask) { - int loc_x = i * tile_size; - int loc_y = j * tile_size; - T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str; - - int size_x = std::min(tile_size, X - loc_x); - int size_y = std::min(tile_size, Y - loc_y); - for (int ii = 0; ii < size_x; ii++) { - for (int jj = 0; jj < size_y; jj++) { - data_block[ii * X_data_str + jj * Y_data_str] = T(0.); - } - } - } - } - } -} - -} // namespace - -void Matmul::eval_cpu(const std::vector& inputs, array& out) { - if (out.dtype() == float32) { - return matmul_cblas(inputs[0], inputs[1], out); - } - return matmul_bnns(inputs[0], inputs[1], out); -} - -void AddMM::eval_cpu(const std::vector& inputs, array& out) { - // Fill output with C - auto& c = inputs[2]; - CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General; - copy(c, out, ctype); - - if (out.dtype() == float32) { - return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_); - } - return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_); -} - -} // namespace mlx::core diff --git a/mlx/backend/accelerate/utils.h b/mlx/backend/accelerate/utils.h deleted file mode 100644 index 389099f37..000000000 --- a/mlx/backend/accelerate/utils.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#pragma once - -#include -#include "mlx/dtype.h" - -namespace mlx::core { - -BNNSDataType to_bnns_dtype(Dtype mlx_dtype) { - uint32_t size_bits = size_of(mlx_dtype) * 8; - switch (kindof(mlx_dtype)) { - case Dtype::Kind::b: - return BNNSDataTypeBoolean; - case Dtype::Kind::u: - return BNNSDataType(BNNSDataTypeUIntBit | size_bits); - case Dtype::Kind::i: - return BNNSDataType(BNNSDataTypeIntBit | size_bits); - case Dtype::Kind::f: - return BNNSDataType(BNNSDataTypeFloatBit | size_bits); - case Dtype::Kind::V: - return BNNSDataTypeBFloat16; - case Dtype::Kind::c: - throw std::invalid_argument("BNNS does not support complex types"); - } -} - -} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 97fc48008..bd793621c 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -50,6 +50,8 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp @@ -71,6 +73,13 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) +if(MLX_BUILD_ACCELERATE) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) +else() + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) +endif() + if(IOS) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp) else() diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 3f6d324fb..b36f73b83 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -1128,7 +1128,7 @@ void conv_3D_cpu( } // namespace -void Convolution::eval(const std::vector& inputs, array& out) { +void Convolution::eval_cpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); auto& in = inputs[0]; diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 899de35cd..21779c35a 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -1,11 +1,6 @@ // Copyright © 2023-2024 Apple Inc. -#include - #include "mlx/array.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack.h" -#include "mlx/backend/common/utils.h" #include "mlx/primitives.h" #define DEFAULT(primitive) \ @@ -21,89 +16,7 @@ namespace mlx::core { -DEFAULT(Convolution) DEFAULT(Reduce) DEFAULT(Scan) -namespace { - -inline void matmul_common_general( - const array& a_pre, - const array& b_pre, - array& out, - float alpha = 1.0f, - float beta = 0.0f) { - auto check_transpose = [](const array& arr) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (stx == arr.shape(-1) && sty == 1) { - return std::make_tuple(false, stx, arr); - } else if (stx == 1 && sty == arr.shape(-2)) { - return std::make_tuple(true, sty, arr); - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy(arr, arr_copy, CopyType::General); - stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); - } - }; - - auto [a_transposed, lda, a] = check_transpose(a_pre); - auto [b_transposed, ldb, b] = check_transpose(b_pre); - size_t M = a.shape(-2); - size_t N = b.shape(-1); - size_t K = a.shape(-1); - if (M == 0 || N == 0) { - return; - } - if (K == 0) { - std::memset(static_cast(out.data()), 0, out.nbytes()); - return; - } - - for (int i = 0; i < (a.size() / (M * K)); ++i) { - cblas_sgemm( - CblasRowMajor, - a_transposed ? CblasTrans : CblasNoTrans, // transA - b_transposed ? CblasTrans : CblasNoTrans, // transB - M, - N, - K, - alpha, // alpha - a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), - lda, - b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), - ldb, - beta, // beta - out.data() + M * N * i, - out.shape(-1) // ldc - ); - } -} - -} // namespace - -void Matmul::eval_cpu(const std::vector& inputs, array& out) { - if (out.dtype() != float32) { - throw std::runtime_error( - "[Matmul::eval_cpu] Currently only supports float32."); - } - out.set_data(allocator::malloc_or_wait(out.nbytes())); - return matmul_common_general(inputs[0], inputs[1], out); -} - -void AddMM::eval_cpu(const std::vector& inputs, array& out) { - if (out.dtype() != float32) { - throw std::runtime_error( - "[AddMM::eval_cpu] Currently only supports float32."); - } - - // Fill output with C - auto& c = inputs[2]; - CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General; - copy(c, out, ctype); - - return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_); -} - } // namespace mlx::core diff --git a/mlx/backend/common/gemm.h b/mlx/backend/common/gemm.h new file mode 100644 index 000000000..008c29157 --- /dev/null +++ b/mlx/backend/common/gemm.h @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#pragma once +#include "mlx/array.h" + +namespace mlx::core { + +template +void matmul( + const array& a, + const array& b, + array& out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + float alpha, + float beta); + +} // namespace mlx::core diff --git a/mlx/backend/common/gemms/bnns.cpp b/mlx/backend/common/gemms/bnns.cpp new file mode 100644 index 000000000..5c5cee739 --- /dev/null +++ b/mlx/backend/common/gemms/bnns.cpp @@ -0,0 +1,157 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/array.h" +#include "mlx/backend/common/gemm.h" +#include "mlx/backend/common/utils.h" +#include "mlx/dtype.h" + +namespace mlx::core { + +BNNSDataType to_bnns_dtype(Dtype mlx_dtype) { + uint32_t size_bits = size_of(mlx_dtype) * 8; + switch (kindof(mlx_dtype)) { + case Dtype::Kind::b: + return BNNSDataTypeBoolean; + case Dtype::Kind::u: + return BNNSDataType(BNNSDataTypeUIntBit | size_bits); + case Dtype::Kind::i: + return BNNSDataType(BNNSDataTypeIntBit | size_bits); + case Dtype::Kind::f: + return BNNSDataType(BNNSDataTypeFloatBit | size_bits); + case Dtype::Kind::V: + return BNNSDataTypeBFloat16; + case Dtype::Kind::c: + throw std::invalid_argument("BNNS does not support complex types"); + } +} + +void matmul_bnns( + const array& a, + const array& b, + array& out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + float alpha, + float beta) { + size_t M = a.shape(-2); + size_t N = b.shape(-1); + size_t K = a.shape(-1); + + BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype()); + + const BNNSLayerParametersBroadcastMatMul gemm_params{ + /* float alpha = */ alpha, + /* float beta = */ beta, + /* bool transA = */ a_transposed, + /* bool transB = */ b_transposed, + /* bool quadratic = */ false, + /* bool a_is_weights = */ false, + /* bool b_is_weights = */ false, + /* BNNSNDArrayDescriptor iA_desc = */ + BNNSNDArrayDescriptor{ + /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, + /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, + + /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ + {lda, (M * K) / lda, 0, 0, 0, 0, 0, 0}, + /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ + {1, lda, 0, 0, 0, 0, 0, 0}, + + /* void * _Nullable data = */ nullptr, + /* BNNSDataType data_type = */ bnns_dtype, + + /* void * _Nullable table_data = */ nullptr, + /* BNNSDataType table_data_type = */ bnns_dtype, + + /* float data_scale = */ 1.0, + /* float data_bias = */ 0.0, + }, + /* BNNSNDArrayDescriptor iB_desc = */ + BNNSNDArrayDescriptor{ + /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, + /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, + + /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ + {ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0}, + /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ + {1, ldb, 0, 0, 0, 0, 0, 0}, + + /* void * _Nullable data = */ nullptr, + /* BNNSDataType data_type = */ bnns_dtype, + + /* void * _Nullable table_data = */ nullptr, + /* BNNSDataType table_data_type = */ bnns_dtype, + + /* float data_scale = */ 1.0, + /* float data_bias = */ 0.0, + }, + /* BNNSNDArrayDescriptor o_desc = */ + BNNSNDArrayDescriptor{ + /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, + /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, + + /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ + {N, M, 0, 0, 0, 0, 0, 0}, + /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ + {1, N, 0, 0, 0, 0, 0, 0}, + + /* void * _Nullable data = */ nullptr, + /* BNNSDataType data_type = */ bnns_dtype, + + /* void * _Nullable table_data = */ nullptr, + /* BNNSDataType table_data_type = */ bnns_dtype, + + /* float data_scale = */ 1.0, + /* float data_bias = */ 0.0, + }, + }; + + auto bnns_filter = + BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr); + + for (int i = 0; i < (a.size() / (M * K)); ++i) { + BNNSFilterApplyTwoInput( + bnns_filter, + a.data() + + elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(), + b.data() + + elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(), + out.data() + M * N * i * out.itemsize()); + } + + BNNSFilterDestroy(bnns_filter); +} + +template <> +void matmul( + const array& a, + const array& b, + array& out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + float alpha, + float beta) { + matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); +} + +template <> +void matmul( + const array& a, + const array& b, + array& out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + float alpha, + float beta) { + matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/gemms/cblas.cpp b/mlx/backend/common/gemms/cblas.cpp new file mode 100644 index 000000000..e6d07bf84 --- /dev/null +++ b/mlx/backend/common/gemms/cblas.cpp @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/gemm.h" +#include "mlx/backend/common/lapack.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +template <> +void matmul( + const array& a, + const array& b, + array& out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + float alpha, + float beta) { + size_t M = a.shape(-2); + size_t N = b.shape(-1); + size_t K = a.shape(-1); + + for (int i = 0; i < (a.size() / (M * K)); ++i) { + cblas_sgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB + M, + N, + K, + alpha, // alpha + a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), + lda, + b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), + ldb, + beta, // beta + out.data() + M * N * i, + out.shape(-1) // ldc + ); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/gemms/no_bf16.cpp b/mlx/backend/common/gemms/no_bf16.cpp new file mode 100644 index 000000000..2abcf1536 --- /dev/null +++ b/mlx/backend/common/gemms/no_bf16.cpp @@ -0,0 +1,21 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/gemm.h" + +namespace mlx::core { + +template <> +void matmul( + const array&, + const array&, + array&, + bool, + bool, + size_t, + size_t, + float, + float) { + throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported."); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/gemms/no_fp16.cpp b/mlx/backend/common/gemms/no_fp16.cpp new file mode 100644 index 000000000..ccc2f2a31 --- /dev/null +++ b/mlx/backend/common/gemms/no_fp16.cpp @@ -0,0 +1,21 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/gemm.h" + +namespace mlx::core { + +template <> +void matmul( + const array&, + const array&, + array&, + bool, + bool, + size_t, + size_t, + float, + float) { + throw std::runtime_error("[Matmul::eval_cpu] float16 not supported."); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/matmul.cpp b/mlx/backend/common/matmul.cpp new file mode 100644 index 000000000..1966c57b6 --- /dev/null +++ b/mlx/backend/common/matmul.cpp @@ -0,0 +1,79 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include "mlx/array.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/gemm.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void matmul_general( + const array& a_pre, + const array& b_pre, + array& out, + float alpha = 1.0f, + float beta = 0.0f) { + auto check_transpose = [](const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (stx == arr.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + + auto [a_transposed, lda, a] = check_transpose(a_pre); + auto [b_transposed, ldb, b] = check_transpose(b_pre); + size_t M = a.shape(-2); + size_t N = b.shape(-1); + size_t K = a.shape(-1); + if (M == 0 || N == 0) { + return; + } + + if (out.dtype() == float32) { + matmul(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + } else if (out.dtype() == float16) { + matmul( + a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + } else if (out.dtype() == bfloat16) { + matmul( + a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + } else { + throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); + } +} + +void Matmul::eval_cpu(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + if (inputs[0].shape(-1) == 0) { + std::memset(out.data(), 0, out.nbytes()); + return; + } + return matmul_general(inputs[0], inputs[1], out); +} + +void AddMM::eval_cpu(const std::vector& inputs, array& out) { + if (out.dtype() != float32) { + throw std::runtime_error( + "[AddMM::eval_cpu] Currently only supports float32."); + } + + // Fill output with C + auto& c = inputs[2]; + CopyType ctype = c.data_size() == 1 + ? CopyType::Scalar + : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy(c, out, ctype); + + return matmul_general(inputs[0], inputs[1], out, alpha_, beta_); +} + +} // namespace mlx::core diff --git a/mlx/primitives.h b/mlx/primitives.h index 8158c88d6..77586d819 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -704,8 +704,6 @@ class Convolution : public UnaryPrimitive { std::vector input_dilation_; int groups_; bool flip_; - - void eval(const std::vector& inputs, array& out); }; class Copy : public UnaryPrimitive {