mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions This commit adds the Step, ELU, SELU and Swish activations functions * add to the docs * review
This commit is contained in:
		 Nicholas Santavas
					Nicholas Santavas
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							b9226c367c
						
					
				
				
					commit
					f5df47ec6e
				
			| @@ -223,6 +223,20 @@ def topk(axis, x): | ||||
|     mx.eval(ys) | ||||
|  | ||||
|  | ||||
| def step_function(x): | ||||
|     y = x | ||||
|     for i in range(100): | ||||
|         y = nn.step(x) | ||||
|     mx.eval(y) | ||||
|  | ||||
|  | ||||
| def selu(x): | ||||
|     y = x | ||||
|     for i in range(100): | ||||
|         y = nn.selu(x) | ||||
|     mx.eval(y) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument("benchmark", help="Choose the benchmark to run") | ||||
| @@ -372,5 +386,11 @@ if __name__ == "__main__": | ||||
|     elif args.benchmark == "topk": | ||||
|         print(bench(topk, axis, x)) | ||||
|  | ||||
|     elif args.benchmark == "step": | ||||
|         print(bench(step_function, x)) | ||||
|  | ||||
|     elif args.benchmark == "selu": | ||||
|         print(bench(selu, x)) | ||||
|  | ||||
|     else: | ||||
|         raise ValueError("Unknown benchmark") | ||||
|   | ||||
| @@ -257,6 +257,14 @@ def topk(axis, x): | ||||
|     sync_if_needed(x) | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def selu(x): | ||||
|     y = x | ||||
|     for i in range(100): | ||||
|         y = torch.nn.functional.selu(y) | ||||
|     sync_if_needed(x) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument("benchmark", help="Choose the benchmark to run") | ||||
|   | ||||
| @@ -205,6 +205,10 @@ if __name__ == "__main__": | ||||
|     compare_filtered("celu --size 32x16x1024 --cpu") | ||||
|     compare_filtered("log_sigmoid --size 32x16x1024") | ||||
|     compare_filtered("log_sigmoid --size 32x16x1024 --cpu") | ||||
|     compare_filtered("step --size 32x16x1024") | ||||
|     compare_filtered("step --size 32x16x1024 --cpu") | ||||
|     compare_filtered("selu --size 32x16x1024") | ||||
|     compare_filtered("selu --size 32x16x1024 --cpu") | ||||
|     compare_filtered("scalar_mul --size 32x16x1024") | ||||
|     compare_filtered("scalar_mul --size 32x16x1024 --cpu") | ||||
|     compare_filtered("cross_entropy --size 256x1024") | ||||
|   | ||||
| @@ -97,7 +97,7 @@ Updating the parameters | ||||
|  | ||||
| MLX modules allow accessing and updating individual parameters. However, most | ||||
| times we need to update large subsets of a module's parameters. This action is | ||||
| performed by :meth:`Module.update`.  | ||||
| performed by :meth:`Module.update`. | ||||
|  | ||||
| Value and grad | ||||
| -------------- | ||||
| @@ -148,6 +148,8 @@ Neural Network Layers | ||||
|    ReLU | ||||
|    GELU | ||||
|    SiLU | ||||
|    Step | ||||
|    SELU | ||||
|    Linear | ||||
|    Conv1d | ||||
|    Conv2d | ||||
| @@ -170,6 +172,8 @@ simple functions. | ||||
|    gelu_fast_approx | ||||
|    relu | ||||
|    silu | ||||
|    step | ||||
|    selu | ||||
|  | ||||
| Loss Functions | ||||
| -------------- | ||||
|   | ||||
| @@ -4,12 +4,14 @@ from mlx.nn.layers.activations import ( | ||||
|     CELU, | ||||
|     ELU, | ||||
|     GELU, | ||||
|     SELU, | ||||
|     LeakyReLU, | ||||
|     LogSigmoid, | ||||
|     ReLU, | ||||
|     ReLU6, | ||||
|     SiLU, | ||||
|     Softplus, | ||||
|     Step, | ||||
|     celu, | ||||
|     elu, | ||||
|     gelu, | ||||
| @@ -19,8 +21,10 @@ from mlx.nn.layers.activations import ( | ||||
|     log_sigmoid, | ||||
|     relu, | ||||
|     relu6, | ||||
|     selu, | ||||
|     silu, | ||||
|     softplus, | ||||
|     step, | ||||
| ) | ||||
| from mlx.nn.layers.base import Module | ||||
| from mlx.nn.layers.containers import Sequential | ||||
|   | ||||
| @@ -74,7 +74,7 @@ def celu(x, alpha=1.0): | ||||
|  | ||||
|  | ||||
| def silu(x): | ||||
|     r"""Applies the Sigmoid Linear Unit. | ||||
|     r"""Applies the Sigmoid Linear Unit. Also known as Swish. | ||||
|  | ||||
|     Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is | ||||
|     the logistic sigmoid. | ||||
| @@ -143,6 +143,41 @@ class Sigmoid(Module): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def step(x: mx.array, threshold: float = 0.0): | ||||
|     r"""Applies the Step Activation Function. | ||||
|  | ||||
|     This function implements a binary step activation, where the output is set | ||||
|     to 1 if the input is greater than a specified threshold, and 0 otherwise. | ||||
|  | ||||
|     .. math:: | ||||
|         \text{step}(x) = \begin{cases} | ||||
|         0 & \text{if } x < \text{threshold} \\ | ||||
|         1 & \text{if } x \geq \text{threshold} | ||||
|         \end{cases} | ||||
|  | ||||
|     Args: | ||||
|         threshold: The value to threshold at. | ||||
|     """ | ||||
|  | ||||
|     return mx.where(x > threshold, 1, 0) | ||||
|  | ||||
|  | ||||
| def selu(x): | ||||
|     r"""Applies the Scaled Exponential Linear Unit. | ||||
|  | ||||
|     .. math:: | ||||
|         \text{selu}(x) = \begin{cases} | ||||
|         \lambda x & \text{if } x > 0 \\ | ||||
|         \lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0 | ||||
|         \end{cases} | ||||
|  | ||||
|     where :math:`\lambda = 1.0507` and :math:`\alpha = 1.67326`. | ||||
|  | ||||
|     See also :func:`elu`. | ||||
|     """ | ||||
|     return elu(x, 1.67326) * 1.0507 | ||||
|  | ||||
|  | ||||
| @_make_activation_module(relu) | ||||
| class ReLU(Module): | ||||
|     pass | ||||
| @@ -274,3 +309,32 @@ def tanh(x): | ||||
| @_make_activation_module(tanh) | ||||
| class Tanh(Module): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class Step(Module): | ||||
|     r"""Applies the Step Activation Function. | ||||
|  | ||||
|     This function implements a binary step activation, where the output is set | ||||
|     to 1 if the input is greater than a specified threshold, and 0 otherwise. | ||||
|  | ||||
|     .. math:: | ||||
|         \text{step}(x) = \begin{cases} | ||||
|         0 & \text{if } x < \text{threshold} \\ | ||||
|         1 & \text{if } x \geq \text{threshold} | ||||
|         \end{cases} | ||||
|  | ||||
|     Args: | ||||
|         threshold: The value to threshold at. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, threshold: float = 0.0): | ||||
|         super().__init__() | ||||
|         self.threshold = threshold | ||||
|  | ||||
|     def __call__(self, x: mx.array): | ||||
|         return step(x, self.threshold) | ||||
|  | ||||
|  | ||||
| @_make_activation_module(selu) | ||||
| class SELU(Module): | ||||
|     pass | ||||
|   | ||||
| @@ -449,6 +449,32 @@ class TestNN(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(y.shape, [3]) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|     def test_step_activation(self): | ||||
|         x = mx.arange(-3, 4) | ||||
|         expected = mx.array([0, 0, 0, 0, 0, 1, 1]) | ||||
|         y = nn.Step()(x) | ||||
|         self.assertTrue(mx.array_equal(y, expected)) | ||||
|  | ||||
|         y = nn.Step(2)(x) | ||||
|         expected = mx.array([0, 0, 0, 0, 0, 0, 1]) | ||||
|         self.assertTrue(mx.array_equal(y, expected)) | ||||
|  | ||||
|     def test_selu(self): | ||||
|         x = mx.arange(-3, 4) | ||||
|         expected = mx.array( | ||||
|             [ | ||||
|                 -1.670563817024231, | ||||
|                 -1.5201621055603027, | ||||
|                 -1.1113275289535522, | ||||
|                 0.0, | ||||
|                 1.0506999492645264, | ||||
|                 2.1013998985290527, | ||||
|                 3.152099847793579, | ||||
|             ] | ||||
|         ) | ||||
|         y = nn.SELU()(x) | ||||
|         self.assertTrue(mx.allclose(y, expected)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user