More fixes for arrays with large sizes (#1405)

* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
This commit is contained in:
Awni Hannun 2024-09-17 12:46:31 -07:00 committed by GitHub
parent c6739ba7f3
commit 4f46e9c997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 325 additions and 611 deletions

View File

@ -13,6 +13,7 @@ simple functions.
:template: nn-module-template.rst
elu
celu
gelu
gelu_approx
gelu_fast_approx

View File

@ -13,6 +13,7 @@ Layers
AvgPool1d
AvgPool2d
BatchNorm
CELU
Conv1d
Conv2d
Conv3d
@ -23,6 +24,7 @@ Layers
Dropout2d
Dropout3d
Embedding
ELU
GELU
GLU
GroupNorm
@ -34,6 +36,8 @@ Layers
LayerNorm
LeakyReLU
Linear
LogSigmoid
LogSoftmax
LSTM
MaxPool1d
MaxPool2d
@ -49,6 +53,7 @@ Layers
RoPE
SELU
Sequential
Sigmoid
SiLU
SinusoidalPositionalEncoding
Softmin

View File

@ -51,6 +51,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(IOS)

View File

@ -0,0 +1,88 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <typename stride_t>
std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<std::vector<stride_t>>& strides,
stride_t size_cap) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
if (shape[0] != 1) {
to_collapse.push_back(0);
}
size_t size = shape[0];
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
contiguous = false;
size = shape[i];
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
if (shape[i] != 1) {
to_collapse.push_back(i);
}
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0;;) {
while (i < to_collapse.size() && to_collapse[i] == -1) {
++i;
};
if (i == to_collapse.size()) {
break;
}
int current_shape = shape[to_collapse[i]];
int k = i;
while (to_collapse[++k] != -1) {
current_shape *= shape[to_collapse[k]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}
i = k + 1;
}
if (!shape.empty() && out_shape.empty()) {
out_shape.push_back(1);
for (auto& out_stride : out_strides) {
out_stride.push_back(0);
}
}
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
} // namespace mlx::core

View File

@ -44,58 +44,26 @@ std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
template <typename stride_t>
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<stride_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
collapse_contiguous_dims(
const std::vector<array>& xs,
size_t size_cap = std::numeric_limits<size_t>::max()) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>

View File

@ -19,7 +19,7 @@
namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 3;
std::string get_kernel_name(
BinaryOpType bopt,
@ -69,46 +69,61 @@ void binary_op_gpu_inplace(
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
auto maybe_collapse = [bopt, &a, &b, &out]() {
if (bopt == BinaryOpType::General) {
// The size cap here should ideally be `UINT32_MAX` but we are
// limitied by the shape being an int.
auto [shape, strides] = collapse_contiguous_dims(
{a, b, out},
/* size_cap = */ INT32_MAX);
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
} else {
std::vector<size_t> e;
return std::make_tuple(std::vector<int>{}, e, e, e);
}
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device);
auto kernel =
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
auto kernel = outputs.size() == 2
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
// otherwise it goes to the second output.
// - If there is only one output only one of a and b will be donated.
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
int arg_idx = 0;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
compute_encoder.set_output_array(outputs[0], 2);
compute_encoder.set_output_array(outputs[1], 3);
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
compute_encoder.set_output_array(outputs[0], arg_idx++);
if (outputs.size() == 2) {
compute_encoder.set_output_array(outputs[1], arg_idx++);
}
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
}
// Launch up to 3D grid of threads
@ -125,9 +140,8 @@ void binary_op_gpu_inplace(
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
@ -164,72 +178,8 @@ void binary_op_gpu_inplace(
array& out,
const std::string& op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
std::vector<array> outputs = {out};
binary_op_gpu_inplace(inputs, outputs, op, s);
}
void binary_op_gpu(

View File

@ -22,7 +22,8 @@ inline void build_kernel(
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim,
bool dynamic_dims) {
bool dynamic_dims,
bool use_big_index = false) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
@ -84,9 +85,15 @@ inline void build_kernel(
// The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
<< std::endl;
<< " uint3 grid [[threads_per_grid]]) {" << std::endl;
if (use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os << " size_t index = pos.x + grid.x * size_t(pos.y);";
} else {
os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);";
}
os << std::endl;
// Extract the indices per axis to individual uints if we have arrays that
// are broadcasted or transposed
@ -212,6 +219,17 @@ void Compiled::eval_gpu(
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_big",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ true);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
@ -285,7 +303,16 @@ void Compiled::eval_gpu(
initial_strides.push_back(std::move(xstrides));
}
std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides);
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool use_2d = false;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
use_2d = (max_size > UINT32_MAX);
}
// Get the kernel from the lib
@ -298,6 +325,8 @@ void Compiled::eval_gpu(
} else {
kernel_name += std::to_string(shape.size());
}
} else if (use_2d) {
kernel_name += "_big";
}
auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
@ -348,8 +377,10 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
size_t nthreads = outputs[0].size();
MTL::Size grid_dims(nthreads, 1, 1);
size_t nthreads = outputs[0].data_size();
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);

View File

