mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 13:07:29 +08:00
Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions This commit adds the Step, ELU, SELU and Swish activations functions * add to the docs * review
This commit is contained in:

committed by
GitHub

parent
b9226c367c
commit
f5df47ec6e
@@ -223,6 +223,20 @@ def topk(axis, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def step_function(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.step(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def selu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.selu(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
||||
@@ -372,5 +386,11 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
elif args.benchmark == "step":
|
||||
print(bench(step_function, x))
|
||||
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
Reference in New Issue
Block a user