mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 01:46:37 +08:00
The naive_conv_2d is no longer used (#2496)
This commit is contained in:
parent
6441c21a94
commit
4abb218d21
@ -166,115 +166,6 @@ instantiate_naive_unfold_nd_dims(float32, float);
|
||||
instantiate_naive_unfold_nd_dims(float16, half);
|
||||
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Slow and naive conv2d kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const int BC = 16>
|
||||
[[kernel]] void naive_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
out += tid.z * params.out_strides[0];
|
||||
in += tid.z * params.in_strides[0];
|
||||
|
||||
int out_o = tid.y * BN * TN + lid.y * TN;
|
||||
int out_hw = tid.x * BM * TM + lid.x * TM;
|
||||
|
||||
int out_h[TM];
|
||||
int out_w[TN];
|
||||
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
int mm = (out_hw + m);
|
||||
out_h[m] = mm / params.oS[1];
|
||||
out_w[m] = mm % params.oS[1];
|
||||
}
|
||||
|
||||
T in_local[TM];
|
||||
T wt_local[TN];
|
||||
T out_local[TM * TN] = {T(0)};
|
||||
|
||||
for (int h = 0; h < params.wS[0]; ++h) {
|
||||
for (int w = 0; w < params.wS[1]; ++w) {
|
||||
for (int c = 0; c < params.C; ++c) {
|
||||
// Local in
|
||||
for (int m = 0; m < TM; m++) {
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
||||
|
||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||
in_local[m] = valid
|
||||
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
|
||||
: T(0);
|
||||
}
|
||||
|
||||
// Load weight
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
int o = out_o + n;
|
||||
wt_local[n] = o < params.O
|
||||
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
|
||||
w * params.wt_strides[2] + c]
|
||||
: T(0);
|
||||
}
|
||||
|
||||
// Accumulate
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
|
||||
(out_o + n) < params.O)
|
||||
out[out_h[m] * params.out_strides[1] +
|
||||
out_w[m] * params.out_strides[2] + out_o + n] =
|
||||
out_local[m * TN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiations
|
||||
|
||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_naive_conv_2d_blocks(name, itype) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
||||
|
||||
instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Depthwise convolution kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
Reference in New Issue
Block a user