updated BN implementation to handle input shape as NLC and NWHC^^

This commit is contained in:
m0saan 2023-12-24 23:05:18 +01:00
parent 9bf68814a4
commit a1c06b7d46

View File

@ -413,7 +413,7 @@ class TestNN(mlx_tests.MLXTestCase):
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
with self.assertRaises(ValueError):
y = bn(x)