Reduce update (#783)

* Split reduction files to reduce compile times

* Add small and medium axis size specializations for row reductions

* Add non-row-reduction options for small and med kernels
This commit is contained in:
Jagrit Digani
2024-03-04 19:09:51 -08:00
committed by GitHub
parent c096a77b9b
commit 6686e61ca4
13 changed files with 949 additions and 667 deletions

View File

@@ -380,10 +380,6 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.cpu:
mx.set_default_device(mx.cpu)
else:
@@ -406,6 +402,10 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))