mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-28 22:28:11 +08:00
jagrit's commit files
This commit is contained in:
38
benchmarks/numpy/single_ops.py
Normal file
38
benchmarks/numpy/single_ops.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def time_add():
|
||||
a = np.ones((100, 100, 10), dtype=np.float32)
|
||||
b = np.ones((100, 100, 10), dtype=np.float32)
|
||||
time_fn(np.add, a, b)
|
||||
|
||||
|
||||
def time_matmul():
|
||||
a = np.random.rand(1000, 500).astype(np.float32)
|
||||
b = np.random.rand(500, 1000).astype(np.float32)
|
||||
time_fn(np.matmul, a, b)
|
||||
|
||||
|
||||
def time_exp():
|
||||
a = np.random.randn(1000, 100).astype(np.float32)
|
||||
time_fn(np.exp, a)
|
||||
|
||||
|
||||
def time_take():
|
||||
a = np.random.rand(10000, 500)
|
||||
ids = np.random.randint(0, 10000, (20, 10))
|
||||
ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
|
||||
|
||||
def random_take():
|
||||
return [np.take(a, idx, 0) for idx in ids]
|
||||
|
||||
time_fn(random_take)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_exp()
|
||||
time_take()
|
||||
Reference in New Issue
Block a user