mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups * fix * fix * fix + test
This commit is contained in:
@@ -72,7 +72,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_reshaped};
|
||||
std::vector<array> copies = {in_unfolded};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
@@ -155,22 +155,27 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
|
||||
return steel_matmul_conv_groups(
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
return steel_matmul_regular(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt_transpose,
|
||||
/*c = */ out,
|
||||
/*M = */ implicit_M,
|
||||
/*N = */ implicit_N,
|
||||
/*K = */ implicit_K,
|
||||
/*a_cols = */ implicit_K * groups,
|
||||
/*b_cols = */ implicit_K,
|
||||
/*out_cols = */ implicit_N * groups,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/* groups = */ groups,
|
||||
/* a = */ in_unfolded,
|
||||
/* b = */ wt_transpose,
|
||||
/* c = */ out,
|
||||
/* M = */ implicit_M,
|
||||
/* N = */ implicit_N,
|
||||
/* K = */ implicit_K,
|
||||
/* batch_size_out = */ groups,
|
||||
/* a_cols = */ implicit_K * groups,
|
||||
/* b_cols = */ implicit_K,
|
||||
/* out_cols = */ implicit_N * groups,
|
||||
/* a_transposed = */ false,
|
||||
/* b_transposed = */ true,
|
||||
/* batch_shape = */ {1},
|
||||
/* batch_strides = */ {0},
|
||||
/* A_batch_strides = */ size_t(implicit_K),
|
||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
||||
/* matrix_stride_out = */ size_t(implicit_N),
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user