Add groups to Conv1d (#948)

* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Rifur13
2024-04-27 09:24:57 -04:00
committed by GitHub
parent 86f495985b
commit c4a471c99d
11 changed files with 633 additions and 55 deletions

View File

@@ -260,6 +260,110 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Steel matmul fallback
///////////////////////////////////////////////////////////////////////////////
void steel_matmul_conv_groups(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
int groups,
std::vector<array>& copies) {
using namespace mlx::steel;
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ ldd,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ K,
/* const int batch_stride_b = */ N * K,
/* const int batch_stride_d = */ N,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 1};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, groups);
std::vector<int> batch_shape = {1};
std::vector<size_t> batch_strides = {0};
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
void steel_matmul(
const Stream& s,
metal::Device& d,