Add Step, ELU, SELU, Swish activation functions (#117)

* Add Step, ELU, SELU, Swish activation functions

This commit adds the Step, ELU, SELU and Swish activations functions

* add to the docs

* review
This commit is contained in:
Nicholas Santavas
2023-12-12 02:04:07 +01:00
committed by GitHub
parent b9226c367c
commit f5df47ec6e
7 changed files with 132 additions and 2 deletions

View File

@@ -223,6 +223,20 @@ def topk(axis, x):
mx.eval(ys)
def step_function(x):
y = x
for i in range(100):
y = nn.step(x)
mx.eval(y)
def selu(x):
y = x
for i in range(100):
y = nn.selu(x)
mx.eval(y)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
@@ -372,5 +386,11 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else:
raise ValueError("Unknown benchmark")

View File

@@ -257,6 +257,14 @@ def topk(axis, x):
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
for i in range(100):
y = torch.nn.functional.selu(y)
sync_if_needed(x)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")

View File

@@ -205,6 +205,10 @@ if __name__ == "__main__":
compare_filtered("celu --size 32x16x1024 --cpu")
compare_filtered("log_sigmoid --size 32x16x1024")
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
compare_filtered("step --size 32x16x1024")
compare_filtered("step --size 32x16x1024 --cpu")
compare_filtered("selu --size 32x16x1024")
compare_filtered("selu --size 32x16x1024 --cpu")
compare_filtered("scalar_mul --size 32x16x1024")
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
compare_filtered("cross_entropy --size 256x1024")