mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions This commit adds the Step, ELU, SELU and Swish activations functions * add to the docs * review
This commit is contained in:

committed by
GitHub

parent
b9226c367c
commit
f5df47ec6e
@@ -223,6 +223,20 @@ def topk(axis, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def step_function(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.step(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def selu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.selu(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||
@@ -372,5 +386,11 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
elif args.benchmark == "step":
|
||||
print(bench(step_function, x))
|
||||
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
@@ -257,6 +257,14 @@ def topk(axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def selu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.nn.functional.selu(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||
|
@@ -205,6 +205,10 @@ if __name__ == "__main__":
|
||||
compare_filtered("celu --size 32x16x1024 --cpu")
|
||||
compare_filtered("log_sigmoid --size 32x16x1024")
|
||||
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
|
||||
compare_filtered("step --size 32x16x1024")
|
||||
compare_filtered("step --size 32x16x1024 --cpu")
|
||||
compare_filtered("selu --size 32x16x1024")
|
||||
compare_filtered("selu --size 32x16x1024 --cpu")
|
||||
compare_filtered("scalar_mul --size 32x16x1024")
|
||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||
compare_filtered("cross_entropy --size 256x1024")
|
||||
|
Reference in New Issue
Block a user