mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	gelu tanh approx (#989)
* gelu tanh approx * gelu tanh approx * replace gelu approx with tanh approach * fix comments * fix comment
This commit is contained in:
		| @@ -717,16 +717,22 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         expected = np.array( | ||||
|             [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383] | ||||
|         ) | ||||
|         # From: jax.nn.gelu(np.array(inputs), approximate=True) | ||||
|         expected_approx = np.array( | ||||
|             [1.0091482, -0.1693441, 0.22918446, 0.60491, 0.4945476] | ||||
|         ) | ||||
|  | ||||
|         out = nn.GELU()(mx.array(inputs)) | ||||
|         self.assertTrue(np.allclose(out, expected)) | ||||
|         out_approx = nn.GELU(approx="precise")(mx.array(inputs)) | ||||
|         self.assertTrue(np.allclose(out_approx, expected_approx)) | ||||
|  | ||||
|         # Crudely check the approximations | ||||
|         x = mx.arange(-6.0, 6.0, 12 / 100) | ||||
|         y = nn.gelu(x) | ||||
|         y_hat1 = nn.gelu_approx(x) | ||||
|         y_hat2 = nn.gelu_fast_approx(x) | ||||
|         self.assertLess(mx.abs(y - y_hat1).max(), 0.0003) | ||||
|         self.assertLess(mx.abs(y - y_hat1).max(), 0.0005) | ||||
|         self.assertLess(mx.abs(y - y_hat2).max(), 0.025) | ||||
|  | ||||
|     def test_sin_pe(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Shiyu
					Shiyu