mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 19:48:15 +08:00
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish * added tests * updated bench * remove torch refs, add init to PReLU * added arvix reference to mish * added missing docs
This commit is contained in:
@@ -163,6 +163,22 @@ def log_sigmoid(x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
return torch.nn.functional.mish(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def scalar_mult(x):
|
||||
y = x
|
||||
@@ -376,6 +392,10 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "log_sigmoid":
|
||||
print(bench(log_sigmoid, x))
|
||||
|
||||
elif args.benchmark == "prelu":
|
||||
print(bench(prelu, x))
|
||||
elif args.benchmark == "mish":
|
||||
print(bench(mish, x))
|
||||
elif args.benchmark == "scalar_mul":
|
||||
print(bench(scalar_mult, x))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user