mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish * added tests * updated bench * remove torch refs, add init to PReLU * added arvix reference to mish * added missing docs
This commit is contained in:
@@ -96,7 +96,35 @@ def softmax_fused(axis, x):
|
||||
def relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = mx.maximum(y, 0)
|
||||
y = nn.relu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def leaky_relu(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.leaky_relu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def prelu(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.prelu(y, mx.ones(1))
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def softplus(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.softplus(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def mish(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.mish(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
@@ -334,24 +362,26 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu":
|
||||
print(bench(relu, x))
|
||||
|
||||
elif args.benchmark == "leaky_relu":
|
||||
print(bench(leaky_relu, x))
|
||||
|
||||
elif args.benchmark == "elu":
|
||||
print(bench(elu, x))
|
||||
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
|
||||
elif args.benchmark == "celu":
|
||||
print(bench(celu, x))
|
||||
|
||||
elif args.benchmark == "log_sigmoid":
|
||||
print(bench(log_sigmoid, x))
|
||||
|
||||
elif args.benchmark == "leaky_relu":
|
||||
print(bench(leaky_relu, x))
|
||||
elif args.benchmark == "prelu":
|
||||
print(bench(prelu, x))
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
elif args.benchmark == "mish":
|
||||
print(bench(mish, x))
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
|
Reference in New Issue
Block a user