Reduce specializations (#1607)

* start of reduce specializations

* fix all reduce

* fix many dims

* fix

* non-jit tests clear

* cleanup instantiations

* cpu merges

* change dim specializations

* optimize

* fix jit

* fix jit

* use higher precision for integer sum+prod

* fixes
This commit is contained in:
Awni Hannun
2024-11-21 19:53:00 -08:00
committed by GitHub
parent dcca0d7477
commit 0c5eea226b
14 changed files with 733 additions and 406 deletions

View File

@@ -1,3 +1,4 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -99,7 +100,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name,
const std::string&,
const std::string&,
const array&) {
const Dtype&) {
return d.get_kernel(kernel_name);
}
@@ -108,8 +109,9 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name,
const std::string&,
const std::string&,
const array&,
const array&,
const Dtype&,
const Dtype&,
const std::string&,
int,
int,
int) {