mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Up to 10x faster scatter. (#709)
* Faster scatter. Add specialization for 1-d index tensors. * Address review comments. - Check for row contiguity of index, update tensors instead of checking strides. - Add support for 1d specialization with col contiguous update tensor, along with a test. * Nit1 Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Nit2 Co-authored-by: Awni Hannun <awni.hannun@gmail.com> --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -7,12 +7,14 @@ import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
def scatter(dst, x, idx):
|
||||
dst[idx] = x
|
||||
dst[*idx] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
|
||||
idx = []
|
||||
for idx_shape in idx_shapes:
|
||||
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||
|
||||
@@ -20,13 +22,15 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[idx] = x
|
||||
dst[*idx] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
|
||||
idx = []
|
||||
for idx_shape in idx_shapes:
|
||||
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
@@ -45,9 +49,45 @@ if __name__ == "__main__":
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
|
||||
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
|
||||
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
|
||||
dst_shapes = [
|
||||
(10, 64),
|
||||
(100_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000,),
|
||||
(2_000_00,),
|
||||
(20_000_000,),
|
||||
(10000, 64),
|
||||
(100, 64),
|
||||
(100, 10_000, 64),
|
||||
(10, 100, 100, 21),
|
||||
(1_000, 1_000, 10),
|
||||
]
|
||||
idx_shapes = [
|
||||
[(1_000_000,)],
|
||||
[(1_000_000,)],
|
||||
[(100_000,)],
|
||||
[(1_000_000,)],
|
||||
[(20_000_000,)],
|
||||
[(20_000_000,)],
|
||||
[(1000000,)],
|
||||
[(10000000,)],
|
||||
[(1_000,)],
|
||||
[(10_000,)],
|
||||
[(1_000,), (1_000,)],
|
||||
]
|
||||
x_shapes = [
|
||||
(1_000_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000, 64),
|
||||
(1_000_000,),
|
||||
(20_000_000,),
|
||||
(20_000_000,),
|
||||
(1000000, 64),
|
||||
(10000000, 64),
|
||||
(1_000, 10_000, 64),
|
||||
(10_000, 100, 100, 21),
|
||||
(1_000, 10),
|
||||
]
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
|
Reference in New Issue
Block a user