Conv grad with groups + bugfix (#1449)

* fix bug in flipped conv with groups, start of grad for groups

* fix

* fix

* fix + test
This commit is contained in:
Awni Hannun
2024-10-06 07:08:53 -07:00
committed by GitHub
parent fef3c4ec1d
commit e4534dac17
6 changed files with 197 additions and 176 deletions

View File

@@ -72,7 +72,7 @@ void explicit_gemm_conv_ND_gpu(
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_reshaped};
std::vector<array> copies = {in_unfolded};
return steel_matmul(
s,
d,
@@ -155,22 +155,27 @@ void explicit_gemm_conv_group_ND_gpu(
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
return steel_matmul_conv_groups(
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_transpose,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*a_cols = */ implicit_K * groups,
/*b_cols = */ implicit_K,
/*out_cols = */ implicit_N * groups,
/*a_transposed = */ false,
/*b_transposed = */ true,
/* groups = */ groups,
/* a = */ in_unfolded,
/* b = */ wt_transpose,
/* c = */ out,
/* M = */ implicit_M,
/* N = */ implicit_N,
/* K = */ implicit_K,
/* batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups,
/* b_cols = */ implicit_K,
/* out_cols = */ implicit_N * groups,
/* a_transposed = */ false,
/* b_transposed = */ true,
/* batch_shape = */ {1},
/* batch_strides = */ {0},
/* A_batch_strides = */ size_t(implicit_K),
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
/* matrix_stride_out = */ size_t(implicit_N),
/*copies = */ copies);
}