mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Adapt the torch benchmark to run in CUDA
This commit is contained in:
parent
5adf185f86
commit
ab7c310914
@ -5,6 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@ -44,8 +45,10 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device != torch.device("cpu"):
|
if x.device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
elif x.device == torch.device("cuda"):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@ -340,7 +351,11 @@ if __name__ == "__main__":
|
|||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "mps"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
if args.cpu:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@ -460,5 +475,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
Loading…
Reference in New Issue
Block a user