mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
updated BN implementation to handle input shape as NLC and NWHC^^
This commit is contained in:
parent
9bf68814a4
commit
a1c06b7d46
@ -413,7 +413,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
|
||||
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
|
||||
|
||||
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
y = bn(x)
|
||||
|
Loading…
Reference in New Issue
Block a user