Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)

* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
This commit is contained in:
Jason
2023-12-10 19:31:38 -05:00
committed by GitHub
parent 71d1fff90a
commit b0cd092b7f
6 changed files with 344 additions and 0 deletions

View File

@@ -115,6 +115,54 @@ def relu(x):
sync_if_needed(x)
@torch.no_grad()
def leaky_relu(x):
y = x
for i in range(100):
y = torch.nn.functional.leaky_relu(y)
sync_if_needed(x)
@torch.no_grad()
def elu(x):
y = x
for i in range(100):
y = torch.nn.functional.elu(y)
sync_if_needed(x)
@torch.no_grad()
def celu(x):
y = x
for i in range(100):
y = torch.nn.functional.celu(y)
sync_if_needed(x)
@torch.no_grad()
def relu6(x):
y = x
for i in range(100):
y = torch.nn.functional.relu6(y)
sync_if_needed(x)
@torch.no_grad()
def softplus(x):
y = x
for i in range(100):
y = torch.nn.functional.softplus(y)
sync_if_needed(x)
@torch.no_grad()
def log_sigmoid(x):
y = x
for i in range(100):
y = torch.nn.functional.logsigmoid(y)
sync_if_needed(x)
@torch.no_grad()
def scalar_mult(x):
y = x
@@ -302,6 +350,24 @@ if __name__ == "__main__":
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))