mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid * minor fixes for docstring and benchmark imports * fixed elu implementation and added tests * added tests for optional param, changed leaky_relu param to fit pytorch documentation
This commit is contained in:
@@ -6,6 +6,7 @@ import os
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def int_or_list(x):
|
||||
@@ -99,6 +100,48 @@ def relu(x):
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def leaky_relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.leaky_relu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def elu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.elu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def relu6(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.relu6(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.softplus(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def celu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.celu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def log_sigmoid(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.log_sigmoid(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def scalar_mult(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
@@ -277,6 +320,24 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu":
|
||||
print(bench(relu, x))
|
||||
|
||||
elif args.benchmark == "leaky_relu":
|
||||
print(bench(leaky_relu, x))
|
||||
|
||||
elif args.benchmark == "elu":
|
||||
print(bench(elu, x))
|
||||
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
|
||||
elif args.benchmark == "celu":
|
||||
print(bench(celu, x))
|
||||
|
||||
elif args.benchmark == "log_sigmoid":
|
||||
print(bench(log_sigmoid, x))
|
||||
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
|
Reference in New Issue
Block a user