diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index 3777f4bdd..3898e1d40 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -122,19 +122,7 @@ void set_binary_op_output_data( } } -struct UseDefaultBinaryOp { - template - void operator()(const T* a, const T* b, U* dst, int size) { - // Should we throw? This should normally never be called. - assert(false); - } - - template - void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { - // Should we throw? This should normally never be called. - assert(false); - } -}; +struct UseDefaultBinaryOp {}; template struct DefaultVectorScalar { @@ -150,18 +138,6 @@ struct DefaultVectorScalar { a++; } } - - void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { - T scalar = *b; - while (size-- > 0) { - auto dst = op(*a, scalar); - *dst_a = dst.first; - *dst_b = dst.second; - dst_a++; - dst_b++; - a++; - } - } }; template @@ -178,18 +154,6 @@ struct DefaultScalarVector { b++; } } - - void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { - T scalar = *a; - while (size-- > 0) { - auto dst = op(scalar, *b); - *dst_a = dst.first; - *dst_b = dst.second; - dst_a++; - dst_b++; - b++; - } - } }; template @@ -206,204 +170,110 @@ struct DefaultVectorVector { b++; } } - - void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { - while (size-- > 0) { - auto dst = op(*a, *b); - *dst_a = dst.first; - *dst_b = dst.second; - dst_a++; - dst_b++; - a++; - b++; - } - } }; -template -void binary_op_dims1(const array& a, const array& b, array& out, Op op) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - for (size_t i = 0; i < out.size(); ++i) { - dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]); - a_idx += a.strides()[0]; - b_idx += b.strides()[0]; - } -} - -template -void binary_op_dims1( - const array& a, - const array& b, - array& out, +template +void binary_op_dims( + const T* a, + const T* b, + U* out, Op op, - int stride) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - for (size_t i = 0; i < a.shape()[0]; i++) { - op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); - a_idx += a.strides()[0]; - b_idx += b.strides()[0]; - dst += stride; - } -} + const std::vector& shape, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; -template -void binary_op_dims2(const array& a, const array& b, array& out, Op op) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t out_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]); - a_idx += a.strides()[1]; - b_idx += b.strides()[1]; - } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; - } -} - -template -void binary_op_dims2( - const array& a, - const array& b, - array& out, - Op op, - int stride) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); - a_idx += a.strides()[1]; - b_idx += b.strides()[1]; - dst += stride; - } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; - } -} - -template -void binary_op_dims3(const array& a, const array& b, array& out, Op op) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t out_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - for (size_t k = 0; k < a.shape()[2]; ++k) { - dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]); - a_idx += a.strides()[2]; - b_idx += b.strides()[2]; + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + binary_op_dims( + a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1); + } else { + if constexpr (Strided) { + op(a, b, out, stride_out); + } else { + *out = op(*a, *b); } - a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; - b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + out += stride_out; + a += stride_a; + b += stride_b; } } -template -void binary_op_dims4(const array& a, const array& b, array& out, Op op) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t out_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - for (size_t k = 0; k < a.shape()[2]; ++k) { - for (size_t ii = 0; ii < a.shape()[3]; ++ii) { - dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]); - a_idx += a.strides()[3]; - b_idx += b.strides()[3]; - } - a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3]; - b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3]; - } - a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; - b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; - } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; - } -} - -template -void binary_op_dispatch_dims( - const array& a, - const array& b, - array& out, - Op op) { - switch (out.ndim()) { - case 1: - binary_op_dims1(a, b, out, op); - return; - case 2: - binary_op_dims2(a, b, out, op); - return; - case 3: - binary_op_dims3(a, b, out, op); - return; - case 4: - binary_op_dims4(a, b, out, op); - return; - } - - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst = out.data(); - for (size_t i = 0; i < out.size(); i++) { - int a_idx = elem_to_loc(i, a.shape(), a.strides()); - int b_idx = elem_to_loc(i, b.shape(), b.strides()); - dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]); - } -} - -template +template void binary_op_dispatch_dims( const array& a, const array& b, array& out, Op op, int dim, - int stride) { - // Number of dimensions to loop over for vectorized ops + const std::vector& shape, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& out_strides) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* out_ptr = out.data(); switch (dim) { case 1: - binary_op_dims1(a, b, out, op, stride); + binary_op_dims( + a_ptr, + b_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); return; case 2: - binary_op_dims2(a, b, out, op, stride); + binary_op_dims( + a_ptr, + b_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + case 3: + binary_op_dims( + a_ptr, + b_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); return; } - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst = out.data(); - for (size_t i = 0; i < out.size(); i += stride) { - int a_idx = elem_to_loc(i, a.shape(), a.strides()); - int b_idx = elem_to_loc(i, b.shape(), b.strides()); - op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); - dst += stride; + ContiguousIterator a_it(shape, a_strides, dim - 3); + ContiguousIterator b_it(shape, b_strides, dim - 3); + size_t stride = out_strides[dim - 4]; + for (size_t elem = 0; elem < a.size(); elem += stride) { + binary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + out_ptr + elem, + op, + shape, + a_strides, + b_strides, + out_strides, + dim - 3); + a_it.step(); + b_it.step(); } } @@ -450,29 +320,33 @@ void binary_op( } // General computation so let's try to optimize + auto [new_shape, new_strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), out.strides()}); + const auto& a_strides = new_strides[0]; + const auto& b_strides = new_strides[1]; + const auto& strides = new_strides[2]; // Get the left-most dim such that the array is row contiguous after - auto& strides = out.strides(); - auto leftmost_rc_dim = [&strides](const array& arr) { - int d = arr.ndim() - 1; - for (; d >= 0 && arr.strides()[d] == strides[d]; d--) { + auto leftmost_rc_dim = [&strides](const std::vector& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == strides[d]; d--) { } return d + 1; }; - auto a_rc_dim = leftmost_rc_dim(a); - auto b_rc_dim = leftmost_rc_dim(b); + auto a_rc_dim = leftmost_rc_dim(a_strides); + auto b_rc_dim = leftmost_rc_dim(b_strides); // Get the left-most dim such that the array is a broadcasted "scalar" after - auto leftmost_s_dim = [](const array& arr) { - int d = arr.ndim() - 1; - for (; d >= 0 && arr.strides()[d] == 0; d--) { + auto leftmost_s_dim = [](const std::vector& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == 0; d--) { } return d + 1; }; - auto a_s_dim = leftmost_s_dim(a); - auto b_s_dim = leftmost_s_dim(b); + auto a_s_dim = leftmost_s_dim(a_strides); + auto b_s_dim = leftmost_s_dim(b_strides); - auto ndim = out.ndim(); + auto ndim = new_shape.size(); // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous int dim = ndim; @@ -494,27 +368,27 @@ void binary_op( // Can be sure dim > 0 since otherwise we would have used one of the fully // contiguous methods above. Except for the case that the flags do not // correspond to the underlying contiguity. - size_t stride; if (dim == 0 || strides[dim - 1] < 16) { - stride = 1; bopt = BinaryOpType::General; dim = ndim; - } else { - stride = strides[dim - 1]; } switch (bopt) { case BinaryOpType::VectorVector: - binary_op_dispatch_dims(a, b, out, opvv, dim, stride); + binary_op_dispatch_dims( + a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides); break; case BinaryOpType::VectorScalar: - binary_op_dispatch_dims(a, b, out, opvs, dim, stride); + binary_op_dispatch_dims( + a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides); break; case BinaryOpType::ScalarVector: - binary_op_dispatch_dims(a, b, out, opsv, dim, stride); + binary_op_dispatch_dims( + a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides); break; default: - binary_op_dispatch_dims(a, b, out, op); + binary_op_dispatch_dims( + a, b, out, op, dim, new_shape, a_strides, b_strides, strides); break; } } @@ -531,9 +405,9 @@ void binary_op( // TODO: The following mess of constexpr evaluations can probably be achieved // with template specializations and overloading. Would it be simpler? - if (std::is_same::value) { - if (std::is_same::value) { - if (std::is_same::value) { + if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { // All ops are UseDefaultBinaryOp (why oh why would someone call that?) binary_op( a, @@ -554,7 +428,8 @@ void binary_op( DefaultVectorScalar(op), opvv); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same:: + value) { // opsv and opvv were UseDefaultBinaryOp binary_op( a, @@ -569,7 +444,8 @@ void binary_op( binary_op( a, b, out, op, DefaultScalarVector(op), opvs, opvv); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same:: + value) { if (std::is_same::value) { // opvs and opvv were UseDefaultBinaryOp binary_op( @@ -585,7 +461,8 @@ void binary_op( binary_op( a, b, out, op, opsv, DefaultVectorScalar(op), opvv); } - } else if (std::is_same::value) { + } else if constexpr (std::is_same:: + value) { // opvv was UseDefaultBinaryOp binary_op( a, b, out, op, opsv, opvs, DefaultVectorVector(op)); diff --git a/mlx/backend/common/binary_two.h b/mlx/backend/common/binary_two.h index 3ce2f7110..e9740f8aa 100644 --- a/mlx/backend/common/binary_two.h +++ b/mlx/backend/common/binary_two.h @@ -9,168 +9,43 @@ namespace mlx::core { namespace { -template -void binary_op_dims1( - const array& a, - const array& b, - array& out_a, - array& out_b, - Op op) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst_a = out_a.data(); - U* dst_b = out_b.data(); - size_t a_idx = 0; - size_t b_idx = 0; - for (size_t i = 0; i < out_a.size(); ++i) { - auto dst = op(a_ptr[a_idx], b_ptr[b_idx]); - dst_a[i] = dst.first; - dst_b[i] = dst.second; - a_idx += a.strides()[0]; - b_idx += b.strides()[0]; - } -} - -template -void binary_op_dims1( - const array& a, - const array& b, - array& out_a, - array& out_b, +template +void binary_op_dims( + const T* a, + const T* b, + U* out_a, + U* out_b, Op op, - int stride) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst_a = out_a.data(); - U* dst_b = out_b.data(); - size_t a_idx = 0; - size_t b_idx = 0; - for (size_t i = 0; i < a.shape()[0]; i++) { - op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride); - a_idx += a.strides()[0]; - b_idx += b.strides()[0]; - dst_a += stride; - dst_b += stride; - } -} + const std::vector& shape, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; -template -void binary_op_dims2( - const array& a, - const array& b, - array& out_a, - array& out_b, - Op op) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst_a = out_a.data(); - U* dst_b = out_b.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t out_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - auto dst = op(a_ptr[a_idx], b_ptr[b_idx]); - dst_a[out_idx] = dst.first; - dst_b[out_idx++] = dst.second; - a_idx += a.strides()[1]; - b_idx += b.strides()[1]; + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + binary_op_dims( + a, + b, + out_a, + out_b, + op, + shape, + a_strides, + b_strides, + out_strides, + axis + 1); + } else { + std::tie(*out_a, *out_b) = op(*a, *b); } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; - } -} - -template -void binary_op_dims2( - const array& a, - const array& b, - array& out_a, - array& out_b, - Op op, - int stride) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst_a = out_a.data(); - U* dst_b = out_b.data(); - size_t a_idx = 0; - size_t b_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride); - a_idx += a.strides()[1]; - b_idx += b.strides()[1]; - dst_a += stride; - dst_b += stride; - } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; - } -} - -template -void binary_op_dims3( - const array& a, - const array& b, - array& out_a, - array& out_b, - Op op) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst_a = out_a.data(); - U* dst_b = out_b.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t out_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - for (size_t k = 0; k < a.shape()[2]; ++k) { - auto dst = op(a_ptr[a_idx], b_ptr[b_idx]); - dst_a[out_idx] = dst.first; - dst_b[out_idx++] = dst.second; - a_idx += a.strides()[2]; - b_idx += b.strides()[2]; - } - a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; - b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; - } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; - } -} - -template -void binary_op_dims4( - const array& a, - const array& b, - array& out_a, - array& out_b, - Op op) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst_a = out_a.data(); - U* dst_b = out_b.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t out_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - for (size_t k = 0; k < a.shape()[2]; ++k) { - for (size_t ii = 0; ii < a.shape()[3]; ++ii) { - auto dst = op(a_ptr[a_idx], b_ptr[b_idx]); - dst_a[out_idx] = dst.first; - dst_b[out_idx++] = dst.second; - a_idx += a.strides()[3]; - b_idx += b.strides()[3]; - } - a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3]; - b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3]; - } - a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; - b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; - } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + a += stride_a; + b += stride_b; + out_a += stride_out; + out_b += stride_out; } } @@ -181,352 +56,160 @@ void binary_op_dispatch_dims( array& out_a, array& out_b, Op op) { - switch (out_a.ndim()) { + auto [shape, strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), out_a.strides()}); + const auto& a_strides = strides[0]; + const auto& b_strides = strides[1]; + const auto& out_strides = strides[2]; + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* out_a_ptr = out_a.data(); + U* out_b_ptr = out_b.data(); + + int ndim = shape.size(); + switch (ndim) { case 1: - binary_op_dims1(a, b, out_a, out_b, op); + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); return; case 2: - binary_op_dims2(a, b, out_a, out_b, op); - return; - case 3: - binary_op_dims3(a, b, out_a, out_b, op); - return; - case 4: - binary_op_dims4(a, b, out_a, out_b, op); + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); return; } - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst_a = out_a.data(); - U* dst_b = out_b.data(); - for (size_t i = 0; i < out_a.size(); i++) { - int a_idx = elem_to_loc(i, a.shape(), a.strides()); - int b_idx = elem_to_loc(i, b.shape(), b.strides()); - std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]); + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + size_t stride = out_strides[ndim - 3]; + for (size_t elem = 0; elem < a.size(); elem += stride) { + binary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + out_a_ptr + elem, + out_b_ptr + elem, + op, + shape, + a_strides, + b_strides, + out_strides, + ndim - 2); + a_it.step(); + b_it.step(); } } -template -void binary_op_dispatch_dims( - const array& a, - const array& b, - array& out_a, - array& out_b, - Op op, - int dim, - int stride) { - // Number of dimensions to loop over for vectorized ops - switch (dim) { - case 1: - binary_op_dims1(a, b, out_a, out_b, op, stride); - return; - case 2: - binary_op_dims2(a, b, out_a, out_b, op, stride); - return; - } - - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* dst_a = out_a.data(); - U* dst_b = out_b.data(); - for (size_t i = 0; i < out_a.size(); i += stride) { - int a_idx = elem_to_loc(i, a.shape(), a.strides()); - int b_idx = elem_to_loc(i, b.shape(), b.strides()); - op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride); - dst_a += stride; - dst_b += stride; - } -} - -template < - typename T, - typename U, - typename Op, - typename OpSV, - typename OpVS, - typename OpVV> +template void binary_op( const array& a, const array& b, - array& out_a, - array& out_b, - Op op, - OpSV opsv, - OpVS opvs, - OpVV opvv) { + std::vector& outputs, + Op op) { auto bopt = get_binary_op_type(a, b); + auto& out_a = outputs[0]; + auto& out_b = outputs[1]; set_binary_op_output_data(a, b, out_a, bopt); set_binary_op_output_data(a, b, out_b, bopt); // The full computation is scalar scalar so call the base op once + if (bopt == BinaryOpType::General) { + binary_op_dispatch_dims(a, b, out_a, out_b, op); + return; + } + + auto a_ptr = a.data(); + auto b_ptr = b.data(); + auto out_a_ptr = out_a.data(); + auto out_b_ptr = out_b.data(); if (bopt == BinaryOpType::ScalarScalar) { - std::tie(*(out_a.data()), *(out_b.data())) = - op(*a.data(), *b.data()); - return; - } - - // The full computation is scalar vector so delegate to the op - if (bopt == BinaryOpType::ScalarVector) { - opsv( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - b.data_size()); - return; - } - - // The full computation is vector scalar so delegate to the op - if (bopt == BinaryOpType::VectorScalar) { - opvs( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - a.data_size()); - return; - } - - // The full computation is vector vector so delegate to the op - if (bopt == BinaryOpType::VectorVector) { - opvv( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size()); - return; - } - - // General computation so let's try to optimize - - // Get the left-most dim such that the array is row contiguous after - auto& strides = out_a.strides(); - auto leftmost_rc_dim = [&strides](const array& arr) { - int d = arr.ndim() - 1; - for (; d >= 0 && arr.strides()[d] == strides[d]; d--) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + } else if (bopt == BinaryOpType::ScalarVector) { + for (size_t i = 0; i < b.size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + b_ptr++; } - return d + 1; - }; - auto a_rc_dim = leftmost_rc_dim(a); - auto b_rc_dim = leftmost_rc_dim(b); - - // Get the left-most dim such that the array is a broadcasted "scalar" after - auto leftmost_s_dim = [](const array& arr) { - int d = arr.ndim() - 1; - for (; d >= 0 && arr.strides()[d] == 0; d--) { + } else if (bopt == BinaryOpType::VectorScalar) { + for (size_t i = 0; i < a.size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + } + } else { // VectorVector + for (size_t i = 0; i < a.size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + b_ptr++; } - return d + 1; - }; - auto a_s_dim = leftmost_s_dim(a); - auto b_s_dim = leftmost_s_dim(b); - - auto ndim = out_a.ndim(); - - // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous - int dim = ndim; - if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { - bopt = BinaryOpType::VectorVector; - dim = d; - // Case 2: LxM and Fx1 where L and F are broadcastable and M is row - // contiguous - } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { - bopt = BinaryOpType::VectorScalar; - dim = d; - // Case 3: Lx1 and FxM where L and F are broadcastable and M is row - // contiguous - } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { - bopt = BinaryOpType::ScalarVector; - dim = d; - } - - // Can be sure dim > 0 since otherwise we would have used one of the fully - // contiguous methods above. Except for the case that the flags do not - // correspond to the underlying contiguity. - size_t stride; - if (dim == 0 || strides[dim - 1] < 16) { - stride = 1; - bopt = BinaryOpType::General; - dim = ndim; - } else { - stride = strides[dim - 1]; - } - - switch (bopt) { - case BinaryOpType::VectorVector: - binary_op_dispatch_dims(a, b, out_a, out_b, opvv, dim, stride); - break; - case BinaryOpType::VectorScalar: - binary_op_dispatch_dims(a, b, out_a, out_b, opvs, dim, stride); - break; - case BinaryOpType::ScalarVector: - binary_op_dispatch_dims(a, b, out_a, out_b, opsv, dim, stride); - break; - default: - binary_op_dispatch_dims(a, b, out_a, out_b, op); - break; } } -template -void binary_op( - const array& a, - const array& b, - std::vector& outputs, - Op op, - OpSV opsv, - OpVS opvs, - OpVV opvv) { - // TODO: The following mess of constexpr evaluations can probably be achieved - // with template specializations and overloading. Would it be simpler? - - if (std::is_same::value) { - if (std::is_same::value) { - if (std::is_same::value) { - // All ops are UseDefaultBinaryOp (why oh why would someone call that?) - binary_op( - a, - b, - outputs[0], - outputs[1], - op, - DefaultScalarVector(op), - DefaultVectorScalar(op), - DefaultVectorVector(op)); - } else { - // opsv and opvs were UseDefaultBinaryOp - binary_op( - a, - b, - outputs[0], - outputs[1], - op, - DefaultScalarVector(op), - DefaultVectorScalar(op), - opvv); - } - } else if (std::is_same::value) { - // opsv and opvv were UseDefaultBinaryOp - binary_op( - a, - b, - outputs[0], - outputs[1], - op, - DefaultScalarVector(op), - opvs, - DefaultVectorVector(op)); - } else { - // opsv was UseDefaultBinaryOp - binary_op( - a, - b, - outputs[0], - outputs[1], - op, - DefaultScalarVector(op), - opvs, - opvv); - } - } else if (std::is_same::value) { - if (std::is_same::value) { - // opvs and opvv were UseDefaultBinaryOp - binary_op( - a, - b, - outputs[0], - outputs[1], - op, - opsv, - DefaultVectorScalar(op), - DefaultVectorVector(op)); - } else { - // opvs was UseDefaultBinaryOp - binary_op( - a, - b, - outputs[0], - outputs[1], - op, - opsv, - DefaultVectorScalar(op), - opvv); - } - } else if (std::is_same::value) { - // opvv was UseDefaultBinaryOp - binary_op( - a, - b, - outputs[0], - outputs[1], - op, - opsv, - opvs, - DefaultVectorVector(op)); - } else { - // All ops provided - binary_op(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv); - } -} - -template -void binary_op( - const array& a, - const array& b, - std::vector& outputs, - Op op) { - DefaultScalarVector opsv(op); - DefaultVectorScalar opvs(op); - DefaultVectorVector opvv(op); - binary_op(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv); -} - -template +template void binary( const array& a, const array& b, std::vector& outputs, - Ops... ops) { + Op op) { switch (outputs[0].dtype()) { case bool_: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case uint8: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case uint16: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case uint32: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case uint64: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case int8: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case int16: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case int32: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case int64: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case float16: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case float32: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case bfloat16: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; case complex64: - binary_op(a, b, outputs, ops...); + binary_op(a, b, outputs, op); break; } } diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 6fb0e9edb..a01d94eb3 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -156,8 +156,7 @@ std::pair> Reshape::prepare_reshape( } // Firstly let's collapse all the contiguous dimensions of the input - auto [shape, _strides] = collapse_contiguous_dims(in); - auto& strides = _strides[0]; + auto [shape, strides] = collapse_contiguous_dims(in); // If shapes fit exactly in the contiguous dims then no copy is necessary so // let's check. diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index ff0d00df5..31448e1c6 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -26,465 +26,117 @@ void copy_vector(const array& src, array& dst) { std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); } -template -void copy_general_dim1( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - int64_t i_offset) { - const SrcT* src_ptr = src.data(); - DstT* dst_ptr = dst.data(); - stride_t src_idx = i_offset; - stride_t dst_idx = 0; - for (int i = 0; i < data_shape[0]; ++i) { - dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += i_strides[0]; - } -} +template +inline void copy_dims( + const SrcT* src, + DstT* dst, + const std::vector& shape, + const std::vector& i_strides, + const std::vector& o_strides, + int axis) { + auto stride_src = i_strides[axis]; + auto stride_dst = o_strides[axis]; + auto N = shape[axis]; -template -inline void copy_general_dim1(const array& src, array& dst) { - return copy_general_dim1( - src, dst, src.shape(), src.strides(), 0); -} - -template -void copy_general_dim2( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - int64_t i_offset) { - const SrcT* src_ptr = src.data(); - DstT* dst_ptr = dst.data(); - stride_t src_idx = i_offset; - stride_t dst_idx = 0; - for (int i = 0; i < data_shape[0]; ++i) { - for (int j = 0; j < data_shape[1]; ++j) { - dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += i_strides[1]; + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + copy_dims( + src, dst, shape, i_strides, o_strides, axis + 1); + } else { + *dst = static_cast(*src); } - src_idx += i_strides[0] - i_strides[1] * data_shape[1]; + src += stride_src; + dst += stride_dst; } } -template -inline void copy_general_dim2(const array& src, array& dst) { - return copy_general_dim2( - src, dst, src.shape(), src.strides(), 0); -} - -template -void copy_general_dim3( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - int64_t i_offset) { - const SrcT* src_ptr = src.data(); - DstT* dst_ptr = dst.data(); - stride_t src_idx = i_offset; - stride_t dst_idx = 0; - for (int i = 0; i < data_shape[0]; ++i) { - for (int j = 0; j < data_shape[1]; ++j) { - for (int k = 0; k < data_shape[2]; ++k) { - dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += i_strides[2]; - } - src_idx += i_strides[1] - i_strides[2] * data_shape[2]; - } - src_idx += i_strides[0] - i_strides[1] * data_shape[1]; - } -} - -template -inline void copy_general_dim3(const array& src, array& dst) { - return copy_general_dim3( - src, dst, src.shape(), src.strides(), 0); -} - -template -void copy_general_dim4( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - int64_t i_offset) { - const SrcT* src_ptr = src.data(); - DstT* dst_ptr = dst.data(); - stride_t src_idx = i_offset; - stride_t dst_idx = 0; - for (int i = 0; i < data_shape[0]; ++i) { - for (int j = 0; j < data_shape[1]; ++j) { - for (int k = 0; k < data_shape[2]; ++k) { - for (int ii = 0; ii < data_shape[3]; ++ii) { - dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += i_strides[3]; - } - src_idx += i_strides[2] - i_strides[3] * data_shape[3]; - } - src_idx += i_strides[1] - i_strides[2] * data_shape[2]; - } - src_idx += i_strides[0] - i_strides[1] * data_shape[1]; - } -} - -template -inline void copy_general_dim4(const array& src, array& dst) { - return copy_general_dim4( - src, dst, src.shape(), src.strides(), 0); -} - -template -void copy_general_dim5( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - int64_t i_offset) { - const SrcT* src_ptr = src.data() + i_offset; - DstT* dst_ptr = dst.data(); - - // Pre-compute loop bounds and strides - const int d0 = data_shape[0], d1 = data_shape[1], d2 = data_shape[2], - d3 = data_shape[3], d4 = data_shape[4]; - const stride_t s0 = i_strides[0], s1 = i_strides[1], s2 = i_strides[2], - s3 = i_strides[3], s4 = i_strides[4]; - - // Pre-compute stride adjustments - const stride_t s3_adj = s3 - s4 * d4; - const stride_t s2_adj = s2 - s3 * d3; - const stride_t s1_adj = s1 - s2 * d2; - const stride_t s0_adj = s0 - s1 * d1; - - stride_t src_idx = 0; - stride_t dst_idx = 0; - - for (int i = 0; i < d0; ++i) { - for (int j = 0; j < d1; ++j) { - for (int k = 0; k < d2; ++k) { - for (int l = 0; l < d3; ++l) { - for (int m = 0; m < d4; ++m) { - dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += s4; - } - src_idx += s3_adj; - } - src_idx += s2_adj; - } - src_idx += s1_adj; - } - src_idx += s0_adj; - } -} - -template -inline void copy_general_dim5(const array& src, array& dst) { - return copy_general_dim5( - src, dst, src.shape(), src.strides(), 0); -} - -template -void copy_general_dim6( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - int64_t i_offset) { - const SrcT* src_ptr = src.data() + i_offset; - DstT* dst_ptr = dst.data(); - - // Pre-compute loop bounds and strides - const int d0 = data_shape[0], d1 = data_shape[1], d2 = data_shape[2], - d3 = data_shape[3], d4 = data_shape[4], d5 = data_shape[5]; - const stride_t s0 = i_strides[0], s1 = i_strides[1], s2 = i_strides[2], - s3 = i_strides[3], s4 = i_strides[4], s5 = i_strides[5]; - - // Pre-compute stride adjustments - const stride_t s4_adj = s4 - s5 * d5; - const stride_t s3_adj = s3 - s4 * d4; - const stride_t s2_adj = s2 - s3 * d3; - const stride_t s1_adj = s1 - s2 * d2; - const stride_t s0_adj = s0 - s1 * d1; - - stride_t src_idx = 0; - stride_t dst_idx = 0; - - for (int i = 0; i < d0; ++i) { - for (int j = 0; j < d1; ++j) { - for (int k = 0; k < d2; ++k) { - for (int l = 0; l < d3; ++l) { - for (int m = 0; m < d4; ++m) { - for (int n = 0; n < d5; ++n) { - dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += s5; - } - src_idx += s4_adj; - } - src_idx += s3_adj; - } - src_idx += s2_adj; - } - src_idx += s1_adj; - } - src_idx += s0_adj; - } -} - -template -inline void copy_general_dim6(const array& src, array& dst) { - return copy_general_dim6( - src, dst, src.shape(), src.strides(), 0); -} - -template -void copy_general_dim7( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - int64_t i_offset) { - const SrcT* src_ptr = src.data() + i_offset; - DstT* dst_ptr = dst.data(); - - // Pre-compute loop bounds and strides - const int d0 = data_shape[0], d1 = data_shape[1], d2 = data_shape[2], - d3 = data_shape[3], d4 = data_shape[4], d5 = data_shape[5], - d6 = data_shape[6]; - const stride_t s0 = i_strides[0], s1 = i_strides[1], s2 = i_strides[2], - s3 = i_strides[3], s4 = i_strides[4], s5 = i_strides[5], - s6 = i_strides[6]; - - // Pre-compute stride adjustments - const stride_t s5_adj = s5 - s6 * d6; - const stride_t s4_adj = s4 - s5 * d5; - const stride_t s3_adj = s3 - s4 * d4; - const stride_t s2_adj = s2 - s3 * d3; - const stride_t s1_adj = s1 - s2 * d2; - const stride_t s0_adj = s0 - s1 * d1; - - stride_t src_idx = 0; - stride_t dst_idx = 0; - - for (int i = 0; i < d0; ++i) { - for (int j = 0; j < d1; ++j) { - for (int k = 0; k < d2; ++k) { - for (int l = 0; l < d3; ++l) { - for (int m = 0; m < d4; ++m) { - for (int n = 0; n < d5; ++n) { - for (int p = 0; p < d6; ++p) { - dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += s6; - } - src_idx += s5_adj; - } - src_idx += s4_adj; - } - src_idx += s3_adj; - } - src_idx += s2_adj; - } - src_idx += s1_adj; - } - src_idx += s0_adj; - } -} - -template -inline void copy_general_dim7(const array& src, array& dst) { - return copy_general_dim7( - src, dst, src.shape(), src.strides(), 0); -} - -template -void copy_general( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - int64_t i_offset) { - auto [new_shape, new_strides] = collapse_contiguous_dims( - data_shape, std::vector>{i_strides}); - switch (new_shape.size()) { - case 1: - copy_general_dim1( - src, dst, new_shape, new_strides[0], i_offset); - return; - case 2: - copy_general_dim2( - src, dst, new_shape, new_strides[0], i_offset); - return; - case 3: - copy_general_dim3( - src, dst, new_shape, new_strides[0], i_offset); - return; - case 4: - copy_general_dim4( - src, dst, new_shape, new_strides[0], i_offset); - return; - case 5: - copy_general_dim5( - src, dst, new_shape, new_strides[0], i_offset); - return; - case 6: - copy_general_dim6( - src, dst, new_shape, new_strides[0], i_offset); - return; - case 7: - copy_general_dim7( - src, dst, new_shape, new_strides[0], i_offset); - return; - } - - auto src_ptr = src.data() + i_offset; - auto dst_ptr = dst.data(); - for (size_t i = 0; i < dst.size(); ++i) { - stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]); - dst_ptr[i] = static_cast(src_ptr[src_elem]); - } -} - -template -inline void copy_general(const array& src, array& dst) { - return copy_general( - src, dst, src.shape(), src.strides(), 0); -} - -template -inline void copy_general( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, - int64_t i_offset, - int64_t o_offset) { - return copy_general( - src, dst, data_shape, i_strides, i_offset); -} - -template -inline void copy_general_general_dims( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, - int64_t i_offset, - int64_t o_offset) { - if constexpr (D > 1) { - int axis = data_shape.size() - D; - auto stride_src = i_strides[axis]; - auto stride_dst = o_strides[axis]; - auto N = data_shape[axis]; - for (int i = 0; i < N; i++) { - copy_general_general_dims( - src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); - i_offset += stride_src; - o_offset += stride_dst; - } - } else { - int axis = data_shape.size() - 1; - auto stride_src = i_strides[axis]; - auto stride_dst = o_strides[axis]; - auto N = data_shape[axis]; - const SrcT* src_ptr = src.data() + i_offset; - DstT* dst_ptr = dst.data() + o_offset; - for (int i = 0; i < N; i++) { - *dst_ptr = static_cast(*src_ptr); - src_ptr += stride_src; - dst_ptr += stride_dst; - } - } -} - -template +template void copy_general_general( const array& src, array& dst, const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, + const std::vector& i_strides, + const std::vector& o_strides, int64_t i_offset, int64_t o_offset) { - auto [new_shape, new_strides] = collapse_contiguous_dims( - data_shape, std::vector>{i_strides, o_strides}); - switch (new_shape.size()) { - case 1: - copy_general_general_dims( - src, - dst, - new_shape, - new_strides[0], - new_strides[1], - i_offset, - o_offset); - return; - case 2: - copy_general_general_dims( - src, - dst, - new_shape, - new_strides[0], - new_strides[1], - i_offset, - o_offset); - return; - case 3: - copy_general_general_dims( - src, - dst, - new_shape, - new_strides[0], - new_strides[1], - i_offset, - o_offset); - return; - case 4: - copy_general_general_dims( - src, - dst, - new_shape, - new_strides[0], - new_strides[1], - i_offset, - o_offset); - return; - case 5: - copy_general_general_dims( - src, - dst, - new_shape, - new_strides[0], - new_strides[1], - i_offset, - o_offset); - return; + if (data_shape.empty()) { + auto val = static_cast(*(src.data() + i_offset)); + auto dst_ptr = dst.data() + o_offset; + *dst_ptr = val; + return; } - - int size = std::accumulate( - new_shape.end() - 5, new_shape.end(), 1, std::multiplies()); - for (int i = 0; i < src.size(); i += size) { - stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]); - stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]); - copy_general_general_dims( - src, - dst, - new_shape, - new_strides[0], - new_strides[1], - src_offset, - dst_offset); + auto [shape, strides] = collapse_contiguous_dims( + data_shape, std::vector>{i_strides, o_strides}); + auto src_ptr = src.data() + i_offset; + auto dst_ptr = dst.data() + o_offset; + int ndim = shape.size(); + if (ndim == 1) { + copy_dims( + src_ptr, dst_ptr, shape, strides[0], strides[1], 0); + return; + } else if (ndim == 2) { + copy_dims( + src_ptr, dst_ptr, shape, strides[0], strides[1], 0); + return; + } else if (ndim == 3) { + copy_dims( + src_ptr, dst_ptr, shape, strides[0], strides[1], 0); + return; + } + ContiguousIterator in(shape, strides[0], ndim - 3); + ContiguousIterator out(shape, strides[1], ndim - 3); + StrideT stride = std::accumulate( + shape.end() - 3, shape.end(), 1, std::multiplies()); + for (StrideT elem = 0; elem < src.size(); elem += stride) { + copy_dims( + src_ptr + in.loc, + dst_ptr + out.loc, + shape, + strides[0], + strides[1], + ndim - 3); + in.step(); + out.step(); } } template inline void copy_general_general(const array& src, array& dst) { - return copy_general_general( + copy_general_general( src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); } +template +void copy_general( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + const std::vector&, + int64_t i_offset, + int64_t o_offset) { + copy_general_general( + src, + dst, + data_shape, + i_strides, + make_contiguous_strides(data_shape), + i_offset, + o_offset); +} + +template +inline void copy_general(const array& src, array& dst) { + copy_general_general( + src, + dst, + src.shape(), + src.strides(), + make_contiguous_strides(src.shape()), + 0, + 0); +} + template void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { switch (ctype) { @@ -499,6 +151,7 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { return; case CopyType::GeneralGeneral: copy_general_general(src, dst, std::forward(args)...); + return; } } @@ -599,7 +252,7 @@ inline void copy_inplace_dispatch( } // namespace void copy_inplace(const array& src, array& dst, CopyType ctype) { - return copy_inplace_dispatch(src, dst, ctype); + copy_inplace_dispatch(src, dst, ctype); } void copy(const array& src, array& dst, CopyType ctype) { @@ -629,20 +282,20 @@ void copy(const array& src, array& dst, CopyType ctype) { copy_inplace(src, dst, ctype); } -template +template void copy_inplace( const array& src, array& dst, const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, + const std::vector& i_strides, + const std::vector& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype) { switch (ctype) { case CopyType::General: case CopyType::GeneralGeneral: - return copy_inplace_dispatch( + copy_inplace_dispatch( src, dst, ctype, @@ -651,10 +304,10 @@ void copy_inplace( o_strides, i_offset, o_offset); - + break; case CopyType::Scalar: case CopyType::Vector: - return copy_inplace_dispatch(src, dst, ctype); + copy_inplace_dispatch(src, dst, ctype); } } diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 23c5efa19..f015f9995 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -406,16 +406,7 @@ void Reshape::eval(const std::vector& inputs, array& out) { if (copy_necessary) { out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto out_strides = make_contiguous_strides(in.shape()); - copy_inplace( - in, - out, - in.shape(), - in.strides(), - out_strides, - 0, - 0, - CopyType::General); + copy_inplace(in, out, CopyType::General); } else { shared_buffer_reshape(in, out_strides, out); } diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index f3696c7a5..dcd5a8676 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -71,128 +71,46 @@ void set_ternary_op_output_data( break; } } +template +void ternary_op_dims( + const T1* a, + const T2* b, + const T3* c, + U* out, + Op op, + const std::vector& shape, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& c_strides, + const std::vector& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_c = c_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; -template -void ternary_op_dims1( - const array& a, - const array& b, - const array& c, - array& out, - Op op) { - const T1* a_ptr = a.data(); - const T2* b_ptr = b.data(); - const T3* c_ptr = c.data(); - - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t c_idx = 0; - for (size_t i = 0; i < out.size(); ++i) { - dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); - a_idx += a.strides()[0]; - b_idx += b.strides()[0]; - c_idx += c.strides()[0]; - } -} - -template -void ternary_op_dims2( - const array& a, - const array& b, - const array& c, - array& out, - Op op) { - const T1* a_ptr = a.data(); - const T2* b_ptr = b.data(); - const T3* c_ptr = c.data(); - - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t c_idx = 0; - size_t out_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); - a_idx += a.strides()[1]; - b_idx += b.strides()[1]; - c_idx += c.strides()[1]; + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + ternary_op_dims( + a, + b, + c, + out, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + axis + 1); + } else { + *out = op(*a, *b, *c); } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; - c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; - } -} - -template -void ternary_op_dims3( - const array& a, - const array& b, - const array& c, - array& out, - Op op) { - const T1* a_ptr = a.data(); - const T2* b_ptr = b.data(); - const T3* c_ptr = c.data(); - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t c_idx = 0; - size_t out_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - for (size_t k = 0; k < a.shape()[2]; ++k) { - dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); - a_idx += a.strides()[2]; - b_idx += b.strides()[2]; - c_idx += c.strides()[2]; - } - a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; - b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; - c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2]; - } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; - c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; - } -} - -template -void ternary_op_dims4( - const array& a, - const array& b, - const array& c, - array& out, - Op op) { - const T1* a_ptr = a.data(); - const T2* b_ptr = b.data(); - const T3* c_ptr = c.data(); - - U* dst = out.data(); - size_t a_idx = 0; - size_t b_idx = 0; - size_t c_idx = 0; - size_t out_idx = 0; - for (size_t i = 0; i < a.shape()[0]; ++i) { - for (size_t j = 0; j < a.shape()[1]; ++j) { - for (size_t k = 0; k < a.shape()[2]; ++k) { - for (size_t ii = 0; ii < a.shape()[3]; ++ii) { - dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); - a_idx += a.strides()[3]; - b_idx += b.strides()[3]; - c_idx += c.strides()[3]; - } - a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3]; - b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3]; - c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3]; - } - a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; - b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; - c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2]; - } - a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; - b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; - c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; + a += stride_a; + b += stride_b; + c += stride_c; + out += stride_out; } } @@ -203,30 +121,69 @@ void ternary_op_dispatch_dims( const array& c, array& out, Op op) { - switch (out.ndim()) { - case 1: - ternary_op_dims1(a, b, c, out, op); - return; - case 2: - ternary_op_dims2(a, b, c, out, op); - return; - case 3: - ternary_op_dims3(a, b, c, out, op); - return; - case 4: - ternary_op_dims4(a, b, c, out, op); - return; - } + auto [shape, strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); + const auto& a_strides = strides[0]; + const auto& b_strides = strides[1]; + const auto& c_strides = strides[2]; + const auto& out_strides = strides[3]; const T1* a_ptr = a.data(); const T2* b_ptr = b.data(); const T3* c_ptr = c.data(); - U* dst = out.data(); - for (size_t i = 0; i < out.size(); i++) { - int a_idx = elem_to_loc(i, a.shape(), a.strides()); - int b_idx = elem_to_loc(i, b.shape(), b.strides()); - int c_idx = elem_to_loc(i, c.shape(), c.strides()); - dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + U* out_ptr = out.data(); + int ndim = shape.size(); + switch (ndim) { + case 1: + ternary_op_dims( + a_ptr, + b_ptr, + c_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + 0); + return; + case 2: + ternary_op_dims( + a_ptr, + b_ptr, + c_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + ContiguousIterator c_it(shape, c_strides, ndim - 2); + size_t stride = out_strides[ndim - 3]; + for (size_t elem = 0; elem < a.size(); elem += stride) { + ternary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + c_ptr + c_it.loc, + out_ptr + elem, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + ndim - 2); + a_it.step(); + b_it.step(); + c_it.step(); } } @@ -243,10 +200,21 @@ void ternary_op( // The full computation is scalar-scalar-scalar so we call the base op once. if (topt == TernaryOpType::ScalarScalarScalar) { *(out.data()) = op(*a.data(), *b.data(), *c.data()); - return; + } else if (topt == TernaryOpType::VectorVectorVector) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* out_ptr = out.data(); + for (size_t i = 0; i < out.size(); ++i) { + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); + a_ptr++; + b_ptr++; + c_ptr++; + out_ptr++; + } + } else { + ternary_op_dispatch_dims(a, b, c, out, op); } - - ternary_op_dispatch_dims(a, b, c, out, op); } } // namespace diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index 0dbe374fe..f9e682777 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -24,6 +24,14 @@ void set_unary_output_data(const array& in, array& out) { } } +template +void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) { + for (size_t i = 0; i < shape; i += 1) { + out[i] = op(*a); + a += stride; + } +} + template void unary_op(const array& a, array& out, Op op) { const T* a_ptr = a.data(); @@ -36,10 +44,16 @@ void unary_op(const array& a, array& out, Op op) { } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); T* dst = out.data(); - for (size_t i = 0; i < out.size(); ++i) { - // TODO this is super inefficient, need to fix. - int a_idx = elem_to_loc(i, a.shape(), a.strides()); - dst[i] = op(a_ptr[a_idx]); + size_t shape = a.ndim() > 0 ? a.shape(-1) : 1; + size_t stride = a.ndim() > 0 ? a.strides(-1) : 1; + if (a.ndim() <= 1) { + unary_op(a_ptr, dst, op, shape, stride); + return; + } + ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1); + for (size_t elem = 0; elem < a.size(); elem += shape) { + unary_op(a_ptr + it.loc, dst + elem, op, shape, stride); + it.step(); } } } diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 30e743a79..05ee47566 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -4,12 +4,12 @@ namespace mlx::core { -template -std::tuple, std::vector>> +template +std::tuple, std::vector>> collapse_contiguous_dims_impl( const std::vector& shape, - const std::vector>& strides, - stride_t size_cap) { + const std::vector>& strides, + StrideT size_cap) { // Make a vector that has axes separated with -1. Collapse all axes between // -1. std::vector to_collapse; @@ -21,7 +21,7 @@ collapse_contiguous_dims_impl( for (int i = 1; i < shape.size(); i++) { bool contiguous = true; size *= shape[i]; - for (const std::vector& st : strides) { + for (const std::vector& st : strides) { if (st[i] * shape[i] != st[i - 1] || size > size_cap) { contiguous = false; size = shape[i]; @@ -39,7 +39,7 @@ collapse_contiguous_dims_impl( } std::vector out_shape; - std::vector> out_strides(strides.size()); + std::vector> out_strides(strides.size()); for (int i = 0;;) { while (i < to_collapse.size() && to_collapse[i] == -1) { ++i; @@ -54,7 +54,7 @@ collapse_contiguous_dims_impl( } out_shape.push_back(current_shape); for (int j = 0; j < strides.size(); j++) { - const std::vector& st = strides[j]; + const std::vector& st = strides[j]; out_strides[j].push_back(st[to_collapse[k - 1]]); } i = k + 1; @@ -85,4 +85,54 @@ collapse_contiguous_dims( return collapse_contiguous_dims_impl(shape, strides, size_cap); } +template +std::pair, std::vector> collapse_contiguous_dims_impl( + const std::vector& shape, + const std::vector& strides, + StrideT size_cap) { + std::vector collapsed_shape; + std::vector collapsed_strides; + + if (shape.size() > 0) { + collapsed_shape.push_back(shape[0]); + collapsed_strides.push_back(strides[0]); + for (int i = 1; i < shape.size(); i++) { + if (shape[i] == 1) { + continue; + } else if ( + strides[i] * shape[i] != collapsed_strides.back() || + collapsed_shape.back() * static_cast(shape[i]) > size_cap) { + collapsed_shape.push_back(shape[i]); + collapsed_strides.push_back(strides[i]); + } else { + collapsed_shape.back() *= shape[i]; + collapsed_strides.back() = strides[i]; + } + } + } + + return std::make_pair(collapsed_shape, collapsed_strides); +} + +std::pair, std::vector> collapse_contiguous_dims( + const std::vector& shape, + const std::vector& strides, + int64_t size_cap /* = std::numeric_limits::max() */) { + return collapse_contiguous_dims_impl(shape, strides, size_cap); +} + +std::pair, std::vector> collapse_contiguous_dims( + const std::vector& shape, + const std::vector& strides, + size_t size_cap /* = std::numeric_limits::max() */) { + return collapse_contiguous_dims_impl(shape, strides, size_cap); +} + +std::pair, std::vector> collapse_contiguous_dims( + const array& a, + size_t size_cap /* = std::numeric_limits::max()*/) { + return collapse_contiguous_dims_impl( + a.shape(), a.strides(), size_cap); +} + } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 6f57fe11b..b037c309f 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -8,12 +8,12 @@ namespace mlx::core { -template -inline stride_t elem_to_loc( +template +inline StrideT elem_to_loc( int elem, const std::vector& shape, - const std::vector& strides) { - stride_t loc = 0; + const std::vector& strides) { + StrideT loc = 0; for (int i = shape.size() - 1; i >= 0; --i) { auto q_and_r = ldiv(elem, shape[i]); loc += q_and_r.rem * strides[i]; @@ -29,9 +29,9 @@ inline size_t elem_to_loc(int elem, const array& a) { return elem_to_loc(elem, a.shape(), a.strides()); } -template -std::vector make_contiguous_strides(const std::vector& shape) { - std::vector strides(shape.size(), 1); +template +std::vector make_contiguous_strides(const std::vector& shape) { + std::vector strides(shape.size(), 1); for (int i = shape.size() - 1; i > 0; i--) { strides[i - 1] = strides[i] * shape[i]; } @@ -58,7 +58,7 @@ collapse_contiguous_dims( inline std::tuple, std::vector>> collapse_contiguous_dims( const std::vector& xs, - size_t size_cap = std::numeric_limits::max()) { + size_t size_cap = std::numeric_limits::max()) { std::vector> strides; for (auto& x : xs) { strides.emplace_back(x.strides()); @@ -73,36 +73,55 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) { } // The single array version of the above. -inline std::tuple, std::vector> -collapse_contiguous_dims( +std::pair, std::vector> collapse_contiguous_dims( const std::vector& shape, - const std::vector& strides) { - std::vector collapsed_shape; - std::vector collapsed_strides; + const std::vector& strides, + int64_t size_cap = std::numeric_limits::max()); +std::pair, std::vector> collapse_contiguous_dims( + const std::vector& shape, + const std::vector& strides, + size_t size_cap = std::numeric_limits::max()); +std::pair, std::vector> collapse_contiguous_dims( + const array& a, + size_t size_cap = std::numeric_limits::max()); - if (shape.size() > 0) { - collapsed_shape.push_back(shape[0]); - collapsed_strides.push_back(strides[0]); - for (int i = 1; i < shape.size(); i++) { - if (strides[i] * shape[i] != collapsed_strides.back() || - collapsed_shape.back() * static_cast(shape[i]) > - std::numeric_limits::max()) { - collapsed_shape.push_back(shape[i]); - collapsed_strides.push_back(strides[i]); - } else { - collapsed_shape.back() *= shape[i]; - collapsed_strides.back() = strides[i]; - } +template +struct ContiguousIterator { + inline void step() { + int i = dims_; + while (pos_[i] == (shape_[i] - 1) && i > 0) { + pos_[i] = 0; + loc -= (shape_[i] - 1) * strides_[i]; + i--; } + pos_[i]++; + loc += strides_[i]; } - return std::make_tuple(collapsed_shape, collapsed_strides); -} + explicit ContiguousIterator( + const std::vector& shape, + const std::vector& strides, + int dims) + : shape_(shape.begin(), shape.begin() + dims), + strides_(strides.begin(), strides.begin() + dims) { + std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); + dims_ = shape_.size() - 1; + pos_ = std::vector(dims_ + 1, 0); + } -template + StrideT loc{0}; + + private: + std::vector shape_; + std::vector strides_; + std::vector pos_; + int dims_; +}; + +template inline auto check_contiguity( const std::vector& shape, - const std::vector& strides) { + const std::vector& strides) { size_t no_broadcast_data_size = 1; size_t f_stride = 1; size_t b_stride = 1; diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 248fb526c..101810628 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -73,11 +73,7 @@ void binary_op_gpu_inplace( // Try to collapse contiguous dims auto maybe_collapse = [bopt, &a, &b, &out]() { if (bopt == BinaryOpType::General) { - // The size cap here should ideally be `UINT32_MAX` but we are - // limitied by the shape being an int. - auto [shape, strides] = collapse_contiguous_dims( - {a, b, out}, - /* size_cap = */ INT32_MAX); + auto [shape, strides] = collapse_contiguous_dims(a, b, out); return std::make_tuple(shape, strides[0], strides[1], strides[2]); } else { std::vector e; diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index c70b5e969..357065bdc 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -26,11 +26,7 @@ void ternary_op_gpu_inplace( // Try to collapse contiguous dims auto maybe_collapse = [topt, &a, &b, &c, &out]() { if (topt == TernaryOpType::General) { - // The size cap here should ideally be `UINT32_MAX` but we are - // limitied by the shape being an int. - auto [shape, strides] = collapse_contiguous_dims( - {a, b, c, out}, - /* size_cap = */ INT32_MAX); + auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); return std::make_tuple( shape, strides[0], strides[1], strides[2], strides[3]); } else { diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 666739d3a..eb4af03ec 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -28,10 +28,7 @@ void unary_op_gpu_inplace( auto maybe_collapse = [contig, &in, &out]() { if (!contig) { - auto [shape, strides] = collapse_contiguous_dims( - {in, out}, - /* size_cap = */ INT32_MAX); - return std::make_pair(shape, strides[0]); + return collapse_contiguous_dims(in); } else { return std::make_pair(std::vector{}, std::vector{}); }