More jitting (#1132)

* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
This commit is contained in:
Awni Hannun
2024-05-23 16:23:44 -07:00
committed by GitHub
parent 9401507336
commit 0189ab6ab6
41 changed files with 2377 additions and 1846 deletions

View File

@@ -5,6 +5,13 @@
namespace mlx::core {
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -15,31 +22,84 @@ MTL::ComputePipelineState* get_unary_kernel(
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
bool,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
bool,
bool,
const array&,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&,
int,
int) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_mb_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&,
int,
int) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&) {
return d.get_kernel(kernel_name);
}