mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster cpu ops (#1434)
* faster binary and cleaner copy * use recursive template for other ops * more cleanup * fix from cleanup * more clean * fix binary * use contiguous iterator * add 3d * nits * fix * fix? * fix * fix rebase
This commit is contained in:
@@ -122,19 +122,7 @@ void set_binary_op_output_data(
|
||||
}
|
||||
}
|
||||
|
||||
struct UseDefaultBinaryOp {
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
// Should we throw? This should normally never be called.
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
// Should we throw? This should normally never be called.
|
||||
assert(false);
|
||||
}
|
||||
};
|
||||
struct UseDefaultBinaryOp {};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorScalar {
|
||||
@@ -150,18 +138,6 @@ struct DefaultVectorScalar {
|
||||
a++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
T scalar = *b;
|
||||
while (size-- > 0) {
|
||||
auto dst = op(*a, scalar);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@@ -178,18 +154,6 @@ struct DefaultScalarVector {
|
||||
b++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
T scalar = *a;
|
||||
while (size-- > 0) {
|
||||
auto dst = op(scalar, *b);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
@@ -206,204 +170,110 @@ struct DefaultVectorVector {
|
||||
b++;
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||
while (size-- > 0) {
|
||||
auto dst = op(*a, *b);
|
||||
*dst_a = dst.first;
|
||||
*dst_b = dst.second;
|
||||
dst_a++;
|
||||
dst_b++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
template <typename T, typename U, typename Op, int D, bool Strided>
|
||||
void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
dst += stride;
|
||||
}
|
||||
}
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
dst += stride;
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[2];
|
||||
b_idx += b.strides()[2];
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
} else {
|
||||
if constexpr (Strided) {
|
||||
op(a, b, out, stride_out);
|
||||
} else {
|
||||
*out = op(*a, *b);
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
out += stride_out;
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[3];
|
||||
b_idx += b.strides()[3];
|
||||
}
|
||||
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op) {
|
||||
switch (out.ndim()) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims3<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
case 4:
|
||||
binary_op_dims4<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); i++) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <typename T, typename U, bool Strided, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
int stride) {
|
||||
// Number of dimensions to loop over for vectorized ops
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out, op, stride);
|
||||
binary_op_dims<T, U, Op, 1, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out, op, stride);
|
||||
binary_op_dims<T, U, Op, 2, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); i += stride) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||
dst += stride;
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
|
||||
size_t stride = out_strides[dim - 4];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
dim - 3);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -450,29 +320,33 @@ void binary_op(
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
const auto& a_strides = new_strides[0];
|
||||
const auto& b_strides = new_strides[1];
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto& strides = out.strides();
|
||||
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a);
|
||||
auto b_rc_dim = leftmost_rc_dim(b);
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a);
|
||||
auto b_s_dim = leftmost_s_dim(b);
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = out.ndim();
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
@@ -494,27 +368,27 @@ void binary_op(
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, op);
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -531,9 +405,9 @@ void binary_op(
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
@@ -554,7 +428,8 @@ void binary_op(
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
@@ -569,7 +444,8 @@ void binary_op(
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
@@ -585,7 +461,8 @@ void binary_op(
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||
|
||||
Reference in New Issue
Block a user