mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 14:58:13 +08:00
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid * minor fixes for docstring and benchmark imports * fixed elu implementation and added tests * added tests for optional param, changed leaky_relu param to fit pytorch documentation
This commit is contained in:
@@ -6,6 +6,7 @@ import os
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def int_or_list(x):
|
||||
@@ -99,6 +100,48 @@ def relu(x):
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def leaky_relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.leaky_relu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def elu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.elu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def relu6(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.relu6(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.softplus(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def celu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.celu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def log_sigmoid(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.log_sigmoid(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def scalar_mult(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
@@ -277,6 +320,24 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu":
|
||||
print(bench(relu, x))
|
||||
|
||||
elif args.benchmark == "leaky_relu":
|
||||
print(bench(leaky_relu, x))
|
||||
|
||||
elif args.benchmark == "elu":
|
||||
print(bench(elu, x))
|
||||
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
|
||||
elif args.benchmark == "celu":
|
||||
print(bench(celu, x))
|
||||
|
||||
elif args.benchmark == "log_sigmoid":
|
||||
print(bench(log_sigmoid, x))
|
||||
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
|
@@ -115,6 +115,54 @@ def relu(x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def leaky_relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.leaky_relu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def elu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.elu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def celu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.celu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def relu6(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.relu6(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softplus(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.softplus(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_sigmoid(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.logsigmoid(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def scalar_mult(x):
|
||||
y = x
|
||||
@@ -302,6 +350,24 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu":
|
||||
print(bench(relu, x))
|
||||
|
||||
elif args.benchmark == "leaky_relu":
|
||||
print(bench(leaky_relu, x))
|
||||
|
||||
elif args.benchmark == "elu":
|
||||
print(bench(elu, x))
|
||||
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
|
||||
elif args.benchmark == "celu":
|
||||
print(bench(celu, x))
|
||||
|
||||
elif args.benchmark == "log_sigmoid":
|
||||
print(bench(log_sigmoid, x))
|
||||
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
|
@@ -193,6 +193,18 @@ if __name__ == "__main__":
|
||||
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
|
||||
compare_filtered("relu --size 32x16x1024")
|
||||
compare_filtered("relu --size 32x16x1024 --cpu")
|
||||
compare_filtered("leaky_relu --size 32x16x1024")
|
||||
compare_filtered("leaky_relu --size 32x16x1024 --cpu")
|
||||
compare_filtered("elu --size 32x16x1024")
|
||||
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||
compare_filtered("relu6 --size 32x16x1024")
|
||||
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||
compare_filtered("softplus --size 32x16x1024")
|
||||
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||
compare_filtered("celu --size 32x16x1024")
|
||||
compare_filtered("celu --size 32x16x1024 --cpu")
|
||||
compare_filtered("log_sigmoid --size 32x16x1024")
|
||||
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
|
||||
compare_filtered("scalar_mul --size 32x16x1024")
|
||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||
compare_filtered("cross_entropy --size 256x1024")
|
||||
|
Reference in New Issue
Block a user