mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
fix custom metal extension (#2446)
This commit is contained in:
@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c shape: {c_cpu.shape}")
|
||||
print(f"c dtype: {c_cpu.dtype}")
|
||||
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||
|
||||
Reference in New Issue
Block a user