Kernel generation (#614)

Generate reusable element-wise kernels given a computation graph.
This commit is contained in:
Angelos Katharopoulos
2024-02-07 13:15:59 -08:00
committed by GitHub
parent 5fd11c347d
commit 28eac18571
19 changed files with 1302 additions and 459 deletions

View File

@@ -117,16 +117,18 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (xs[0].ndim() > 0) {
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < xs[0].ndim(); i++) {
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (auto& x : xs) {
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
for (const std::vector<size_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
@@ -142,21 +144,31 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
}
std::vector<int> out_shape;
std::vector<std::vector<size_t>> out_strides(xs.size());
std::vector<std::vector<size_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = xs[0].shape()[to_collapse[i]];
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= xs[0].shape()[to_collapse[i]];
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < xs.size(); j++) {
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
for (int j = 0; j < strides.size(); j++) {
const std::vector<size_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays>
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) {