Faster cpu ops (#1434)

* faster binary and cleaner copy

* use recursive template for other ops

* more cleanup

* fix from cleanup

* more clean

* fix binary

* use contiguous iterator

* add 3d

* nits

* fix

* fix?

* fix

* fix rebase
This commit is contained in:
Awni Hannun
2024-09-26 09:19:13 -07:00
committed by GitHub
parent 0b4a58699e
commit 5b6f38df2b
12 changed files with 590 additions and 1347 deletions

View File

@@ -73,11 +73,7 @@ void binary_op_gpu_inplace(
// Try to collapse contiguous dims
auto maybe_collapse = [bopt, &a, &b, &out]() {
if (bopt == BinaryOpType::General) {
// The size cap here should ideally be `UINT32_MAX` but we are
// limitied by the shape being an int.
auto [shape, strides] = collapse_contiguous_dims(
{a, b, out},
/* size_cap = */ INT32_MAX);
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
} else {
std::vector<size_t> e;

View File

@@ -26,11 +26,7 @@ void ternary_op_gpu_inplace(
// Try to collapse contiguous dims
auto maybe_collapse = [topt, &a, &b, &c, &out]() {
if (topt == TernaryOpType::General) {
// The size cap here should ideally be `UINT32_MAX` but we are
// limitied by the shape being an int.
auto [shape, strides] = collapse_contiguous_dims(
{a, b, c, out},
/* size_cap = */ INT32_MAX);
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
return std::make_tuple(
shape, strides[0], strides[1], strides[2], strides[3]);
} else {

View File

@@ -28,10 +28,7 @@ void unary_op_gpu_inplace(
auto maybe_collapse = [contig, &in, &out]() {
if (!contig) {
auto [shape, strides] = collapse_contiguous_dims(
{in, out},
/* size_cap = */ INT32_MAX);
return std::make_pair(shape, strides[0]);
return collapse_contiguous_dims(in);
} else {
return std::make_pair(std::vector<int>{}, std::vector<size_t>{});
}