mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Reduce update (#783)
* Split reduction files to reduce compile times * Add small and medium axis size specializations for row reductions * Add non-row-reduction options for small and med kernels
This commit is contained in:
@@ -380,10 +380,6 @@ if __name__ == "__main__":
|
||||
if len(args.axis) > 1:
|
||||
args.axis.pop(0)
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
else:
|
||||
@@ -406,6 +402,10 @@ if __name__ == "__main__":
|
||||
x = xs[0]
|
||||
axis = args.axis[0]
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
if args.benchmark == "matmul_square":
|
||||
print(bench(matmul_square, x))
|
||||
|
||||
|
Reference in New Issue
Block a user