mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 13:54:44 +08:00
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish * added tests * updated bench * remove torch refs, add init to PReLU * added arvix reference to mish * added missing docs
This commit is contained in:
@@ -96,7 +96,35 @@ def softmax_fused(axis, x):
|
||||
def relu(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = mx.maximum(y, 0)
|
||||
y = nn.relu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def leaky_relu(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.leaky_relu(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def prelu(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.prelu(y, mx.ones(1))
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def softplus(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.softplus(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def mish(x: mx.array):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = nn.mish(y)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
@@ -334,24 +362,26 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "relu":
|
||||
print(bench(relu, x))
|
||||
|
||||
elif args.benchmark == "leaky_relu":
|
||||
print(bench(leaky_relu, x))
|
||||
|
||||
elif args.benchmark == "elu":
|
||||
print(bench(elu, x))
|
||||
|
||||
elif args.benchmark == "relu6":
|
||||
print(bench(relu6, x))
|
||||
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
|
||||
elif args.benchmark == "celu":
|
||||
print(bench(celu, x))
|
||||
|
||||
elif args.benchmark == "log_sigmoid":
|
||||
print(bench(log_sigmoid, x))
|
||||
|
||||
elif args.benchmark == "leaky_relu":
|
||||
print(bench(leaky_relu, x))
|
||||
elif args.benchmark == "prelu":
|
||||
print(bench(prelu, x))
|
||||
elif args.benchmark == "softplus":
|
||||
print(bench(softplus, x))
|
||||
elif args.benchmark == "mish":
|
||||
print(bench(mish, x))
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
|
@@ -163,6 +163,22 @@ def log_sigmoid(x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
return torch.nn.functional.mish(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def scalar_mult(x):
|
||||
y = x
|
||||
@@ -376,6 +392,10 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "log_sigmoid":
|
||||
print(bench(log_sigmoid, x))
|
||||
|
||||
elif args.benchmark == "prelu":
|
||||
print(bench(prelu, x))
|
||||
elif args.benchmark == "mish":
|
||||
print(bench(mish, x))
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
|
@@ -209,6 +209,11 @@ if __name__ == "__main__":
|
||||
compare_filtered("step --size 32x16x1024 --cpu")
|
||||
compare_filtered("selu --size 32x16x1024")
|
||||
compare_filtered("selu --size 32x16x1024 --cpu")
|
||||
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
|
||||
compare_filtered("mish --size 32x16x1024 --cpu")
|
||||
compare_filtered("prelu --size 32x16x1024")
|
||||
compare_filtered("prelu --size 32x16x1024 --cpu")
|
||||
|
||||
compare_filtered("scalar_mul --size 32x16x1024")
|
||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||
compare_filtered("cross_entropy --size 256x1024")
|
||||
|
Reference in New Issue
Block a user