From a1c06b7d46375c7eb61c2600c364c0bd405d7718 Mon Sep 17 00:00:00 2001 From: m0saan Date: Sun, 24 Dec 2023 23:05:18 +0100 Subject: [PATCH] updated BN implementation to handle input shape as NLC and NWHC^^ --- python/tests/test_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 36915d55d..9f0eb9ff8 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -413,7 +413,7 @@ class TestNN(mlx_tests.MLXTestCase): expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]]) self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5)) self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5)) - + x = mx.random.normal((N, L, C, L, C), dtype=mx.float32) with self.assertRaises(ValueError): y = bn(x)