Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)

* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
This commit is contained in:
Jason
2023-12-10 19:31:38 -05:00
committed by GitHub
parent 71d1fff90a
commit b0cd092b7f
6 changed files with 344 additions and 0 deletions

View File

@@ -6,6 +6,7 @@ import os
import time
import mlx.core as mx
import mlx.nn as nn
def int_or_list(x):
@@ -99,6 +100,48 @@ def relu(x):
mx.eval(y)
def leaky_relu(x):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def elu(x):
y = x
for i in range(100):
y = nn.elu(y)
mx.eval(y)
def relu6(x):
y = x
for i in range(100):
y = nn.relu6(y)
mx.eval(y)
def softplus(x):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def celu(x):
y = x
for i in range(100):
y = nn.celu(y)
mx.eval(y)
def log_sigmoid(x):
y = x
for i in range(100):
y = nn.log_sigmoid(y)
mx.eval(y)
def scalar_mult(x):
y = x
for i in range(100):
@@ -277,6 +320,24 @@ if __name__ == "__main__":
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))