mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Fix overflow in large filter small channels (#2520)
This commit is contained in:
parent
512281781c
commit
25c1e03205
@ -83,7 +83,7 @@ struct Conv2DInputBlockLoaderSmallChannels {
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant ImplicitGemmConv2DParams* gemm_params;
|
||||
|
||||
short weight_hw;
|
||||
int weight_hw;
|
||||
|
||||
const device T* src[n_rows];
|
||||
|
||||
|
@ -1186,6 +1186,13 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
y_hat = mx.conv2d(x, w)
|
||||
self.assertTrue(mx.allclose(y, y_hat))
|
||||
|
||||
def test_conv2d_large_filter_small_channels(self):
|
||||
x = mx.random.normal(shape=(1, 181, 181, 1))
|
||||
w = mx.random.normal(shape=(1, 182, 182, 1))
|
||||
y = mx.conv2d(x, w, (1, 1), (1, 1), stream=mx.cpu)
|
||||
y_hat = mx.conv2d(x, w, (1, 1), (1, 1))
|
||||
self.assertTrue(mx.allclose(y, y_hat, rtol=1e-3, atol=1e-3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Loading…
Reference in New Issue
Block a user