Faster cpu ops (#1434)

* faster binary and cleaner copy

* use recursive template for other ops

* more cleanup

* fix from cleanup

* more clean

* fix binary

* use contiguous iterator

* add 3d

* nits

* fix

* fix?

* fix

* fix rebase
This commit is contained in:
Awni Hannun
2024-09-26 09:19:13 -07:00
committed by GitHub
parent 0b4a58699e
commit 5b6f38df2b
12 changed files with 590 additions and 1347 deletions

View File

@@ -71,128 +71,46 @@ void set_ternary_op_output_data(
break;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
void ternary_op_dims(
const T1* a,
const T2* b,
const T3* c,
U* out,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& c_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_c = c_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims1(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
c_idx += c.strides()[0];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims2(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
c_idx += c.strides()[1];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
a,
b,
c,
out,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
axis + 1);
} else {
*out = op(*a, *b, *c);
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims3(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
c_idx += c.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims4(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
c_idx += c.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
a += stride_a;
b += stride_b;
c += stride_c;
out += stride_out;
}
}
@@ -203,30 +121,69 @@ void ternary_op_dispatch_dims(
const array& c,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 2:
ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 3:
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 4:
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
}
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& c_strides = strides[2];
const auto& out_strides = strides[3];
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
int c_idx = elem_to_loc(i, c.shape(), c.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
U* out_ptr = out.data<T3>();
int ndim = shape.size();
switch (ndim) {
case 1:
ternary_op_dims<T1, T2, T3, U, Op, 1>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
case 2:
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
c_ptr + c_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
c_it.step();
}
}
@@ -243,10 +200,21 @@ void ternary_op(
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
return;
} else if (topt == TernaryOpType::VectorVectorVector) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
} else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
} // namespace