mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fix BN stats to not expand shape (#409)
* fix BN stats to not expand shape * nit
This commit is contained in:
		| @@ -333,8 +333,8 @@ class BatchNorm(Module): | ||||
|         """ | ||||
|         reduction_axes = tuple(range(0, x.ndim - 1)) | ||||
|  | ||||
|         mean = mx.mean(x, axis=reduction_axes, keepdims=True) | ||||
|         var = mx.var(x, axis=reduction_axes, keepdims=True) | ||||
|         mean = mx.mean(x, axis=reduction_axes) | ||||
|         var = mx.var(x, axis=reduction_axes) | ||||
|  | ||||
|         return mean, var | ||||
|  | ||||
|   | ||||
| @@ -517,35 +517,36 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         batch_norm = nn.BatchNorm(num_features) | ||||
|  | ||||
|         batch_norm.train() | ||||
|         running_mean = np.array(batch_norm.running_mean) | ||||
|         running_var = np.array(batch_norm.running_var) | ||||
|         running_mean = batch_norm.running_mean | ||||
|         running_var = batch_norm.running_var | ||||
|  | ||||
|         data = mx.random.normal((batch_size, num_features)) | ||||
|  | ||||
|         normalized_data = batch_norm(data) | ||||
|         np_data = np.array(data) | ||||
|         means = np.mean(np_data, axis=0) | ||||
|         variances = np.var(np_data, axis=0) | ||||
|         means = mx.mean(data, axis=0) | ||||
|         variances = mx.var(data, axis=0) | ||||
|         running_mean = (1 - momentum) * running_mean + momentum * means | ||||
|         running_var = (1 - momentum) * running_var + momentum * variances | ||||
|         self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) | ||||
|         self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5)) | ||||
|         self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) | ||||
|         self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5)) | ||||
|  | ||||
|         batch_norm = nn.BatchNorm(num_features) | ||||
|  | ||||
|         batch_norm.train() | ||||
|         running_mean = np.array(batch_norm.running_mean) | ||||
|         running_var = np.array(batch_norm.running_var) | ||||
|         running_mean = batch_norm.running_mean | ||||
|         running_var = batch_norm.running_var | ||||
|         data = mx.random.normal((batch_size, h, w, num_features)) | ||||
|  | ||||
|         normalized_data = batch_norm(data) | ||||
|         np_data = np.array(data) | ||||
|         means = np.mean(np_data, axis=(0, 1, 2)) | ||||
|         variances = np.var(np_data, axis=(0, 1, 2)) | ||||
|         means = mx.mean(data, axis=(0, 1, 2)) | ||||
|         variances = mx.var(data, axis=(0, 1, 2)) | ||||
|         running_mean = (1 - momentum) * running_mean + momentum * means | ||||
|         running_var = (1 - momentum) * running_var + momentum * variances | ||||
|         self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) | ||||
|         self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5)) | ||||
|         self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) | ||||
|         self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5)) | ||||
|  | ||||
|         self.assertEqual(batch_norm.running_mean.shape, running_mean.shape) | ||||
|         self.assertEqual(batch_norm.running_var.shape, running_var.shape) | ||||
|  | ||||
|     def test_conv1d(self): | ||||
|         N = 5 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun