|
|
|
@@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu(
|
|
|
|
|
int implicit_K = wt.size() / conv_params.O;
|
|
|
|
|
int implicit_N = conv_params.O;
|
|
|
|
|
// Prepare unfolding array
|
|
|
|
|
std::vector<int> unfolded_shape{implicit_M, implicit_K};
|
|
|
|
|
Shape unfolded_shape{implicit_M, implicit_K};
|
|
|
|
|
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
|
|
|
|
|
|
|
|
|
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
|
|
|
@@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Prepare unfolding array
|
|
|
|
|
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
|
|
|
|
|
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
|
|
|
|
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
|
|
|
|
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
|
|
|
|
|
|
|
|
@@ -192,12 +192,12 @@ void conv_1D_gpu(
|
|
|
|
|
bool flip) {
|
|
|
|
|
// Make conv params
|
|
|
|
|
MLXConvParams<1> conv_params{
|
|
|
|
|
/* const int N = */ in.shape(0),
|
|
|
|
|
/* const int C = */ in.shape(2),
|
|
|
|
|
/* const int O = */ wt.shape(0),
|
|
|
|
|
/* const int iS[NDIM] = */ {in.shape(1)},
|
|
|
|
|
/* const int wS[NDIM] = */ {wt.shape(1)},
|
|
|
|
|
/* const int oS[NDIM] = */ {out.shape(1)},
|
|
|
|
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
|
|
|
|
/* const int C = */ static_cast<int>(in.shape(2)),
|
|
|
|
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
|
|
|
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
|
|
|
|
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
|
|
|
|
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
|
|
|
|
/* const int str[NDIM] = */ {wt_strides[0]},
|
|
|
|
|
/* const int pad[NDIM] = */ {padding[0]},
|
|
|
|
|
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
|
|
|
@@ -541,7 +541,7 @@ void winograd_conv_2D_gpu(
|
|
|
|
|
array out,
|
|
|
|
|
const MLXConvParams<2>& conv_params,
|
|
|
|
|
std::vector<array>& copies_w) {
|
|
|
|
|
std::vector<int> padded_shape = {
|
|
|
|
|
Shape padded_shape = {
|
|
|
|
|
conv_params.N,
|
|
|
|
|
conv_params.iS[0] + 2 * conv_params.pad[0],
|
|
|
|
|
conv_params.iS[1] + 2 * conv_params.pad[1],
|
|
|
|
@@ -550,7 +550,7 @@ void winograd_conv_2D_gpu(
|
|
|
|
|
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
|
|
|
|
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
|
|
|
|
|
|
|
|
|
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
|
|
|
|
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
|
|
|
|
|
|
|
|
|
|
// Fill with zeros
|
|
|
|
|
array zero_arr = array(0, in.dtype());
|
|
|
|
@@ -575,12 +575,16 @@ void winograd_conv_2D_gpu(
|
|
|
|
|
copies_w.push_back(in_padded);
|
|
|
|
|
|
|
|
|
|
MLXConvParams<2> conv_params_updated{
|
|
|
|
|
/* const int N = */ in_padded.shape(0),
|
|
|
|
|
/* const int C = */ in_padded.shape(3),
|
|
|
|
|
/* const int O = */ wt.shape(0),
|
|
|
|
|
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
|
|
|
|
|
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
|
|
|
|
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
|
|
|
|
/* const int N = */ static_cast<int>(in_padded.shape(0)),
|
|
|
|
|
/* const int C = */ static_cast<int>(in_padded.shape(3)),
|
|
|
|
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
|
|
|
/* const int iS[NDIM] = */
|
|
|
|
|
{static_cast<int>(in_padded.shape(1)),
|
|
|
|
|
static_cast<int>(in_padded.shape(2))},
|
|
|
|
|
/* const int wS[NDIM] = */
|
|
|
|
|
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
|
|
|
|
/* const int oS[NDIM] = */
|
|
|
|
|
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
|
|
|
|
/* const int str[NDIM] = */ {1, 1},
|
|
|
|
|
/* const int pad[NDIM] = */ {0, 0},
|
|
|
|
|
/* const int kdil[NDIM] = */ {1, 1},
|
|
|
|
@@ -607,8 +611,8 @@ void winograd_conv_2D_gpu(
|
|
|
|
|
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
|
|
|
|
|
|
|
|
|
// Do filter transform
|
|
|
|
|
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
|
|
|
|
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
|
|
|
|
|
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
|
|
|
|
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
|
|
|
|
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
|
|
|
|
copies_w.push_back(filt_wg);
|
|
|
|
|
{
|
|
|
|
@@ -634,8 +638,8 @@ void winograd_conv_2D_gpu(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Do input transform
|
|
|
|
|
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
|
|
|
|
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
|
|
|
|
|
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
|
|
|
|
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
|
|
|
|
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
|
|
|
|
copies_w.push_back(inp_wg);
|
|
|
|
|
{
|
|
|
|
@@ -661,8 +665,8 @@ void winograd_conv_2D_gpu(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Do batched gemm
|
|
|
|
|
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
|
|
|
|
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
|
|
|
|
|
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
|
|
|
|
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
|
|
|
|
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
|
|
|
|
copies_w.push_back(out_wg);
|
|
|
|
|
{
|
|
|
|
@@ -723,12 +727,15 @@ void conv_2D_gpu(
|
|
|
|
|
std::vector<array>& copies) {
|
|
|
|
|
// Make conv params
|
|
|
|
|
MLXConvParams<2> conv_params{
|
|
|
|
|
/* const int N = */ in.shape(0),
|
|
|
|
|
/* const int C = */ in.shape(3),
|
|
|
|
|
/* const int O = */ wt.shape(0),
|
|
|
|
|
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
|
|
|
|
|
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
|
|
|
|
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
|
|
|
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
|
|
|
|
/* const int C = */ static_cast<int>(in.shape(3)),
|
|
|
|
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
|
|
|
/* const int iS[NDIM] = */
|
|
|
|
|
{static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
|
|
|
|
|
/* const int wS[NDIM] = */
|
|
|
|
|
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
|
|
|
|
/* const int oS[NDIM] = */
|
|
|
|
|
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
|
|
|
|
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
|
|
|
|
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
|
|
|
|
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
|
|
|
@@ -800,12 +807,21 @@ void conv_3D_gpu(
|
|
|
|
|
std::vector<array>& copies) {
|
|
|
|
|
// Make conv params
|
|
|
|
|
MLXConvParams<3> conv_params{
|
|
|
|
|
/* const int N = */ in.shape(0),
|
|
|
|
|
/* const int C = */ in.shape(4),
|
|
|
|
|
/* const int O = */ wt.shape(0),
|
|
|
|
|
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
|
|
|
|
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
|
|
|
|
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
|
|
|
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
|
|
|
|
/* const int C = */ static_cast<int>(in.shape(4)),
|
|
|
|
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
|
|
|
/* const int iS[NDIM] = */
|
|
|
|
|
{static_cast<int>(in.shape(1)),
|
|
|
|
|
static_cast<int>(in.shape(2)),
|
|
|
|
|
static_cast<int>(in.shape(3))},
|
|
|
|
|
/* const int wS[NDIM] = */
|
|
|
|
|
{static_cast<int>(wt.shape(1)),
|
|
|
|
|
static_cast<int>(wt.shape(2)),
|
|
|
|
|
static_cast<int>(wt.shape(3))},
|
|
|
|
|
/* const int oS[NDIM] = */
|
|
|
|
|
{static_cast<int>(out.shape(1)),
|
|
|
|
|
static_cast<int>(out.shape(2)),
|
|
|
|
|
static_cast<int>(out.shape(3))},
|
|
|
|
|
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
|
|
|
|
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
|
|
|
|
/* const int kdil[NDIM] = */
|
|
|
|
|