mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster indexing math in a few kernels (#1589)
* wip: faster compiled kernels * faster general unary with uint specialization * index type in compiled, unary, binary, ternary, copy * fix jit * jit fix * specialize gather + scatter * nit in docs
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gather_kernels = R"(
|
||||
[[kernel]] void gather{0}_{3}_{6}(
|
||||
[[kernel]] void gather{0}_{3}_{6}_{7}(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
@@ -19,7 +19,7 @@ constexpr std::string_view gather_kernels = R"(
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
return gather_impl<{1}, {2}, {3}, {6}, {7}>(
|
||||
src,
|
||||
out,
|
||||
src_shape,
|
||||
@@ -34,7 +34,7 @@ constexpr std::string_view gather_kernels = R"(
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
|
||||
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
@@ -54,7 +54,7 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
|
||||
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
|
||||
Reference in New Issue
Block a user