diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 36915d55d..9f0eb9ff8 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -413,7 +413,7 @@ class TestNN(mlx_tests.MLXTestCase): expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]]) self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5)) self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5)) - + x = mx.random.normal((N, L, C, L, C), dtype=mx.float32) with self.assertRaises(ValueError): y = bn(x)