mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous * style * nits in docs * a bunch more stuff * update jit * update jit * use constant for shapes and strides and remove elem_to_loc overload * use kernel instantiation * docs nits * update binary and ternary * comments
This commit is contained in:
parent
c6739ba7f3
commit
4f46e9c997
@ -13,6 +13,7 @@ simple functions.
|
|||||||
:template: nn-module-template.rst
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
elu
|
elu
|
||||||
|
celu
|
||||||
gelu
|
gelu
|
||||||
gelu_approx
|
gelu_approx
|
||||||
gelu_fast_approx
|
gelu_fast_approx
|
||||||
|
@ -13,6 +13,7 @@ Layers
|
|||||||
AvgPool1d
|
AvgPool1d
|
||||||
AvgPool2d
|
AvgPool2d
|
||||||
BatchNorm
|
BatchNorm
|
||||||
|
CELU
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
Conv3d
|
Conv3d
|
||||||
@ -23,6 +24,7 @@ Layers
|
|||||||
Dropout2d
|
Dropout2d
|
||||||
Dropout3d
|
Dropout3d
|
||||||
Embedding
|
Embedding
|
||||||
|
ELU
|
||||||
GELU
|
GELU
|
||||||
GLU
|
GLU
|
||||||
GroupNorm
|
GroupNorm
|
||||||
@ -34,6 +36,8 @@ Layers
|
|||||||
LayerNorm
|
LayerNorm
|
||||||
LeakyReLU
|
LeakyReLU
|
||||||
Linear
|
Linear
|
||||||
|
LogSigmoid
|
||||||
|
LogSoftmax
|
||||||
LSTM
|
LSTM
|
||||||
MaxPool1d
|
MaxPool1d
|
||||||
MaxPool2d
|
MaxPool2d
|
||||||
@ -49,6 +53,7 @@ Layers
|
|||||||
RoPE
|
RoPE
|
||||||
SELU
|
SELU
|
||||||
Sequential
|
Sequential
|
||||||
|
Sigmoid
|
||||||
SiLU
|
SiLU
|
||||||
SinusoidalPositionalEncoding
|
SinusoidalPositionalEncoding
|
||||||
Softmin
|
Softmin
|
||||||
|
@ -51,6 +51,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||||
|
|
||||||
if(IOS)
|
if(IOS)
|
||||||
|
88
mlx/backend/common/utils.cpp
Normal file
88
mlx/backend/common/utils.cpp
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
|
std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
||||||
|
collapse_contiguous_dims_impl(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::vector<std::vector<stride_t>>& strides,
|
||||||
|
stride_t size_cap) {
|
||||||
|
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||||
|
// -1.
|
||||||
|
std::vector<int> to_collapse;
|
||||||
|
if (shape.size() > 0) {
|
||||||
|
if (shape[0] != 1) {
|
||||||
|
to_collapse.push_back(0);
|
||||||
|
}
|
||||||
|
size_t size = shape[0];
|
||||||
|
for (int i = 1; i < shape.size(); i++) {
|
||||||
|
bool contiguous = true;
|
||||||
|
size *= shape[i];
|
||||||
|
for (const std::vector<stride_t>& st : strides) {
|
||||||
|
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
|
||||||
|
contiguous = false;
|
||||||
|
size = shape[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!contiguous) {
|
||||||
|
to_collapse.push_back(-1);
|
||||||
|
}
|
||||||
|
if (shape[i] != 1) {
|
||||||
|
to_collapse.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
to_collapse.push_back(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> out_shape;
|
||||||
|
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
||||||
|
for (int i = 0;;) {
|
||||||
|
while (i < to_collapse.size() && to_collapse[i] == -1) {
|
||||||
|
++i;
|
||||||
|
};
|
||||||
|
if (i == to_collapse.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
int current_shape = shape[to_collapse[i]];
|
||||||
|
int k = i;
|
||||||
|
while (to_collapse[++k] != -1) {
|
||||||
|
current_shape *= shape[to_collapse[k]];
|
||||||
|
}
|
||||||
|
out_shape.push_back(current_shape);
|
||||||
|
for (int j = 0; j < strides.size(); j++) {
|
||||||
|
const std::vector<stride_t>& st = strides[j];
|
||||||
|
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
||||||
|
}
|
||||||
|
i = k + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!shape.empty() && out_shape.empty()) {
|
||||||
|
out_shape.push_back(1);
|
||||||
|
for (auto& out_stride : out_strides) {
|
||||||
|
out_stride.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_tuple(out_shape, out_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||||
|
collapse_contiguous_dims(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::vector<std::vector<int64_t>>& strides,
|
||||||
|
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
||||||
|
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||||
|
collapse_contiguous_dims(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::vector<std::vector<size_t>>& strides,
|
||||||
|
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
|
||||||
|
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -44,58 +44,26 @@ std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
|
|||||||
//
|
//
|
||||||
// When multiple arrays are passed they should all have the same shape. The
|
// When multiple arrays are passed they should all have the same shape. The
|
||||||
// collapsed axes are also the same so one shape is returned.
|
// collapsed axes are also the same so one shape is returned.
|
||||||
template <typename stride_t>
|
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
||||||
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
|
||||||
collapse_contiguous_dims(
|
collapse_contiguous_dims(
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
const std::vector<std::vector<stride_t>> strides) {
|
const std::vector<std::vector<int64_t>>& strides,
|
||||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||||
// -1.
|
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||||
std::vector<int> to_collapse;
|
collapse_contiguous_dims(
|
||||||
if (shape.size() > 0) {
|
const std::vector<int>& shape,
|
||||||
to_collapse.push_back(0);
|
const std::vector<std::vector<size_t>>& strides,
|
||||||
for (int i = 1; i < shape.size(); i++) {
|
size_t size_cap = std::numeric_limits<int32_t>::max());
|
||||||
bool contiguous = true;
|
|
||||||
for (const std::vector<stride_t>& st : strides) {
|
|
||||||
if (st[i] * shape[i] != st[i - 1]) {
|
|
||||||
contiguous = false;
|
|
||||||
}
|
|
||||||
if (!contiguous) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!contiguous) {
|
|
||||||
to_collapse.push_back(-1);
|
|
||||||
}
|
|
||||||
to_collapse.push_back(i);
|
|
||||||
}
|
|
||||||
to_collapse.push_back(-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> out_shape;
|
|
||||||
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
|
||||||
for (int i = 0; i < to_collapse.size(); i++) {
|
|
||||||
int current_shape = shape[to_collapse[i]];
|
|
||||||
while (to_collapse[++i] != -1) {
|
|
||||||
current_shape *= shape[to_collapse[i]];
|
|
||||||
}
|
|
||||||
out_shape.push_back(current_shape);
|
|
||||||
for (int j = 0; j < strides.size(); j++) {
|
|
||||||
const std::vector<stride_t>& st = strides[j];
|
|
||||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(out_shape, out_strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& xs,
|
||||||
|
size_t size_cap = std::numeric_limits<size_t>::max()) {
|
||||||
std::vector<std::vector<size_t>> strides;
|
std::vector<std::vector<size_t>> strides;
|
||||||
for (auto& x : xs) {
|
for (auto& x : xs) {
|
||||||
strides.emplace_back(x.strides());
|
strides.emplace_back(x.strides());
|
||||||
}
|
}
|
||||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 3;
|
||||||
|
|
||||||
std::string get_kernel_name(
|
std::string get_kernel_name(
|
||||||
BinaryOpType bopt,
|
BinaryOpType bopt,
|
||||||
@ -69,46 +69,61 @@ void binary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try to collapse contiguous dims
|
// Try to collapse contiguous dims
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
auto maybe_collapse = [bopt, &a, &b, &out]() {
|
||||||
auto& strides_a = strides[0];
|
if (bopt == BinaryOpType::General) {
|
||||||
auto& strides_b = strides[1];
|
// The size cap here should ideally be `UINT32_MAX` but we are
|
||||||
auto& strides_out = strides[2];
|
// limitied by the shape being an int.
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(
|
||||||
|
{a, b, out},
|
||||||
|
/* size_cap = */ INT32_MAX);
|
||||||
|
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
|
||||||
|
} else {
|
||||||
|
std::vector<size_t> e;
|
||||||
|
return std::make_tuple(std::vector<int>{}, e, e, e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||||
|
|
||||||
bool use_2d = out.data_size() > UINT32_MAX;
|
bool use_2d = out.data_size() > UINT32_MAX;
|
||||||
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
auto kernel =
|
auto kernel = outputs.size() == 2
|
||||||
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
|
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
|
||||||
|
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// - If a is donated it goes to the first output
|
// - If a is donated it goes to the first output
|
||||||
// - If b is donated it goes to the first output if a was not donated
|
// - If b is donated it goes to the first output if a was not donated
|
||||||
// otherwise it goes to the second output
|
// otherwise it goes to the second output.
|
||||||
|
// - If there is only one output only one of a and b will be donated.
|
||||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
|
int arg_idx = 0;
|
||||||
|
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
|
||||||
compute_encoder.set_input_array(
|
compute_encoder.set_input_array(
|
||||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
|
||||||
compute_encoder.set_output_array(outputs[0], 2);
|
compute_encoder.set_output_array(outputs[0], arg_idx++);
|
||||||
compute_encoder.set_output_array(outputs[1], 3);
|
if (outputs.size() == 2) {
|
||||||
|
compute_encoder.set_output_array(outputs[1], arg_idx++);
|
||||||
|
}
|
||||||
|
|
||||||
if (bopt == BinaryOpType::General) {
|
if (bopt == BinaryOpType::General) {
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
if (ndim > 3) {
|
if (ndim > 3) {
|
||||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
compute_encoder->setBytes(
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
|
||||||
} else {
|
} else {
|
||||||
// The shape is implicit in the grid for <= 3D
|
// The shape is implicit in the grid for <= 3D
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
compute_encoder->setBytes(
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
|
||||||
}
|
compute_encoder->setBytes(
|
||||||
|
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
|
||||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch up to 3D grid of threads
|
// Launch up to 3D grid of threads
|
||||||
@ -125,9 +140,8 @@ void binary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
// Launch a 1D or 2D grid of threads
|
// Launch a 1D or 2D grid of threads
|
||||||
size_t nthreads = out.data_size();
|
size_t nthreads = out.data_size();
|
||||||
MTL::Size grid_dims = use_2d
|
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
: MTL::Size(nthreads, 1, 1);
|
||||||
: MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size > nthreads) {
|
if (thread_group_size > nthreads) {
|
||||||
thread_group_size = nthreads;
|
thread_group_size = nthreads;
|
||||||
@ -164,72 +178,8 @@ void binary_op_gpu_inplace(
|
|||||||
array& out,
|
array& out,
|
||||||
const std::string& op,
|
const std::string& op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& a = inputs[0];
|
std::vector<array> outputs = {out};
|
||||||
auto& b = inputs[1];
|
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
if (out.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to collapse contiguous dims
|
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
|
||||||
auto& strides_a = strides[0];
|
|
||||||
auto& strides_b = strides[1];
|
|
||||||
auto& strides_out = strides[2];
|
|
||||||
|
|
||||||
bool use_2d = out.data_size() > UINT32_MAX;
|
|
||||||
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
|
|
||||||
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
|
||||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
|
||||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
|
||||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
|
||||||
compute_encoder.set_output_array(out, 2);
|
|
||||||
|
|
||||||
if (bopt == BinaryOpType::General) {
|
|
||||||
auto ndim = shape.size();
|
|
||||||
if (ndim > 3) {
|
|
||||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
|
||||||
} else {
|
|
||||||
// The shape is implicit in the grid for <= 3D
|
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch up to 3D grid of threads
|
|
||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
|
||||||
size_t rest = out.size() / (dim0 * dim1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size != 1024) {
|
|
||||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
|
||||||
}
|
|
||||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
|
||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
} else {
|
|
||||||
// Launch a 1D or 2D grid of threads
|
|
||||||
|
|
||||||
size_t nthreads = out.data_size();
|
|
||||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
|
||||||
: MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > nthreads) {
|
|
||||||
thread_group_size = nthreads;
|
|
||||||
}
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void binary_op_gpu(
|
void binary_op_gpu(
|
||||||
|
@ -22,7 +22,8 @@ inline void build_kernel(
|
|||||||
const std::unordered_set<uintptr_t>& constant_ids,
|
const std::unordered_set<uintptr_t>& constant_ids,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim,
|
int ndim,
|
||||||
bool dynamic_dims) {
|
bool dynamic_dims,
|
||||||
|
bool use_big_index = false) {
|
||||||
// All outputs should have the exact same shape and will be row contiguous
|
// All outputs should have the exact same shape and will be row contiguous
|
||||||
auto output_shape = outputs[0].shape();
|
auto output_shape = outputs[0].shape();
|
||||||
auto output_strides = outputs[0].strides();
|
auto output_strides = outputs[0].strides();
|
||||||
@ -84,9 +85,15 @@ inline void build_kernel(
|
|||||||
|
|
||||||
// The thread index in the whole grid
|
// The thread index in the whole grid
|
||||||
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
|
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
|
||||||
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
|
<< " uint3 grid [[threads_per_grid]]) {" << std::endl;
|
||||||
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
|
if (use_big_index) {
|
||||||
<< std::endl;
|
// This is only used for contiguous kernels which don't have
|
||||||
|
// a third grid dimension
|
||||||
|
os << " size_t index = pos.x + grid.x * size_t(pos.y);";
|
||||||
|
} else {
|
||||||
|
os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);";
|
||||||
|
}
|
||||||
|
os << std::endl;
|
||||||
|
|
||||||
// Extract the indices per axis to individual uints if we have arrays that
|
// Extract the indices per axis to individual uints if we have arrays that
|
||||||
// are broadcasted or transposed
|
// are broadcasted or transposed
|
||||||
@ -212,6 +219,17 @@ void Compiled::eval_gpu(
|
|||||||
/* contiguous = */ true,
|
/* contiguous = */ true,
|
||||||
/* ndim = */ 0,
|
/* ndim = */ 0,
|
||||||
/* dynamic_dims = */ false);
|
/* dynamic_dims = */ false);
|
||||||
|
build_kernel(
|
||||||
|
kernel,
|
||||||
|
kernel_lib_ + "_contiguous_big",
|
||||||
|
inputs_,
|
||||||
|
outputs_,
|
||||||
|
tape_,
|
||||||
|
constant_ids_,
|
||||||
|
/* contiguous = */ true,
|
||||||
|
/* ndim = */ 0,
|
||||||
|
/* dynamic_dims = */ false,
|
||||||
|
/* use_big_index = */ true);
|
||||||
for (int i = 1; i < 8; i++) {
|
for (int i = 1; i < 8; i++) {
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
@ -285,7 +303,16 @@ void Compiled::eval_gpu(
|
|||||||
initial_strides.push_back(std::move(xstrides));
|
initial_strides.push_back(std::move(xstrides));
|
||||||
}
|
}
|
||||||
std::tie(shape, strides) =
|
std::tie(shape, strides) =
|
||||||
collapse_contiguous_dims(output_shape, initial_strides);
|
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool use_2d = false;
|
||||||
|
if (contiguous) {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
max_size = std::max(max_size, in.data_size());
|
||||||
|
}
|
||||||
|
use_2d = (max_size > UINT32_MAX);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel from the lib
|
// Get the kernel from the lib
|
||||||
@ -298,6 +325,8 @@ void Compiled::eval_gpu(
|
|||||||
} else {
|
} else {
|
||||||
kernel_name += std::to_string(shape.size());
|
kernel_name += std::to_string(shape.size());
|
||||||
}
|
}
|
||||||
|
} else if (use_2d) {
|
||||||
|
kernel_name += "_big";
|
||||||
}
|
}
|
||||||
auto kernel = d.get_kernel(kernel_name, lib);
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -348,8 +377,10 @@ void Compiled::eval_gpu(
|
|||||||
|
|
||||||
// Launch the kernel
|
// Launch the kernel
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
size_t nthreads = outputs[0].size();
|
size_t nthreads = outputs[0].data_size();
|
||||||
MTL::Size grid_dims(nthreads, 1, 1);
|
MTL::Size grid_dims = use_2d
|
||||||
|
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||||
|
: MTL::Size(nthreads, 1, 1);
|
||||||
MTL::Size group_dims(
|
MTL::Size group_dims(
|
||||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
@ -59,10 +59,20 @@ void copy_gpu_inplace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try to collapse contiguous dims
|
// Try to collapse contiguous dims
|
||||||
auto [shape, strides] = collapse_contiguous_dims(
|
auto maybe_collapse =
|
||||||
data_shape, std::vector{strides_in_pre, strides_out_pre});
|
[ctype, &data_shape, &strides_in_pre, &strides_out_pre]() {
|
||||||
auto& strides_in_ = strides[0];
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
auto& strides_out_ = strides[1];
|
auto [shape, strides] = collapse_contiguous_dims(
|
||||||
|
data_shape,
|
||||||
|
std::vector{strides_in_pre, strides_out_pre},
|
||||||
|
/* size_cap = */ INT32_MAX);
|
||||||
|
return std::make_tuple(shape, strides[0], strides[1]);
|
||||||
|
} else {
|
||||||
|
std::vector<stride_t> e;
|
||||||
|
return std::make_tuple(std::vector<int>{}, e, e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto [shape, strides_in_, strides_out_] = maybe_collapse();
|
||||||
|
|
||||||
bool use_2d = out.data_size() > UINT32_MAX;
|
bool use_2d = out.data_size() > UINT32_MAX;
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
@ -1,100 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
constexpr std::string_view copy_kernels = R"(
|
|
||||||
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
uint index [[thread_position_in_grid]]);
|
|
||||||
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
uint index [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
|
||||||
copy_g_nd<{1}, {2}, 4>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
|
||||||
template [[host_name("gg4_{0}")]] [[kernel]] void
|
|
||||||
copy_gg_nd<{1}, {2}, 4>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
uint3 index [[thread_position_in_grid]]);
|
|
||||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
|
||||||
copy_g_nd<{1}, {2}, 5>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
|
||||||
template [[host_name("gg5_{0}")]] [[kernel]] void
|
|
||||||
copy_gg_nd<{1}, {2}, 5>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
uint3 index [[thread_position_in_grid]]);
|
|
||||||
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int64_t& src_stride [[buffer(3)]],
|
|
||||||
uint index [[thread_position_in_grid]]);
|
|
||||||
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
uint2 index [[thread_position_in_grid]],
|
|
||||||
uint2 grid_dim [[threads_per_grid]]);
|
|
||||||
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
|
||||||
template [[host_name("gg1_{0}")]] [[kernel]] void
|
|
||||||
copy_gg_nd1<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int64_t& src_stride [[buffer(3)]],
|
|
||||||
constant const int64_t& dst_stride [[buffer(4)]],
|
|
||||||
uint index [[thread_position_in_grid]]);
|
|
||||||
template [[host_name("gg2_{0}")]] [[kernel]] void
|
|
||||||
copy_gg_nd2<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
uint2 index [[thread_position_in_grid]]);
|
|
||||||
template [[host_name("gg3_{0}")]] [[kernel]] void
|
|
||||||
copy_gg_nd3<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
uint3 index [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int& ndim [[buffer(5)]],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
|
||||||
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
|
|
||||||
device const {1}* src [[buffer(0)]],
|
|
||||||
device {2}* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
constant const int& ndim [[buffer(5)]],
|
|
||||||
uint3 index [[thread_position_in_grid]]);
|
|
||||||
)";
|
|
@ -1,9 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/metal/jit/arange.h"
|
#include "mlx/backend/metal/jit/arange.h"
|
||||||
#include "mlx/backend/metal/jit/copy.h"
|
|
||||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
#include "mlx/backend/metal/jit/scan.h"
|
#include "mlx/backend/metal/jit/scan.h"
|
||||||
@ -67,7 +65,7 @@ void add_binary_kernels(
|
|||||||
Dtype out_type,
|
Dtype out_type,
|
||||||
const std::string op,
|
const std::string op,
|
||||||
std::ostringstream& kernel_source) {
|
std::ostringstream& kernel_source) {
|
||||||
const std::map<std::string, std::string> kernel_types = {
|
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
|
||||||
{"ss", "binary_ss"},
|
{"ss", "binary_ss"},
|
||||||
{"vs", "binary_vs"},
|
{"vs", "binary_vs"},
|
||||||
{"sv", "binary_sv"},
|
{"sv", "binary_sv"},
|
||||||
@ -78,29 +76,16 @@ void add_binary_kernels(
|
|||||||
{"g1", "binary_g_nd1"},
|
{"g1", "binary_g_nd1"},
|
||||||
{"g2", "binary_g_nd2"},
|
{"g2", "binary_g_nd2"},
|
||||||
{"g3", "binary_g_nd3"},
|
{"g3", "binary_g_nd3"},
|
||||||
{"g4", "binary_g_nd"},
|
|
||||||
{"g5", "binary_g_nd"},
|
|
||||||
{"gn", "binary_g"},
|
{"gn", "binary_g"},
|
||||||
};
|
}};
|
||||||
for (auto [name, func] : kernel_types) {
|
for (auto& [name, func] : kernel_types) {
|
||||||
std::string template_def;
|
std::string template_def;
|
||||||
if (name == "g4" || name == "g5") {
|
template_def = get_template_definition(
|
||||||
int dim = std::stoi(name.substr(1));
|
name + lib_name,
|
||||||
template_def = get_template_definition(
|
func,
|
||||||
name + lib_name,
|
get_type_string(in_type),
|
||||||
func,
|
get_type_string(out_type),
|
||||||
get_type_string(in_type),
|
op);
|
||||||
get_type_string(out_type),
|
|
||||||
op,
|
|
||||||
dim);
|
|
||||||
} else {
|
|
||||||
template_def = get_template_definition(
|
|
||||||
name + lib_name,
|
|
||||||
func,
|
|
||||||
get_type_string(in_type),
|
|
||||||
get_type_string(out_type),
|
|
||||||
op);
|
|
||||||
}
|
|
||||||
kernel_source << template_def;
|
kernel_source << template_def;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -149,27 +134,19 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
|||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name);
|
||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
const std::map<std::string, std::string> kernel_types = {
|
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
|
||||||
{"v", "ternary_v"},
|
{"v", "ternary_v"},
|
||||||
{"v2", "ternary_v2"},
|
{"v2", "ternary_v2"},
|
||||||
{"g", "ternary_g"},
|
{"g", "ternary_g"},
|
||||||
{"g1", "ternary_g_nd1"},
|
{"g1", "ternary_g_nd1"},
|
||||||
{"g2", "ternary_g_nd2"},
|
{"g2", "ternary_g_nd2"},
|
||||||
{"g3", "ternary_g_nd3"},
|
{"g3", "ternary_g_nd3"},
|
||||||
{"g4", "ternary_g_nd"},
|
}};
|
||||||
{"g5", "ternary_g_nd"},
|
|
||||||
};
|
|
||||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
|
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
|
||||||
for (auto [name, func] : kernel_types) {
|
for (auto& [name, func] : kernel_types) {
|
||||||
std::string template_def;
|
std::string template_def;
|
||||||
if (name == "g4" || name == "g5") {
|
template_def = get_template_definition(
|
||||||
int dim = std::stoi(name.substr(1));
|
name + "_" + lib_name, func, get_type_string(type), op);
|
||||||
template_def = get_template_definition(
|
|
||||||
name + "_" + lib_name, func, get_type_string(type), op, dim);
|
|
||||||
} else {
|
|
||||||
template_def = get_template_definition(
|
|
||||||
name + "_" + lib_name, func, get_type_string(type), op);
|
|
||||||
}
|
|
||||||
kernel_source << template_def;
|
kernel_source << template_def;
|
||||||
}
|
}
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
@ -186,12 +163,27 @@ MTL::ComputePipelineState* get_copy_kernel(
|
|||||||
auto lib = d.get_library(lib_name);
|
auto lib = d.get_library(lib_name);
|
||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::copy()
|
auto in_type = get_type_string(in.dtype());
|
||||||
<< fmt::format(
|
auto out_type = get_type_string(out.dtype());
|
||||||
copy_kernels,
|
kernel_source
|
||||||
lib_name,
|
<< metal::utils() << metal::copy()
|
||||||
get_type_string(in.dtype()),
|
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
|
||||||
get_type_string(out.dtype()));
|
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
|
||||||
|
<< get_template_definition(
|
||||||
|
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
|
||||||
|
<< get_template_definition(
|
||||||
|
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
|
||||||
|
<< get_template_definition(
|
||||||
|
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
|
||||||
|
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
|
||||||
|
<< get_template_definition(
|
||||||
|
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
|
||||||
|
<< get_template_definition(
|
||||||
|
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
|
||||||
|
<< get_template_definition(
|
||||||
|
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
|
||||||
|
<< get_template_definition(
|
||||||
|
"gg_" + lib_name, "copy_gg", in_type, out_type);
|
||||||
lib = d.get_library(lib_name, kernel_source.str());
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
}
|
}
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@ -296,11 +288,11 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
|||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::sort();
|
kernel_source << metal::utils() << metal::sort();
|
||||||
std::vector<std::pair<std::string, std::string>> kernel_types = {
|
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
|
||||||
{"sort_", "mb_block_sort"},
|
{{"sort_", "mb_block_sort"},
|
||||||
{"partition_", "mb_block_partition"},
|
{"partition_", "mb_block_partition"},
|
||||||
{"merge_", "mb_block_merge"}};
|
{"merge_", "mb_block_merge"}}};
|
||||||
for (auto [name, func] : kernel_types) {
|
for (auto& [name, func] : kernel_types) {
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
name + lib_name,
|
name + lib_name,
|
||||||
func,
|
func,
|
||||||
|
@ -70,16 +70,16 @@ IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
|||||||
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
|
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op, int N_READS>
|
template <typename T, typename Op, int N_READS = 4>
|
||||||
[[kernel]] void arg_reduce_general(
|
[[kernel]] void arg_reduce_general(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device uint32_t* out [[buffer(1)]],
|
device uint32_t* out [[buffer(1)]],
|
||||||
const device int* shape [[buffer(2)]],
|
const constant int* shape [[buffer(2)]],
|
||||||
const device size_t* in_strides [[buffer(3)]],
|
const constant size_t* in_strides [[buffer(3)]],
|
||||||
const device size_t* out_strides [[buffer(4)]],
|
const constant size_t* out_strides [[buffer(4)]],
|
||||||
const device size_t& ndim [[buffer(5)]],
|
const constant size_t& ndim [[buffer(5)]],
|
||||||
const device size_t& axis_stride [[buffer(6)]],
|
const constant size_t& axis_stride [[buffer(6)]],
|
||||||
const device size_t& axis_size [[buffer(7)]],
|
const constant size_t& axis_size [[buffer(7)]],
|
||||||
uint gid [[thread_position_in_grid]],
|
uint gid [[thread_position_in_grid]],
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
uint lsize [[threads_per_threadgroup]],
|
uint lsize [[threads_per_threadgroup]],
|
||||||
@ -159,28 +159,12 @@ template <typename T, typename Op, int N_READS>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_arg_reduce_helper(name, itype, op) \
|
|
||||||
template [[host_name(name)]] [[kernel]] void \
|
|
||||||
arg_reduce_general<itype, op<itype>, 4>( \
|
|
||||||
const device itype* in [[buffer(0)]], \
|
|
||||||
device uint32_t* out [[buffer(1)]], \
|
|
||||||
const device int* shape [[buffer(2)]], \
|
|
||||||
const device size_t* in_strides [[buffer(3)]], \
|
|
||||||
const device size_t* out_strides [[buffer(4)]], \
|
|
||||||
const device size_t& ndim [[buffer(5)]], \
|
|
||||||
const device size_t& axis_stride [[buffer(6)]], \
|
|
||||||
const device size_t& axis_size [[buffer(7)]], \
|
|
||||||
uint gid [[thread_position_in_grid]], \
|
|
||||||
uint lid [[thread_position_in_threadgroup]], \
|
|
||||||
uint lsize [[threads_per_threadgroup]], \
|
|
||||||
uint simd_size [[threads_per_simdgroup]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#define instantiate_arg_reduce(name, itype) \
|
#define instantiate_arg_reduce(name, itype) \
|
||||||
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
instantiate_kernel( \
|
||||||
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
"argmin_" #name, arg_reduce_general, itype, ArgMin<itype>) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"argmax_" #name, arg_reduce_general, itype, ArgMax<itype>)
|
||||||
|
|
||||||
instantiate_arg_reduce(bool_, bool)
|
instantiate_arg_reduce(bool_, bool)
|
||||||
instantiate_arg_reduce(uint8, uint8_t)
|
instantiate_arg_reduce(uint8, uint8_t)
|
||||||
|
@ -93,7 +93,7 @@ template <typename T, typename U, typename Op>
|
|||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,26 +109,10 @@ template <typename T, typename U, typename Op>
|
|||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
size_t out_idx =
|
size_t out_idx =
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int DIM>
|
|
||||||
[[kernel]] void binary_g_nd(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
constant const int shape[DIM],
|
|
||||||
constant const size_t a_strides[DIM],
|
|
||||||
constant const size_t b_strides[DIM],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
|
||||||
size_t out_idx =
|
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
|
||||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void binary_g(
|
[[kernel]] void binary_g(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
@ -141,6 +125,7 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||||
}
|
}
|
||||||
|
@ -21,8 +21,6 @@
|
|||||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||||
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
|
|
||||||
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
|
|
||||||
|
|
||||||
#define instantiate_binary_integer(op) \
|
#define instantiate_binary_integer(op) \
|
||||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||||
|
@ -118,7 +118,7 @@ template <typename T, typename U, typename Op>
|
|||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||||
auto out = Op()(a[a_idx], b[b_idx]);
|
auto out = Op()(a[a_idx], b[b_idx]);
|
||||||
c[out_idx] = out[0];
|
c[out_idx] = out[0];
|
||||||
d[out_idx] = out[1];
|
d[out_idx] = out[1];
|
||||||
@ -137,31 +137,12 @@ template <typename T, typename U, typename Op>
|
|||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
size_t out_idx =
|
size_t out_idx =
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
auto out = Op()(a[a_idx], b[b_idx]);
|
auto out = Op()(a[a_idx], b[b_idx]);
|
||||||
c[out_idx] = out[0];
|
c[out_idx] = out[0];
|
||||||
d[out_idx] = out[1];
|
d[out_idx] = out[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int DIM>
|
|
||||||
[[kernel]] void binary_g_nd(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
constant const int shape[DIM],
|
|
||||||
constant const size_t a_strides[DIM],
|
|
||||||
constant const size_t b_strides[DIM],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
|
||||||
size_t out_idx =
|
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
|
||||||
auto out = Op()(a[idx.x], b[idx.y]);
|
|
||||||
c[out_idx] = out[0];
|
|
||||||
d[out_idx] = out[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void binary_g(
|
[[kernel]] void binary_g(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
@ -175,7 +156,8 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
auto out = Op()(a[idx.x], b[idx.y]);
|
auto out = Op()(a[idx.x], b[idx.y]);
|
||||||
c[out_idx] = out[0];
|
c[out_idx] = out[0];
|
||||||
d[out_idx] = out[1];
|
d[out_idx] = out[1];
|
||||||
|
@ -19,8 +19,6 @@
|
|||||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||||
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
|
|
||||||
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
|
|
||||||
|
|
||||||
#define instantiate_binary_float(op) \
|
#define instantiate_binary_float(op) \
|
||||||
instantiate_binary_all(op, float16, half, half) \
|
instantiate_binary_all(op, float16, half, half) \
|
||||||
|
@ -71,20 +71,6 @@ template <typename T, typename U>
|
|||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int DIM>
|
|
||||||
[[kernel]] void copy_g_nd(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
|
||||||
int64_t dst_idx =
|
|
||||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_g(
|
[[kernel]] void copy_g(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
@ -136,19 +122,6 @@ template <typename T, typename U>
|
|||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int DIM>
|
|
||||||
[[kernel]] void copy_gg_nd(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
|
||||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_gg(
|
[[kernel]] void copy_gg(
|
||||||
device const T* src [[buffer(0)]],
|
device const T* src [[buffer(0)]],
|
||||||
|
@ -16,10 +16,6 @@
|
|||||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||||
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
|
|
||||||
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
|
|
||||||
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
|
|
||||||
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
|
|
||||||
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
||||||
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
|
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
|
||||||
|
|
||||||
|
@ -69,9 +69,9 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
|||||||
device char* out,
|
device char* out,
|
||||||
device const bool& odd,
|
device const bool& odd,
|
||||||
device const uint& bytes_per_key,
|
device const uint& bytes_per_key,
|
||||||
device const int& ndim,
|
constant const int& ndim,
|
||||||
device const int* key_shape,
|
constant const int* key_shape,
|
||||||
device const size_t* key_strides,
|
constant const size_t* key_strides,
|
||||||
uint2 grid_dim [[threads_per_grid]],
|
uint2 grid_dim [[threads_per_grid]],
|
||||||
uint2 index [[thread_position_in_grid]]) {
|
uint2 index [[thread_position_in_grid]]) {
|
||||||
auto kidx = 2 * index.x;
|
auto kidx = 2 * index.x;
|
||||||
|
@ -342,9 +342,9 @@ template <
|
|||||||
const constant int& in_stride_sorted_axis [[buffer(3)]],
|
const constant int& in_stride_sorted_axis [[buffer(3)]],
|
||||||
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||||
const constant int& nc_dim [[buffer(5)]],
|
const constant int& nc_dim [[buffer(5)]],
|
||||||
const device int* nc_shape [[buffer(6)]],
|
const constant int* nc_shape [[buffer(6)]],
|
||||||
const device size_t* in_nc_strides [[buffer(7)]],
|
const constant size_t* in_nc_strides [[buffer(7)]],
|
||||||
const device size_t* out_nc_strides [[buffer(8)]],
|
const constant size_t* out_nc_strides [[buffer(8)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
using sort_kernel =
|
using sort_kernel =
|
||||||
@ -485,8 +485,8 @@ template <
|
|||||||
const constant int& size_sorted_axis [[buffer(3)]],
|
const constant int& size_sorted_axis [[buffer(3)]],
|
||||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||||
const constant int& nc_dim [[buffer(5)]],
|
const constant int& nc_dim [[buffer(5)]],
|
||||||
const device int* nc_shape [[buffer(6)]],
|
const constant int* nc_shape [[buffer(6)]],
|
||||||
const device size_t* nc_strides [[buffer(7)]],
|
const constant size_t* nc_strides [[buffer(7)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
using sort_kernel = KernelMultiBlockMergeSort<
|
using sort_kernel = KernelMultiBlockMergeSort<
|
||||||
|
@ -52,7 +52,7 @@ template <typename T, typename Op>
|
|||||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||||
auto c_idx = elem_to_loc_2(index, c_strides);
|
auto c_idx = elem_to_loc_2(index, c_strides);
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
|
||||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,29 +71,10 @@ template <typename T, typename Op>
|
|||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||||
size_t out_idx =
|
size_t out_idx =
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op, int DIM>
|
|
||||||
[[kernel]] void ternary_g_nd(
|
|
||||||
device const bool* a,
|
|
||||||
device const T* b,
|
|
||||||
device const T* c,
|
|
||||||
device T* d,
|
|
||||||
constant const int shape[DIM],
|
|
||||||
constant const size_t a_strides[DIM],
|
|
||||||
constant const size_t b_strides[DIM],
|
|
||||||
constant const size_t c_strides[DIM],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto idx =
|
|
||||||
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
|
||||||
size_t out_idx =
|
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
|
||||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
[[kernel]] void ternary_g(
|
[[kernel]] void ternary_g(
|
||||||
device const bool* a,
|
device const bool* a,
|
||||||
@ -109,6 +90,7 @@ template <typename T, typename Op>
|
|||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto idx =
|
auto idx =
|
||||||
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
size_t out_idx =
|
||||||
|
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
|
||||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||||
}
|
}
|
||||||
|
@ -16,8 +16,6 @@
|
|||||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
|
||||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
|
||||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
|
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
|
||||||
instantiate_kernel("g4_" #op #tname, ternary_g_nd, type, op, 4) \
|
|
||||||
instantiate_kernel("g5_" #op #tname, ternary_g_nd, type, op, 5)
|
|
||||||
|
|
||||||
#define instantiate_ternary_types(op) \
|
#define instantiate_ternary_types(op) \
|
||||||
instantiate_ternary_all(op, bool_, bool) \
|
instantiate_ternary_all(op, bool_, bool) \
|
||||||
|
@ -22,8 +22,8 @@ template <typename T, typename Op>
|
|||||||
[[kernel]] void unary_g(
|
[[kernel]] void unary_g(
|
||||||
device const T* in,
|
device const T* in,
|
||||||
device T* out,
|
device T* out,
|
||||||
device const int* in_shape,
|
constant const int* in_shape,
|
||||||
device const size_t* in_strides,
|
constant const size_t* in_strides,
|
||||||
device const int& ndim,
|
device const int& ndim,
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
|
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
|
||||||
|
@ -83,20 +83,6 @@ struct Limits<complex64_t> {
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Single Array with generic dims
|
// Single Array with generic dims
|
||||||
|
|
||||||
template <typename stride_t>
|
|
||||||
METAL_FUNC stride_t elem_to_loc(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const stride_t* strides,
|
|
||||||
int ndim) {
|
|
||||||
stride_t loc = 0;
|
|
||||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
|
||||||
loc += (elem % shape[i]) * strides[i];
|
|
||||||
elem /= shape[i];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename stride_t>
|
template <typename stride_t>
|
||||||
METAL_FUNC stride_t elem_to_loc(
|
METAL_FUNC stride_t elem_to_loc(
|
||||||
uint elem,
|
uint elem,
|
||||||
@ -111,20 +97,6 @@ METAL_FUNC stride_t elem_to_loc(
|
|||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename stride_t>
|
|
||||||
METAL_FUNC stride_t elem_to_loc(
|
|
||||||
stride_t elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const stride_t* strides,
|
|
||||||
int ndim) {
|
|
||||||
stride_t loc = 0;
|
|
||||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
|
||||||
loc += (elem % shape[i]) * strides[i];
|
|
||||||
elem /= shape[i];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename stride_t>
|
template <typename stride_t>
|
||||||
METAL_FUNC stride_t elem_to_loc(
|
METAL_FUNC stride_t elem_to_loc(
|
||||||
stride_t elem,
|
stride_t elem,
|
||||||
@ -174,78 +146,18 @@ elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
|
|||||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
METAL_FUNC size_t elem_to_loc_nd(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int d = NDIM - 2; d >= 0; --d) {
|
|
||||||
elem /= shape[d + 1];
|
|
||||||
loc += (elem % shape[d]) * strides[d];
|
|
||||||
}
|
|
||||||
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
METAL_FUNC size_t elem_to_loc_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int shape[NDIM],
|
|
||||||
constant const size_t strides[NDIM]) {
|
|
||||||
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
loc += (elem.z % shape[d]) * strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
METAL_FUNC int64_t elem_to_loc_nd(
|
|
||||||
uint elem,
|
|
||||||
constant const int shape[NDIM],
|
|
||||||
constant const int64_t strides[NDIM]) {
|
|
||||||
int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
|
||||||
|
|
||||||
MLX_MTL_PRAGMA_UNROLL
|
|
||||||
for (int d = NDIM - 2; d >= 0; --d) {
|
|
||||||
elem /= shape[d + 1];
|
|
||||||
loc += (elem % shape[d]) * strides[d];
|
|
||||||
}
|
|
||||||
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
METAL_FUNC int64_t elem_to_loc_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int shape[NDIM],
|
|
||||||
constant const int64_t strides[NDIM]) {
|
|
||||||
int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
loc += (elem.z % shape[d]) * strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Multiple Arrays with generic dims
|
// Multiple Arrays with generic dims
|
||||||
|
|
||||||
METAL_FUNC uint2 elem_to_loc_2_nd(
|
METAL_FUNC ulong2 elem_to_loc_2_nd(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const size_t* a_strides,
|
constant const size_t* a_strides,
|
||||||
constant const size_t* b_strides,
|
constant const size_t* b_strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
uint2 loc = {
|
ulong2 loc = {
|
||||||
static_cast<uint>(
|
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
|
||||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]};
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
uint l = elem.z % shape[d];
|
uint l = elem.z % shape[d];
|
||||||
loc.x += l * a_strides[d];
|
loc.x += l * a_strides[d];
|
||||||
@ -255,20 +167,17 @@ METAL_FUNC uint2 elem_to_loc_2_nd(
|
|||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
METAL_FUNC uint3 elem_to_loc_3_nd(
|
METAL_FUNC ulong3 elem_to_loc_3_nd(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const size_t* a_strides,
|
constant const size_t* a_strides,
|
||||||
constant const size_t* b_strides,
|
constant const size_t* b_strides,
|
||||||
constant const size_t* c_strides,
|
constant const size_t* c_strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
uint3 loc = {
|
ulong3 loc = {
|
||||||
static_cast<uint>(
|
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
|
||||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
|
||||||
static_cast<uint>(
|
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
|
||||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
|
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
uint l = elem.z % shape[d];
|
uint l = elem.z % shape[d];
|
||||||
loc.x += l * a_strides[d];
|
loc.x += l * a_strides[d];
|
||||||
@ -279,53 +188,6 @@ METAL_FUNC uint3 elem_to_loc_3_nd(
|
|||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Multiple Arrays with fixed N dims
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
METAL_FUNC uint2 elem_to_loc_2_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int shape[NDIM],
|
|
||||||
constant const size_t a_strides[NDIM],
|
|
||||||
constant const size_t b_strides[NDIM]) {
|
|
||||||
uint2 loc = {
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
uint l = elem.z % shape[d];
|
|
||||||
loc.x += l * a_strides[d];
|
|
||||||
loc.y += l * b_strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
METAL_FUNC uint3 elem_to_loc_3_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int shape[NDIM],
|
|
||||||
constant const size_t a_strides[NDIM],
|
|
||||||
constant const size_t b_strides[NDIM],
|
|
||||||
constant const size_t c_strides[NDIM]) {
|
|
||||||
uint3 loc = {
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
uint l = elem.z % shape[d];
|
|
||||||
loc.x += l * a_strides[d];
|
|
||||||
loc.y += l * b_strides[d];
|
|
||||||
loc.z += l * c_strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Elem to loc in a loop utils
|
// Elem to loc in a loop utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
|
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 3;
|
||||||
|
|
||||||
void ternary_op_gpu_inplace(
|
void ternary_op_gpu_inplace(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
@ -26,11 +26,21 @@ void ternary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try to collapse contiguous dims
|
// Try to collapse contiguous dims
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
auto maybe_collapse = [topt, &a, &b, &c, &out]() {
|
||||||
auto& strides_a = strides[0];
|
if (topt == TernaryOpType::General) {
|
||||||
auto& strides_b = strides[1];
|
// The size cap here should ideally be `UINT32_MAX` but we are
|
||||||
auto& strides_c = strides[2];
|
// limitied by the shape being an int.
|
||||||
auto& strides_out = strides[3];
|
auto [shape, strides] = collapse_contiguous_dims(
|
||||||
|
{a, b, c, out},
|
||||||
|
/* size_cap = */ INT32_MAX);
|
||||||
|
return std::make_tuple(
|
||||||
|
shape, strides[0], strides[1], strides[2], strides[3]);
|
||||||
|
} else {
|
||||||
|
std::vector<size_t> e;
|
||||||
|
return std::make_tuple(std::vector<int>{}, e, e, e, e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
||||||
|
|
||||||
bool use_2d = out.data_size() > UINT_MAX;
|
bool use_2d = out.data_size() > UINT_MAX;
|
||||||
std::string kernel_name;
|
std::string kernel_name;
|
||||||
@ -88,7 +98,7 @@ void ternary_op_gpu_inplace(
|
|||||||
size_t rest = out.size() / (dim0 * dim1);
|
size_t rest = out.size() / (dim0 * dim1);
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size != 1024) {
|
if (thread_group_size != 1024) {
|
||||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
|
||||||
}
|
}
|
||||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
|
@ -6,4 +6,5 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cpp)
|
||||||
|
@ -243,7 +243,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
template_args.emplace_back(name, dtype);
|
template_args.emplace_back(name, dtype);
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -271,25 +271,24 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
nb::sig(
|
nb::sig(
|
||||||
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Run the kernel.
|
Run the kernel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (List[array]): The inputs passed to the Metal kernel.
|
inputs (List[array]): The inputs passed to the Metal kernel.
|
||||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||||
By default, output arrays are uninitialized. Default: ``None``.
|
By default, output arrays are uninitialized. Default: ``None``.
|
||||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||||
when it is run. Default: ``False``.
|
when it is run. Default: ``False``.
|
||||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[array]: The list of output arrays.
|
List[array]: The list of output arrays.)pbdoc");
|
||||||
)pbdoc");
|
|
||||||
},
|
},
|
||||||
"name"_a,
|
"name"_a,
|
||||||
"input_names"_a,
|
"input_names"_a,
|
||||||
@ -306,16 +305,16 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
input_names (List[str]): The parameter names of the inputs in the
|
input_names (List[str]): The parameter names of the inputs in the
|
||||||
function signature.
|
function signature.
|
||||||
output_names (List[str]): The parameter names of the outputs in the
|
output_names (List[str]): The parameter names of the outputs in the
|
||||||
function signature.
|
function signature.
|
||||||
source (str): Source code. This is the body of a function in Metal,
|
source (str): Source code. This is the body of a function in Metal,
|
||||||
the function signature will be automatically generated.
|
the function signature will be automatically generated.
|
||||||
header (str): Header source code to include before the main function.
|
header (str): Header source code to include before the main function.
|
||||||
Useful for helper functions or includes that should live outside of
|
Useful for helper functions or includes that should live outside of
|
||||||
the main function body.
|
the main function body.
|
||||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||||
before the kernel runs. Default: ``True``.
|
before the kernel runs. Default: ``True``.
|
||||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||||
e.g. ``device atomic<float>``. Default: ``False``.
|
e.g. ``device atomic<float>``. Default: ``False``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Callable ``metal_kernel``.
|
Callable ``metal_kernel``.
|
||||||
|
Loading…
Reference in New Issue
Block a user