Activations LeakyReLU / PReLU / Softplus / Mish (#109)

* Leaky_relu / prelu / softplus / mish

* added tests

* updated bench

* remove torch refs, add init to PReLU

* added arvix reference to mish

* added missing docs
This commit is contained in:
Diogo
2023-12-11 22:40:57 -05:00
committed by GitHub
parent f5df47ec6e
commit 02de234ef0
8 changed files with 133 additions and 31 deletions

View File

@@ -96,7 +96,35 @@ def softmax_fused(axis, x):
def relu(x):
y = x
for i in range(100):
y = mx.maximum(y, 0)
y = nn.relu(y)
mx.eval(y)
def leaky_relu(x: mx.array):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def prelu(x: mx.array):
y = x
for i in range(100):
y = nn.prelu(y, mx.ones(1))
mx.eval(y)
def softplus(x: mx.array):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def mish(x: mx.array):
y = x
for i in range(100):
y = nn.mish(y)
mx.eval(y)
@@ -334,24 +362,26 @@ if __name__ == "__main__":
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))

View File

@@ -163,6 +163,22 @@ def log_sigmoid(x):
sync_if_needed(x)
@torch.no_grad()
def prelu(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
sync_if_needed(x)
@torch.no_grad()
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
return torch.nn.functional.mish(y)
sync_if_needed(x)
@torch.no_grad()
def scalar_mult(x):
y = x
@@ -376,6 +392,10 @@ if __name__ == "__main__":
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))

View File

@@ -209,6 +209,11 @@ if __name__ == "__main__":
compare_filtered("step --size 32x16x1024 --cpu")
compare_filtered("selu --size 32x16x1024")
compare_filtered("selu --size 32x16x1024 --cpu")
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
compare_filtered("mish --size 32x16x1024 --cpu")
compare_filtered("prelu --size 32x16x1024")
compare_filtered("prelu --size 32x16x1024 --cpu")
compare_filtered("scalar_mul --size 32x16x1024")
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
compare_filtered("cross_entropy --size 256x1024")