mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous * style * nits in docs * a bunch more stuff * update jit * update jit * use constant for shapes and strides and remove elem_to_loc overload * use kernel instantiation * docs nits * update binary and ternary * comments
This commit is contained in:
parent
c6739ba7f3
commit
4f46e9c997
@ -13,6 +13,7 @@ simple functions.
|
||||
:template: nn-module-template.rst
|
||||
|
||||
elu
|
||||
celu
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
|
@ -13,6 +13,7 @@ Layers
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
BatchNorm
|
||||
CELU
|
||||
Conv1d
|
||||
Conv2d
|
||||
Conv3d
|
||||
@ -23,6 +24,7 @@ Layers
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
Embedding
|
||||
ELU
|
||||
GELU
|
||||
GLU
|
||||
GroupNorm
|
||||
@ -34,6 +36,8 @@ Layers
|
||||
LayerNorm
|
||||
LeakyReLU
|
||||
Linear
|
||||
LogSigmoid
|
||||
LogSoftmax
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
@ -49,6 +53,7 @@ Layers
|
||||
RoPE
|
||||
SELU
|
||||
Sequential
|
||||
Sigmoid
|
||||
SiLU
|
||||
SinusoidalPositionalEncoding
|
||||
Softmin
|
||||
|
@ -51,6 +51,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if(IOS)
|
||||
|
88
mlx/backend/common/utils.cpp
Normal file
88
mlx/backend/common/utils.cpp
Normal file
@ -0,0 +1,88 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename stride_t>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
||||
collapse_contiguous_dims_impl(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<stride_t>>& strides,
|
||||
stride_t size_cap) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (shape.size() > 0) {
|
||||
if (shape[0] != 1) {
|
||||
to_collapse.push_back(0);
|
||||
}
|
||||
size_t size = shape[0];
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
size *= shape[i];
|
||||
for (const std::vector<stride_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
|
||||
contiguous = false;
|
||||
size = shape[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!contiguous) {
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
if (shape[i] != 1) {
|
||||
to_collapse.push_back(i);
|
||||
}
|
||||
}
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
||||
for (int i = 0;;) {
|
||||
while (i < to_collapse.size() && to_collapse[i] == -1) {
|
||||
++i;
|
||||
};
|
||||
if (i == to_collapse.size()) {
|
||||
break;
|
||||
}
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
int k = i;
|
||||
while (to_collapse[++k] != -1) {
|
||||
current_shape *= shape[to_collapse[k]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<stride_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
||||
}
|
||||
i = k + 1;
|
||||
}
|
||||
|
||||
if (!shape.empty() && out_shape.empty()) {
|
||||
out_shape.push_back(1);
|
||||
for (auto& out_stride : out_strides) {
|
||||
out_stride.push_back(0);
|
||||
}
|
||||
}
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<int64_t>>& strides,
|
||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>>& strides,
|
||||
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
|
||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -44,58 +44,26 @@ std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
template <typename stride_t>
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<stride_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (shape.size() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (const std::vector<stride_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!contiguous) {
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
to_collapse.push_back(i);
|
||||
}
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= shape[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<stride_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
const std::vector<std::vector<int64_t>>& strides,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>>& strides,
|
||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<array>& xs,
|
||||
size_t size_cap = std::numeric_limits<size_t>::max()) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
|
||||
}
|
||||
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
|
@ -19,7 +19,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
std::string get_kernel_name(
|
||||
BinaryOpType bopt,
|
||||
@ -69,46 +69,61 @@ void binary_op_gpu_inplace(
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
auto maybe_collapse = [bopt, &a, &b, &out]() {
|
||||
if (bopt == BinaryOpType::General) {
|
||||
// The size cap here should ideally be `UINT32_MAX` but we are
|
||||
// limitied by the shape being an int.
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
{a, b, out},
|
||||
/* size_cap = */ INT32_MAX);
|
||||
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
|
||||
} else {
|
||||
std::vector<size_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel =
|
||||
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
|
||||
|
||||
auto kernel = outputs.size() == 2
|
||||
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
|
||||
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
// otherwise it goes to the second output
|
||||
// otherwise it goes to the second output.
|
||||
// - If there is only one output only one of a and b will be donated.
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
|
||||
int arg_idx = 0;
|
||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
|
||||
compute_encoder.set_input_array(
|
||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||
compute_encoder.set_output_array(outputs[0], 2);
|
||||
compute_encoder.set_output_array(outputs[1], 3);
|
||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
|
||||
compute_encoder.set_output_array(outputs[0], arg_idx++);
|
||||
if (outputs.size() == 2) {
|
||||
compute_encoder.set_output_array(outputs[1], arg_idx++);
|
||||
}
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
compute_encoder->setBytes(
|
||||
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
@ -125,9 +140,8 @@ void binary_op_gpu_inplace(
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
@ -164,72 +178,8 @@ void binary_op_gpu_inplace(
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
std::vector<array> outputs = {out};
|
||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
|
@ -22,7 +22,8 @@ inline void build_kernel(
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
bool contiguous,
|
||||
int ndim,
|
||||
bool dynamic_dims) {
|
||||
bool dynamic_dims,
|
||||
bool use_big_index = false) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
@ -84,9 +85,15 @@ inline void build_kernel(
|
||||
|
||||
// The thread index in the whole grid
|
||||
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
|
||||
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
|
||||
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
|
||||
<< std::endl;
|
||||
<< " uint3 grid [[threads_per_grid]]) {" << std::endl;
|
||||
if (use_big_index) {
|
||||
// This is only used for contiguous kernels which don't have
|
||||
// a third grid dimension
|
||||
os << " size_t index = pos.x + grid.x * size_t(pos.y);";
|
||||
} else {
|
||||
os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);";
|
||||
}
|
||||
os << std::endl;
|
||||
|
||||
// Extract the indices per axis to individual uints if we have arrays that
|
||||
// are broadcasted or transposed
|
||||
@ -212,6 +219,17 @@ void Compiled::eval_gpu(
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false);
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous_big",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false,
|
||||
/* use_big_index = */ true);
|
||||
for (int i = 1; i < 8; i++) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
@ -285,7 +303,16 @@ void Compiled::eval_gpu(
|
||||
initial_strides.push_back(std::move(xstrides));
|
||||
}
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(output_shape, initial_strides);
|
||||
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
||||
}
|
||||
|
||||
bool use_2d = false;
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
use_2d = (max_size > UINT32_MAX);
|
||||
}
|
||||
|
||||
// Get the kernel from the lib
|
||||
@ -298,6 +325,8 @@ void Compiled::eval_gpu(
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
} else if (use_2d) {
|
||||
kernel_name += "_big";
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@ -348,8 +377,10 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].size();
|
||||
MTL::Size grid_dims(nthreads, 1, 1);
|
||||
size_t nthreads = outputs[0].data_size();
|
||||
MTL::Size grid_dims = use_2d
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
@ -10,7 +10,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
@ -59,10 +59,20 @@ void copy_gpu_inplace(
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector{strides_in_pre, strides_out_pre});
|
||||
auto& strides_in_ = strides[0];
|
||||
auto& strides_out_ = strides[1];
|
||||
auto maybe_collapse =
|
||||
[ctype, &data_shape, &strides_in_pre, &strides_out_pre]() {
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape,
|
||||
std::vector{strides_in_pre, strides_out_pre},
|
||||
/* size_cap = */ INT32_MAX);
|
||||
return std::make_tuple(shape, strides[0], strides[1]);
|
||||
} else {
|
||||
std::vector<stride_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
|
@ -1,100 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view copy_kernels = R"(
|
||||
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg4_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg5_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg1_{0}")]] [[kernel]] void
|
||||
copy_gg_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg2_{0}")]] [[kernel]] void
|
||||
copy_gg_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg3_{0}")]] [[kernel]] void
|
||||
copy_gg_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
)";
|
@ -1,9 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <map>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
@ -67,7 +65,7 @@ void add_binary_kernels(
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::ostringstream& kernel_source) {
|
||||
const std::map<std::string, std::string> kernel_types = {
|
||||
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
@ -78,29 +76,16 @@ void add_binary_kernels(
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
{"g4", "binary_g_nd"},
|
||||
{"g5", "binary_g_nd"},
|
||||
{"gn", "binary_g"},
|
||||
};
|
||||
for (auto [name, func] : kernel_types) {
|
||||
}};
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
if (name == "g4" || name == "g5") {
|
||||
int dim = std::stoi(name.substr(1));
|
||||
template_def = get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op,
|
||||
dim);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op);
|
||||
}
|
||||
template_def = get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op);
|
||||
kernel_source << template_def;
|
||||
}
|
||||
}
|
||||
@ -149,27 +134,19 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
const std::map<std::string, std::string> kernel_types = {
|
||||
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g", "ternary_g"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
{"g3", "ternary_g_nd3"},
|
||||
{"g4", "ternary_g_nd"},
|
||||
{"g5", "ternary_g_nd"},
|
||||
};
|
||||
}};
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
|
||||
for (auto [name, func] : kernel_types) {
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
if (name == "g4" || name == "g5") {
|
||||
int dim = std::stoi(name.substr(1));
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op, dim);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op);
|
||||
}
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op);
|
||||
kernel_source << template_def;
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
@ -186,12 +163,27 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::copy()
|
||||
<< fmt::format(
|
||||
copy_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()));
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source
|
||||
<< metal::utils() << metal::copy()
|
||||
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
|
||||
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||
<< get_template_definition(
|
||||
"gg_" + lib_name, "copy_gg", in_type, out_type);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@ -296,11 +288,11 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
std::vector<std::pair<std::string, std::string>> kernel_types = {
|
||||
{"sort_", "mb_block_sort"},
|
||||
{"partition_", "mb_block_partition"},
|
||||
{"merge_", "mb_block_merge"}};
|
||||
for (auto [name, func] : kernel_types) {
|
||||
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
|
||||
{{"sort_", "mb_block_sort"},
|
||||
{"partition_", "mb_block_partition"},
|
||||
{"merge_", "mb_block_merge"}}};
|
||||
for (auto& [name, func] : kernel_types) {
|
||||
kernel_source << get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
|
@ -70,16 +70,16 @@ IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int N_READS>
|
||||
template <typename T, typename Op, int N_READS = 4>
|
||||
[[kernel]] void arg_reduce_general(
|
||||
const device T* in [[buffer(0)]],
|
||||
device uint32_t* out [[buffer(1)]],
|
||||
const device int* shape [[buffer(2)]],
|
||||
const device size_t* in_strides [[buffer(3)]],
|
||||
const device size_t* out_strides [[buffer(4)]],
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
const device size_t& axis_stride [[buffer(6)]],
|
||||
const device size_t& axis_size [[buffer(7)]],
|
||||
const constant int* shape [[buffer(2)]],
|
||||
const constant size_t* in_strides [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& ndim [[buffer(5)]],
|
||||
const constant size_t& axis_stride [[buffer(6)]],
|
||||
const constant size_t& axis_size [[buffer(7)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
@ -159,28 +159,12 @@ template <typename T, typename Op, int N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_arg_reduce_helper(name, itype, op) \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
arg_reduce_general<itype, op<itype>, 4>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device uint32_t* out [[buffer(1)]], \
|
||||
const device int* shape [[buffer(2)]], \
|
||||
const device size_t* in_strides [[buffer(3)]], \
|
||||
const device size_t* out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
const device size_t& axis_stride [[buffer(6)]], \
|
||||
const device size_t& axis_size [[buffer(7)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_arg_reduce(name, itype) \
|
||||
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
||||
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
||||
instantiate_kernel( \
|
||||
"argmin_" #name, arg_reduce_general, itype, ArgMin<itype>) \
|
||||
instantiate_kernel( \
|
||||
"argmax_" #name, arg_reduce_general, itype, ArgMax<itype>)
|
||||
|
||||
instantiate_arg_reduce(bool_, bool)
|
||||
instantiate_arg_reduce(uint8, uint8_t)
|
||||
|
@ -93,7 +93,7 @@ template <typename T, typename U, typename Op>
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
@ -109,26 +109,10 @@ template <typename T, typename U, typename Op>
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
@ -141,6 +125,7 @@ template <typename T, typename U, typename Op>
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
@ -21,8 +21,6 @@
|
||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
|
||||
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
|
@ -118,7 +118,7 @@ template <typename T, typename U, typename Op>
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
@ -137,31 +137,12 @@ template <typename T, typename U, typename Op>
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
@ -175,7 +156,8 @@ template <typename T, typename U, typename Op>
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
|
@ -19,8 +19,6 @@
|
||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
|
||||
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
|
@ -71,20 +71,6 @@ template <typename T, typename U>
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src [[buffer(0)]],
|
||||
@ -136,19 +122,6 @@ template <typename T, typename U>
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src [[buffer(0)]],
|
||||
|
@ -16,10 +16,6 @@
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
|
||||
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
|
||||
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
|
||||
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
|
||||
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
||||
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
|
||||
|
||||
|
@ -69,9 +69,9 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
device const int& ndim,
|
||||
device const int* key_shape,
|
||||
device const size_t* key_strides,
|
||||
constant const int& ndim,
|
||||
constant const int* key_shape,
|
||||
constant const size_t* key_strides,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
|
@ -342,9 +342,9 @@ template <
|
||||
const constant int& in_stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* in_nc_strides [[buffer(7)]],
|
||||
const device size_t* out_nc_strides [[buffer(8)]],
|
||||
const constant int* nc_shape [[buffer(6)]],
|
||||
const constant size_t* in_nc_strides [[buffer(7)]],
|
||||
const constant size_t* out_nc_strides [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
@ -485,8 +485,8 @@ template <
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
const constant int* nc_shape [[buffer(6)]],
|
||||
const constant size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel = KernelMultiBlockMergeSort<
|
||||
|
@ -52,7 +52,7 @@ template <typename T, typename Op>
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2(index, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
@ -71,29 +71,10 @@ template <typename T, typename Op>
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int DIM>
|
||||
[[kernel]] void ternary_g_nd(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
constant const size_t c_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx =
|
||||
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_g(
|
||||
device const bool* a,
|
||||
@ -109,6 +90,7 @@ template <typename T, typename Op>
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx =
|
||||
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
size_t out_idx =
|
||||
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
@ -16,8 +16,6 @@
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("g4_" #op #tname, ternary_g_nd, type, op, 4) \
|
||||
instantiate_kernel("g5_" #op #tname, ternary_g_nd, type, op, 5)
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
|
@ -22,8 +22,8 @@ template <typename T, typename Op>
|
||||
[[kernel]] void unary_g(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
device const int* in_shape,
|
||||
device const size_t* in_strides,
|
||||
constant const int* in_shape,
|
||||
constant const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
|
||||
|
@ -83,20 +83,6 @@ struct Limits<complex64_t> {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with generic dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const stride_t* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
uint elem,
|
||||
@ -111,20 +97,6 @@ METAL_FUNC stride_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
stride_t elem,
|
||||
device const int* shape,
|
||||
device const stride_t* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
stride_t elem,
|
||||
@ -174,78 +146,18 @@ elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
METAL_FUNC size_t elem_to_loc_nd(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int d = NDIM - 2; d >= 0; --d) {
|
||||
elem /= shape[d + 1];
|
||||
loc += (elem % shape[d]) * strides[d];
|
||||
}
|
||||
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
METAL_FUNC size_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t strides[NDIM]) {
|
||||
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
METAL_FUNC int64_t elem_to_loc_nd(
|
||||
uint elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const int64_t strides[NDIM]) {
|
||||
int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int d = NDIM - 2; d >= 0; --d) {
|
||||
elem /= shape[d + 1];
|
||||
loc += (elem % shape[d]) * strides[d];
|
||||
}
|
||||
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
METAL_FUNC int64_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const int64_t strides[NDIM]) {
|
||||
int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with generic dims
|
||||
|
||||
METAL_FUNC uint2 elem_to_loc_2_nd(
|
||||
METAL_FUNC ulong2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
ulong2 loc = {
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
@ -255,20 +167,17 @@ METAL_FUNC uint2 elem_to_loc_2_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||
METAL_FUNC ulong3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
int ndim) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
|
||||
ulong3 loc = {
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
|
||||
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
@ -279,53 +188,6 @@ METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with fixed N dims
|
||||
|
||||
template <int NDIM>
|
||||
METAL_FUNC uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM],
|
||||
constant const size_t c_strides[NDIM]) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
|
||||
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
@ -26,11 +26,21 @@ void ternary_op_gpu_inplace(
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_c = strides[2];
|
||||
auto& strides_out = strides[3];
|
||||
auto maybe_collapse = [topt, &a, &b, &c, &out]() {
|
||||
if (topt == TernaryOpType::General) {
|
||||
// The size cap here should ideally be `UINT32_MAX` but we are
|
||||
// limitied by the shape being an int.
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
{a, b, c, out},
|
||||
/* size_cap = */ INT32_MAX);
|
||||
return std::make_tuple(
|
||||
shape, strides[0], strides[1], strides[2], strides[3]);
|
||||
} else {
|
||||
std::vector<size_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
||||
|
||||
bool use_2d = out.data_size() > UINT_MAX;
|
||||
std::string kernel_name;
|
||||
@ -88,7 +98,7 @@ void ternary_op_gpu_inplace(
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
|
@ -6,4 +6,5 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cpp)
|
||||
|
@ -243,7 +243,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
template_args.emplace_back(name, dtype);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -271,25 +271,24 @@ void init_fast(nb::module_& parent_module) {
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
Run the kernel.
|
||||
|
||||
Args:
|
||||
inputs (List[array]): The inputs passed to the Metal kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
Args:
|
||||
inputs (List[array]): The inputs passed to the Metal kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
List[array]: The list of output arrays.
|
||||
)pbdoc");
|
||||
Returns:
|
||||
List[array]: The list of output arrays.)pbdoc");
|
||||
},
|
||||
"name"_a,
|
||||
"input_names"_a,
|
||||
@ -306,16 +305,16 @@ void init_fast(nb::module_& parent_module) {
|
||||
input_names (List[str]): The parameter names of the inputs in the
|
||||
function signature.
|
||||
output_names (List[str]): The parameter names of the outputs in the
|
||||
function signature.
|
||||
function signature.
|
||||
source (str): Source code. This is the body of a function in Metal,
|
||||
the function signature will be automatically generated.
|
||||
the function signature will be automatically generated.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of
|
||||
the main function body.
|
||||
Useful for helper functions or includes that should live outside of
|
||||
the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
|
||||
Returns:
|
||||
Callable ``metal_kernel``.
|
||||
|
Loading…
Reference in New Issue
Block a user