@ -10,7 +10,7 @@
namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
@ -59,10 +59,20 @@ void copy_gpu_inplace(
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector{strides_in_pre, strides_out_pre});
auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1];
auto maybe_collapse =
[ctype, &data_shape, &strides_in_pre, &strides_out_pre]() {
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
auto [shape, strides] = collapse_contiguous_dims(
data_shape,
std::vector{strides_in_pre, strides_out_pre},
/* size_cap = */ INT32_MAX);
return std::make_tuple(shape, strides[0], strides[1]);
} else {
std::vector<stride_t> e;
return std::make_tuple(std::vector<int>{}, e, e);
}
};
auto [shape, strides_in_, strides_out_] = maybe_collapse();
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);

View File

@ -1,100 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view copy_kernels = R"(
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg4_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg5_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg1_{0}")]] [[kernel]] void
copy_gg_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]);
template [[host_name("gg2_{0}")]] [[kernel]] void
copy_gg_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]);
template [[host_name("gg3_{0}")]] [[kernel]] void
copy_gg_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]);
)";

View File

@ -1,9 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/scan.h"
@ -67,7 +65,7 @@ void add_binary_kernels(
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::map<std::string, std::string> kernel_types = {
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
@ -78,29 +76,16 @@ void add_binary_kernels(
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"g4", "binary_g_nd"},
{"g5", "binary_g_nd"},
{"gn", "binary_g"},
};
for (auto [name, func] : kernel_types) {
}};
for (auto& [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op,
dim);
} else {
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
}
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
}
}
@ -149,27 +134,19 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
const std::map<std::string, std::string> kernel_types = {
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
{"g4", "ternary_g_nd"},
{"g5", "ternary_g_nd"},
};
}};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto [name, func] : kernel_types) {
for (auto& [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op, dim);
} else {
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
}
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
}
lib = d.get_library(lib_name, kernel_source.str());
@ -186,12 +163,27 @@ MTL::ComputePipelineState* get_copy_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::copy()
<< fmt::format(
copy_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()));
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source
<< metal::utils() << metal::copy()
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"gg_" + lib_name, "copy_gg", in_type, out_type);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@ -296,11 +288,11 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort();
std::vector<std::pair<std::string, std::string>> kernel_types = {
{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}};
for (auto [name, func] : kernel_types) {
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
{{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}}};
for (auto& [name, func] : kernel_types) {
kernel_source << get_template_definition(
name + lib_name,
func,

View File

@ -70,16 +70,16 @@ IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
}
template <typename T, typename Op, int N_READS>
template <typename T, typename Op, int N_READS = 4>
[[kernel]] void arg_reduce_general(
const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]],
const device int* shape [[buffer(2)]],
const device size_t* in_strides [[buffer(3)]],
const device size_t* out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
const constant int* shape [[buffer(2)]],
const constant size_t* in_strides [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& ndim [[buffer(5)]],
const constant size_t& axis_stride [[buffer(6)]],
const constant size_t& axis_size [[buffer(7)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
@ -159,28 +159,12 @@ template <typename T, typename Op, int N_READS>
}
}
#define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] [[kernel]] void \
arg_reduce_general<itype, op<itype>, 4>( \
const device itype* in [[buffer(0)]], \
device uint32_t* out [[buffer(1)]], \
const device int* shape [[buffer(2)]], \
const device size_t* in_strides [[buffer(3)]], \
const device size_t* out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off
#define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
instantiate_kernel( \
"argmin_" #name, arg_reduce_general, itype, ArgMin<itype>) \
instantiate_kernel( \
"argmax_" #name, arg_reduce_general, itype, ArgMax<itype>)
instantiate_arg_reduce(bool_, bool)
instantiate_arg_reduce(uint8, uint8_t)

View File

@ -93,7 +93,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
@ -109,26 +109,10 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g(
device const T* a,
@ -141,6 +125,7 @@ template <typename T, typename U, typename Op>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}

View File

@ -21,8 +21,6 @@
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@ -118,7 +118,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
@ -137,31 +137,12 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g(
device const T* a,
@ -175,7 +156,8 @@ template <typename T, typename U, typename Op>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];

View File

@ -19,8 +19,6 @@
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \

View File

@ -71,20 +71,6 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
@ -136,19 +122,6 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],

View File

@ -16,10 +16,6 @@
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)

View File

@ -69,9 +69,9 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
device const int& ndim,
device const int* key_shape,
device const size_t* key_strides,
constant const int& ndim,
constant const int* key_shape,
constant const size_t* key_strides,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x;

View File

@ -342,9 +342,9 @@ template <
const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* in_nc_strides [[buffer(7)]],
const device size_t* out_nc_strides [[buffer(8)]],
const constant int* nc_shape [[buffer(6)]],
const constant size_t* in_nc_strides [[buffer(7)]],
const constant size_t* out_nc_strides [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
@ -485,8 +485,8 @@ template <
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
const constant int* nc_shape [[buffer(6)]],
const constant size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<

View File

@ -52,7 +52,7 @@ template <typename T, typename Op>
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
@ -71,29 +71,10 @@ template <typename T, typename Op>
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, int DIM>
[[kernel]] void ternary_g_nd(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
constant const size_t c_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
template <typename T, typename Op>
[[kernel]] void ternary_g(
device const bool* a,
@ -109,6 +90,7 @@ template <typename T, typename Op>
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}

View File

@ -16,8 +16,6 @@
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("g4_" #op #tname, ternary_g_nd, type, op, 4) \
instantiate_kernel("g5_" #op #tname, ternary_g_nd, type, op, 5)
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \

View File

@ -22,8 +22,8 @@ template <typename T, typename Op>
[[kernel]] void unary_g(
device const T* in,
device T* out,
device const int* in_shape,
device const size_t* in_strides,
constant const int* in_shape,
constant const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]) {
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);

View File

@ -83,20 +83,6 @@ struct Limits<complex64_t> {
///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint elem,
device const int* shape,
device const stride_t* strides,
int ndim) {
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint elem,
@ -111,20 +97,6 @@ METAL_FUNC stride_t elem_to_loc(
return loc;
}
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
device const int* shape,
device const stride_t* strides,
int ndim) {
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
@ -174,78 +146,18 @@ elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
template <int NDIM>
METAL_FUNC size_t elem_to_loc_nd(
uint elem,
device const int* shape,
device const size_t* strides) {
size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
MLX_MTL_PRAGMA_UNROLL
for (int d = NDIM - 2; d >= 0; --d) {
elem /= shape[d + 1];
loc += (elem % shape[d]) * strides[d];
}
return loc;
}
template <int NDIM>
METAL_FUNC size_t elem_to_loc_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t strides[NDIM]) {
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
METAL_FUNC int64_t elem_to_loc_nd(
uint elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM]) {
int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
MLX_MTL_PRAGMA_UNROLL
for (int d = NDIM - 2; d >= 0; --d) {
elem /= shape[d + 1];
loc += (elem % shape[d]) * strides[d];
}
return loc;
}
template <int NDIM>
METAL_FUNC int64_t elem_to_loc_nd(
uint3 elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM]) {
int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims
METAL_FUNC uint2 elem_to_loc_2_nd(
METAL_FUNC ulong2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
ulong2 loc = {
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
@ -255,20 +167,17 @@ METAL_FUNC uint2 elem_to_loc_2_nd(
return loc;
}
METAL_FUNC uint3 elem_to_loc_3_nd(
METAL_FUNC ulong3 elem_to_loc_3_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
int ndim) {
uint3 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
static_cast<uint>(
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
ulong3 loc = {
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
@ -279,53 +188,6 @@ METAL_FUNC uint3 elem_to_loc_3_nd(
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with fixed N dims
template <int NDIM>
METAL_FUNC uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM]) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
METAL_FUNC uint3 elem_to_loc_3_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM],
constant const size_t c_strides[NDIM]) {
uint3 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
static_cast<uint>(
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
elem.z /= shape[d];
}
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Elem to loc in a loop utils
///////////////////////////////////////////////////////////////////////////////

View File

@ -8,7 +8,7 @@
namespace mlx::core {
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 3;
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
@ -26,11 +26,21 @@ void ternary_op_gpu_inplace(
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_c = strides[2];
auto& strides_out = strides[3];
auto maybe_collapse = [topt, &a, &b, &c, &out]() {
if (topt == TernaryOpType::General) {
// The size cap here should ideally be `UINT32_MAX` but we are
// limitied by the shape being an int.
auto [shape, strides] = collapse_contiguous_dims(
{a, b, c, out},
/* size_cap = */ INT32_MAX);
return std::make_tuple(
shape, strides[0], strides[1], strides[2], strides[3]);
} else {
std::vector<size_t> e;
return std::make_tuple(std::vector<int>{}, e, e, e, e);
}
};
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT_MAX;
std::string kernel_name;
@ -88,7 +98,7 @@ void ternary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
}
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);

View File

@ -6,4 +6,5 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cpp)

View File

@ -243,7 +243,7 @@ void init_fast(nb::module_& parent_module) {
template_args.emplace_back(name, dtype);
} else {
throw std::invalid_argument(
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
"[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
@ -271,25 +271,24 @@ void init_fast(nb::module_& parent_module) {
nb::sig(
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Run the kernel.
Args:
inputs (List[array]): The inputs passed to the Metal kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Args:
inputs (List[array]): The inputs passed to the Metal kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
List[array]: The list of output arrays.
)pbdoc");
Returns:
List[array]: The list of output arrays.)pbdoc");
},
"name"_a,
"input_names"_a,
@ -306,16 +305,16 @@ void init_fast(nb::module_& parent_module) {
input_names (List[str]): The parameter names of the inputs in the
function signature.
output_names (List[str]): The parameter names of the outputs in the
function signature.
function signature.
source (str): Source code. This is the body of a function in Metal,
the function signature will be automatically generated.
the function signature will be automatically generated.
header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of
the main function body.
Useful for helper functions or includes that should live outside of
the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
e.g. ``device atomic<float>``. Default: ``False``.
Returns:
Callable ``metal_kernel``.