mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix conv2d bug + faster conv 1d (#2195)
* fix conv2d bug + faster conv 1d * revert sort + flaky test
This commit is contained in:
parent
0654543dcc
commit
8576e6fe36
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
@ -178,83 +177,6 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
/*copies = */ copies);
|
/*copies = */ copies);
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_1D_gpu(
|
|
||||||
const Stream& s,
|
|
||||||
metal::Device& d,
|
|
||||||
const array& in,
|
|
||||||
const array& wt,
|
|
||||||
array out,
|
|
||||||
const std::vector<int>& padding,
|
|
||||||
const std::vector<int>& wt_strides,
|
|
||||||
const std::vector<int>& wt_dilation,
|
|
||||||
const std::vector<int>& in_dilation,
|
|
||||||
int groups,
|
|
||||||
bool flip) {
|
|
||||||
// Make conv params
|
|
||||||
MLXConvParams<1> conv_params{
|
|
||||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
|
||||||
/* const int C = */ static_cast<int>(in.shape(2)),
|
|
||||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
||||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
|
||||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
|
||||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
|
||||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
|
||||||
/* const int pad[NDIM] = */ {padding[0]},
|
|
||||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
|
||||||
/* const int idil[NDIM] = */ {in_dilation[0]},
|
|
||||||
/* const size_t in_strides[NDIM + 2] = */
|
|
||||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
|
||||||
/* const size_t wt_strides[NDIM + 2] = */
|
|
||||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
|
||||||
/* const size_t out_strides[NDIM + 2] = */
|
|
||||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
|
||||||
/* const int groups = */ groups,
|
|
||||||
/* const bool flip = */ flip};
|
|
||||||
|
|
||||||
// Direct to explicit gemm conv
|
|
||||||
if (groups > 1) {
|
|
||||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
} else {
|
|
||||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void slow_conv_2D_gpu(
|
|
||||||
const Stream& s,
|
|
||||||
metal::Device& d,
|
|
||||||
const array& in,
|
|
||||||
const array& wt,
|
|
||||||
array out,
|
|
||||||
const MLXConvParams<2>& conv_params) {
|
|
||||||
int bm = 16, bn = 8;
|
|
||||||
int tm = 4, tn = 4;
|
|
||||||
|
|
||||||
std::ostringstream kname;
|
|
||||||
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
|
|
||||||
<< "_tm" << tm << "_tn" << tn;
|
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
|
||||||
|
|
||||||
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
|
|
||||||
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
|
|
||||||
size_t grid_dim_z = conv_params.N;
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
|
||||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
|
||||||
compute_encoder.set_input_array(wt, 1);
|
|
||||||
compute_encoder.set_output_array(out, 2);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(conv_params, 3);
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void implicit_gemm_conv_2D_gpu(
|
void implicit_gemm_conv_2D_gpu(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@ -771,6 +693,141 @@ void depthwise_conv_2D_gpu(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dispatch_conv_2D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const MLXConvParams<2>& conv_params,
|
||||||
|
std::vector<array>& copies) {
|
||||||
|
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
||||||
|
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||||
|
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||||
|
|
||||||
|
if (is_idil_one && conv_params.groups > 1) {
|
||||||
|
const int C_per_group = conv_params.C / conv_params.groups;
|
||||||
|
const int O_per_group = conv_params.O / conv_params.groups;
|
||||||
|
|
||||||
|
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||||
|
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||||
|
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||||
|
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||||
|
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||||
|
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||||
|
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||||
|
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||||
|
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
} else {
|
||||||
|
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to winograd conv
|
||||||
|
bool inp_large =
|
||||||
|
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||||
|
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||||
|
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||||
|
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||||
|
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||||
|
channels_large) {
|
||||||
|
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to implicit gemm conv
|
||||||
|
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
||||||
|
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
||||||
|
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
||||||
|
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to explicit gemm conv
|
||||||
|
else {
|
||||||
|
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void conv_1D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation,
|
||||||
|
const std::vector<int>& in_dilation,
|
||||||
|
int groups,
|
||||||
|
bool flip,
|
||||||
|
std::vector<array>& copies) {
|
||||||
|
bool is_idil_one = in_dilation[0] == 1;
|
||||||
|
int C = in.shape(2);
|
||||||
|
int O = wt.shape(0);
|
||||||
|
const int C_per_group = in.shape(2) / groups;
|
||||||
|
const int O_per_group = wt.shape(0) / groups;
|
||||||
|
|
||||||
|
// Direct to implicit gemm conv
|
||||||
|
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||||
|
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||||
|
MLXConvParams<2> conv_params{
|
||||||
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||||
|
/* const int C = */ C,
|
||||||
|
/* const int O = */ O,
|
||||||
|
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
|
||||||
|
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
|
||||||
|
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
|
||||||
|
/* const int str[NDIM] = */ {wt_strides[0], 1},
|
||||||
|
/* const int pad[NDIM] = */ {padding[0], 0},
|
||||||
|
/* const int kdil[NDIM] = */ {wt_dilation[0], 1},
|
||||||
|
/* const int idil[NDIM] = */ {in_dilation[0], 1},
|
||||||
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
|
{in.strides()[0], in.strides()[1], 0, in.strides()[2]},
|
||||||
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
|
{wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
|
||||||
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
|
{out.strides()[0], out.strides()[1], 0, out.strides()[2]},
|
||||||
|
/* const int groups = */ groups,
|
||||||
|
/* const bool flip = */ flip};
|
||||||
|
|
||||||
|
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make conv params
|
||||||
|
MLXConvParams<1> conv_params{
|
||||||
|
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||||
|
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||||
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||||
|
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||||
|
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||||
|
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||||
|
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||||
|
/* const int pad[NDIM] = */ {padding[0]},
|
||||||
|
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||||
|
/* const int idil[NDIM] = */ {in_dilation[0]},
|
||||||
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
|
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||||
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
|
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||||
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
|
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||||
|
/* const int groups = */ groups,
|
||||||
|
/* const bool flip = */ flip};
|
||||||
|
|
||||||
|
// Direct to explicit gemm conv
|
||||||
|
if (groups > 1) {
|
||||||
|
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
} else {
|
||||||
|
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void conv_2D_gpu(
|
void conv_2D_gpu(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@ -808,57 +865,7 @@ void conv_2D_gpu(
|
|||||||
/* const int groups = */ groups,
|
/* const int groups = */ groups,
|
||||||
/* const bool flip = */ flip,
|
/* const bool flip = */ flip,
|
||||||
};
|
};
|
||||||
|
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||||
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
|
||||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
|
||||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
|
||||||
|
|
||||||
if (is_idil_one && groups > 1) {
|
|
||||||
const int C_per_group = conv_params.C / groups;
|
|
||||||
const int O_per_group = conv_params.O / groups;
|
|
||||||
|
|
||||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
|
||||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
|
||||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
|
||||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
|
||||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
|
||||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
|
||||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
|
||||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
|
||||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
} else {
|
|
||||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct to winograd conv
|
|
||||||
bool inp_large =
|
|
||||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
|
||||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
|
||||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
|
||||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
|
||||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
|
||||||
channels_large) {
|
|
||||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct to implicit gemm conv
|
|
||||||
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
|
||||||
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
|
||||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
|
|
||||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
|
||||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct to explicit gemm conv
|
|
||||||
else {
|
|
||||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_3D_gpu(
|
void conv_3D_gpu(
|
||||||
@ -988,7 +995,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
groups_,
|
groups_,
|
||||||
flip_);
|
flip_,
|
||||||
|
copies);
|
||||||
}
|
}
|
||||||
// Throw error
|
// Throw error
|
||||||
else {
|
else {
|
||||||
|
@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader {
|
|||||||
const constant MLXConvParams<2>* params;
|
const constant MLXConvParams<2>* params;
|
||||||
|
|
||||||
int weight_hw;
|
int weight_hw;
|
||||||
|
int weight_step;
|
||||||
|
|
||||||
const int read_n;
|
const int read_n;
|
||||||
const bool do_read;
|
const bool do_read;
|
||||||
@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader {
|
|||||||
src(src_ + bi * src_ld + bj),
|
src(src_ + bi * src_ld + bj),
|
||||||
params(params_),
|
params(params_),
|
||||||
weight_hw(0),
|
weight_hw(0),
|
||||||
|
weight_step(params->C / params->groups),
|
||||||
read_n(offsets.y + bi),
|
read_n(offsets.y + bi),
|
||||||
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
|
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
|
||||||
|
|
||||||
@ -435,13 +437,13 @@ struct Conv2DWeightBlockLoader {
|
|||||||
/* Iteration helper */
|
/* Iteration helper */
|
||||||
METAL_FUNC void next() {
|
METAL_FUNC void next() {
|
||||||
if (++weight_hw < (params->wS[1] * params->wS[0])) {
|
if (++weight_hw < (params->wS[1] * params->wS[0])) {
|
||||||
src += params->wt_strides[2];
|
src += weight_step;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
weight_hw = 0;
|
weight_hw = 0;
|
||||||
|
|
||||||
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
|
src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const device T* curr_src = src + weight_hw * params->wt_strides[2];
|
const device T* curr_src = src + weight_hw * (params->C / params->groups);
|
||||||
|
|
||||||
if (BN != 8 || do_read) {
|
if (BN != 8 || do_read) {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
|
@ -3584,21 +3584,21 @@ Shape conv_out_shape(
|
|||||||
|
|
||||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << "for "
|
msg << "[conv] Invalid padding " << pads_lo << " | " << pads_hi << " for "
|
||||||
<< spatial_dims << "D convolution.";
|
<< spatial_dims << "D convolution.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kernel_dilation.size() != spatial_dims) {
|
if (kernel_dilation.size() != spatial_dims) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for "
|
msg << "[conv] Invalid kernel dilation " << kernel_dilation << " for "
|
||||||
<< spatial_dims << "D convolution.";
|
<< spatial_dims << "D convolution.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (input_dilation.size() != spatial_dims) {
|
if (input_dilation.size() != spatial_dims) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[conv] Invalid input dilation " << input_dilation << "for "
|
msg << "[conv] Invalid input dilation " << input_dilation << " for "
|
||||||
<< spatial_dims << "D convolution.";
|
<< spatial_dims << "D convolution.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
@ -1152,6 +1152,27 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(grads.shape, k_shape)
|
self.assertEqual(grads.shape, k_shape)
|
||||||
|
|
||||||
|
def test_1d_conv_with_2d(self):
|
||||||
|
x = mx.random.uniform(shape=(2, 10, 16))
|
||||||
|
y = mx.random.normal(shape=(16, 3, 16))
|
||||||
|
|
||||||
|
out = mx.conv1d(x, y, padding=1)
|
||||||
|
out_2d = mx.conv2d(
|
||||||
|
mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 10, 4))
|
||||||
|
y = mx.random.normal(shape=(4, 3, 4))
|
||||||
|
|
||||||
|
out = mx.conv1d(x, y, padding=1)
|
||||||
|
out_2d = mx.conv2d(
|
||||||
|
mx.expand_dims(x, axis=2), mx.expand_dims(y, axis=2), padding=(1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||||
|
|
||||||
def test_leaks(self):
|
def test_leaks(self):
|
||||||
|
mx.synchronize()
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mem_pre = mx.get_active_memory()
|
mem_pre = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user