mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Unify CPU matmuls, remove unused accelerate conv (#1814)
* unify matmuls * Update mlx/backend/common/matmul.cpp Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
ded914f442
commit
c6fc07f1f4
@ -1,6 +1,2 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp)
|
||||
|
@ -1,20 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
|
||||
// TODO: Add accelerate based optimizations for CPU conv
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,253 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/backend/accelerate/utils.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_cblas_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[matmul_cblas] on CPU currently only supports float32");
|
||||
}
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[matmul_cblas] on CPU currently only supports float32");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_cblas_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
inline void matmul_bnns_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
// TODO: Update to utilize BNNS broadcasting
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta,
|
||||
/* bool transA = */ a_transposed,
|
||||
/* bool transB = */ b_transposed,
|
||||
/* bool quadratic = */ false,
|
||||
/* bool a_is_weights = */ false,
|
||||
/* bool b_is_weights = */ false,
|
||||
/* BNNSNDArrayDescriptor iA_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, lda, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor iB_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, ldb, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor o_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{N, M, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, N, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
};
|
||||
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
a.data<uint8_t>() +
|
||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
||||
b.data<uint8_t>() +
|
||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
||||
}
|
||||
|
||||
BNNSFilterDestroy(bnns_filter);
|
||||
}
|
||||
|
||||
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||
// TODO: Update to utilize BNNS broadcasting
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_bnns_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
int tile_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
int tX = (X + tile_size - 1) / tile_size;
|
||||
int tY = (Y + tile_size - 1) / tile_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
int loc_x = i * tile_size;
|
||||
int loc_y = j * tile_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(tile_size, X - loc_x);
|
||||
int size_y = std::min(tile_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() == float32) {
|
||||
return matmul_cblas(inputs[0], inputs[1], out);
|
||||
}
|
||||
return matmul_bnns(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy(c, out, ctype);
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,28 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
||||
switch (kindof(mlx_dtype)) {
|
||||
case Dtype::Kind::b:
|
||||
return BNNSDataTypeBoolean;
|
||||
case Dtype::Kind::u:
|
||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
||||
case Dtype::Kind::i:
|
||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
||||
case Dtype::Kind::f:
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
||||
case Dtype::Kind::V:
|
||||
return BNNSDataTypeBFloat16;
|
||||
case Dtype::Kind::c:
|
||||
throw std::invalid_argument("BNNS does not support complex types");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -50,6 +50,8 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
@ -71,6 +73,13 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
||||
endif()
|
||||
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||
else()
|
||||
|
@ -1128,7 +1128,7 @@ void conv_3D_cpu(
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
@ -1,11 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
@ -21,89 +16,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Scan)
|
||||
|
||||
namespace {
|
||||
|
||||
inline void matmul_common_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
auto check_transpose = [](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[Matmul::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_common_general(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy(c, out, ctype);
|
||||
|
||||
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
20
mlx/backend/common/gemm.h
Normal file
20
mlx/backend/common/gemm.h
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void matmul(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
float alpha,
|
||||
float beta);
|
||||
|
||||
} // namespace mlx::core
|
157
mlx/backend/common/gemms/bnns.cpp
Normal file
157
mlx/backend/common/gemms/bnns.cpp
Normal file
@ -0,0 +1,157 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
||||
switch (kindof(mlx_dtype)) {
|
||||
case Dtype::Kind::b:
|
||||
return BNNSDataTypeBoolean;
|
||||
case Dtype::Kind::u:
|
||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
||||
case Dtype::Kind::i:
|
||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
||||
case Dtype::Kind::f:
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
||||
case Dtype::Kind::V:
|
||||
return BNNSDataTypeBFloat16;
|
||||
case Dtype::Kind::c:
|
||||
throw std::invalid_argument("BNNS does not support complex types");
|
||||
}
|
||||
}
|
||||
|
||||
void matmul_bnns(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
float alpha,
|
||||
float beta) {
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta,
|
||||
/* bool transA = */ a_transposed,
|
||||
/* bool transB = */ b_transposed,
|
||||
/* bool quadratic = */ false,
|
||||
/* bool a_is_weights = */ false,
|
||||
/* bool b_is_weights = */ false,
|
||||
/* BNNSNDArrayDescriptor iA_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, lda, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor iB_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, ldb, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor o_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{N, M, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, N, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
};
|
||||
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
a.data<uint8_t>() +
|
||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
||||
b.data<uint8_t>() +
|
||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
||||
}
|
||||
|
||||
BNNSFilterDestroy(bnns_filter);
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
float alpha,
|
||||
float beta) {
|
||||
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
float alpha,
|
||||
float beta) {
|
||||
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
44
mlx/backend/common/gemms/cblas.cpp
Normal file
44
mlx/backend/common/gemms/cblas.cpp
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
float alpha,
|
||||
float beta) {
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
21
mlx/backend/common/gemms/no_bf16.cpp
Normal file
21
mlx/backend/common/gemms/no_bf16.cpp
Normal file
@ -0,0 +1,21 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const array&,
|
||||
const array&,
|
||||
array&,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
21
mlx/backend/common/gemms/no_fp16.cpp
Normal file
21
mlx/backend/common/gemms/no_fp16.cpp
Normal file
@ -0,0 +1,21 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const array&,
|
||||
const array&,
|
||||
array&,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
79
mlx/backend/common/matmul.cpp
Normal file
79
mlx/backend/common/matmul.cpp
Normal file
@ -0,0 +1,79 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void matmul_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
auto check_transpose = [](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
matmul<float>(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
} else if (out.dtype() == float16) {
|
||||
matmul<float16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
matmul<bfloat16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
} else {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
std::memset(out.data<void>(), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
return matmul_general(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy(c, out, ctype);
|
||||
|
||||
return matmul_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -704,8 +704,6 @@ class Convolution : public UnaryPrimitive {
|
||||
std::vector<int> input_dilation_;
|
||||
int groups_;
|
||||
bool flip_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Copy : public UnaryPrimitive {
|
||||
|
Loading…
Reference in New Issue
Block a user