mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Conv cpu improvements (#1410)
This commit is contained in:

committed by
GitHub

parent
d6492b0163
commit
adcc88e208
@@ -684,6 +684,32 @@ void dispatch_slow_conv_3D(
|
||||
// Explicit gemm conv
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
void flip_spatial_dims_inplace(array& wt) {
|
||||
T* x = wt.data<T>();
|
||||
size_t out_channels = wt.shape(0);
|
||||
size_t in_channels = wt.shape(-1);
|
||||
|
||||
// Calculate the total size of the spatial dimensions
|
||||
int spatial_size = 1;
|
||||
for (int d = 1; d < wt.ndim() - 1; ++d) {
|
||||
spatial_size *= wt.shape(d);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < out_channels; i++) {
|
||||
T* top = x + i * spatial_size * in_channels;
|
||||
T* bottom =
|
||||
x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;
|
||||
for (size_t j = 0; j < spatial_size / 2; j++) {
|
||||
for (size_t k = 0; k < in_channels; k++) {
|
||||
std::swap(top[k], bottom[k]);
|
||||
}
|
||||
top += in_channels;
|
||||
bottom -= in_channels;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_1D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
@@ -910,7 +936,8 @@ void explicit_gemm_conv_ND_cpu(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const bool flip) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const auto iDim = std::vector<int>(
|
||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
@@ -1000,6 +1027,14 @@ void explicit_gemm_conv_ND_cpu(
|
||||
copy(wt, gemm_wt, ctype);
|
||||
}
|
||||
|
||||
if (flip) {
|
||||
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
||||
copy(gemm_wt, gemm_wt_, CopyType::Vector);
|
||||
|
||||
flip_spatial_dims_inplace<float>(gemm_wt_);
|
||||
gemm_wt = gemm_wt_;
|
||||
}
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
@@ -1042,10 +1077,15 @@ void conv_1D_cpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const int groups = in.shape().back() / wt.shape().back();
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||
return explicit_gemm_conv_1D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
}
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_1D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
@@ -1060,6 +1100,13 @@ void conv_2D_cpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const int groups = in.shape().back() / wt.shape().back();
|
||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
||||
in_dilation[1] == 1 && groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_2D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
@@ -1073,6 +1120,14 @@ void conv_3D_cpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const int groups = in.shape().back() / wt.shape().back();
|
||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&
|
||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
||||
groups == 1) {
|
||||
return explicit_gemm_conv_ND_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, flip);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_3D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
@@ -136,6 +136,167 @@ inline void copy_general_dim4(const array& src, array& dst) {
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim5(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
|
||||
// Pre-compute loop bounds and strides
|
||||
const int d0 = data_shape[0], d1 = data_shape[1], d2 = data_shape[2],
|
||||
d3 = data_shape[3], d4 = data_shape[4];
|
||||
const stride_t s0 = i_strides[0], s1 = i_strides[1], s2 = i_strides[2],
|
||||
s3 = i_strides[3], s4 = i_strides[4];
|
||||
|
||||
// Pre-compute stride adjustments
|
||||
const stride_t s3_adj = s3 - s4 * d4;
|
||||
const stride_t s2_adj = s2 - s3 * d3;
|
||||
const stride_t s1_adj = s1 - s2 * d2;
|
||||
const stride_t s0_adj = s0 - s1 * d1;
|
||||
|
||||
stride_t src_idx = 0;
|
||||
stride_t dst_idx = 0;
|
||||
|
||||
for (int i = 0; i < d0; ++i) {
|
||||
for (int j = 0; j < d1; ++j) {
|
||||
for (int k = 0; k < d2; ++k) {
|
||||
for (int l = 0; l < d3; ++l) {
|
||||
for (int m = 0; m < d4; ++m) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += s4;
|
||||
}
|
||||
src_idx += s3_adj;
|
||||
}
|
||||
src_idx += s2_adj;
|
||||
}
|
||||
src_idx += s1_adj;
|
||||
}
|
||||
src_idx += s0_adj;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_dim5(const array& src, array& dst) {
|
||||
return copy_general_dim5<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim6(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
|
||||
// Pre-compute loop bounds and strides
|
||||
const int d0 = data_shape[0], d1 = data_shape[1], d2 = data_shape[2],
|
||||
d3 = data_shape[3], d4 = data_shape[4], d5 = data_shape[5];
|
||||
const stride_t s0 = i_strides[0], s1 = i_strides[1], s2 = i_strides[2],
|
||||
s3 = i_strides[3], s4 = i_strides[4], s5 = i_strides[5];
|
||||
|
||||
// Pre-compute stride adjustments
|
||||
const stride_t s4_adj = s4 - s5 * d5;
|
||||
const stride_t s3_adj = s3 - s4 * d4;
|
||||
const stride_t s2_adj = s2 - s3 * d3;
|
||||
const stride_t s1_adj = s1 - s2 * d2;
|
||||
const stride_t s0_adj = s0 - s1 * d1;
|
||||
|
||||
stride_t src_idx = 0;
|
||||
stride_t dst_idx = 0;
|
||||
|
||||
for (int i = 0; i < d0; ++i) {
|
||||
for (int j = 0; j < d1; ++j) {
|
||||
for (int k = 0; k < d2; ++k) {
|
||||
for (int l = 0; l < d3; ++l) {
|
||||
for (int m = 0; m < d4; ++m) {
|
||||
for (int n = 0; n < d5; ++n) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += s5;
|
||||
}
|
||||
src_idx += s4_adj;
|
||||
}
|
||||
src_idx += s3_adj;
|
||||
}
|
||||
src_idx += s2_adj;
|
||||
}
|
||||
src_idx += s1_adj;
|
||||
}
|
||||
src_idx += s0_adj;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_dim6(const array& src, array& dst) {
|
||||
return copy_general_dim6<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim7(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
|
||||
// Pre-compute loop bounds and strides
|
||||
const int d0 = data_shape[0], d1 = data_shape[1], d2 = data_shape[2],
|
||||
d3 = data_shape[3], d4 = data_shape[4], d5 = data_shape[5],
|
||||
d6 = data_shape[6];
|
||||
const stride_t s0 = i_strides[0], s1 = i_strides[1], s2 = i_strides[2],
|
||||
s3 = i_strides[3], s4 = i_strides[4], s5 = i_strides[5],
|
||||
s6 = i_strides[6];
|
||||
|
||||
// Pre-compute stride adjustments
|
||||
const stride_t s5_adj = s5 - s6 * d6;
|
||||
const stride_t s4_adj = s4 - s5 * d5;
|
||||
const stride_t s3_adj = s3 - s4 * d4;
|
||||
const stride_t s2_adj = s2 - s3 * d3;
|
||||
const stride_t s1_adj = s1 - s2 * d2;
|
||||
const stride_t s0_adj = s0 - s1 * d1;
|
||||
|
||||
stride_t src_idx = 0;
|
||||
stride_t dst_idx = 0;
|
||||
|
||||
for (int i = 0; i < d0; ++i) {
|
||||
for (int j = 0; j < d1; ++j) {
|
||||
for (int k = 0; k < d2; ++k) {
|
||||
for (int l = 0; l < d3; ++l) {
|
||||
for (int m = 0; m < d4; ++m) {
|
||||
for (int n = 0; n < d5; ++n) {
|
||||
for (int p = 0; p < d6; ++p) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += s6;
|
||||
}
|
||||
src_idx += s5_adj;
|
||||
}
|
||||
src_idx += s4_adj;
|
||||
}
|
||||
src_idx += s3_adj;
|
||||
}
|
||||
src_idx += s2_adj;
|
||||
}
|
||||
src_idx += s1_adj;
|
||||
}
|
||||
src_idx += s0_adj;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_dim7(const array& src, array& dst) {
|
||||
return copy_general_dim7<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
@@ -162,6 +323,18 @@ void copy_general(
|
||||
copy_general_dim4<SrcT, DstT, stride_t>(
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
return;
|
||||
case 5:
|
||||
copy_general_dim5<SrcT, DstT, stride_t>(
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
return;
|
||||
case 6:
|
||||
copy_general_dim6<SrcT, DstT, stride_t>(
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
return;
|
||||
case 7:
|
||||
copy_general_dim7<SrcT, DstT, stride_t>(
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
|
Reference in New Issue
Block a user