// Copyright © 2023 Apple Inc. #pragma once #include #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/common/utils.h" namespace mlx::core { namespace { enum class BinaryOpType { ScalarScalar, ScalarVector, VectorScalar, VectorVector, General, }; BinaryOpType get_binary_op_type(const array& a, const array& b) { BinaryOpType bopt; if (a.data_size() == 1 && b.data_size() == 1) { bopt = BinaryOpType::ScalarScalar; } else if (a.data_size() == 1 && b.flags().contiguous) { bopt = BinaryOpType::ScalarVector; } else if (b.data_size() == 1 && a.flags().contiguous) { bopt = BinaryOpType::VectorScalar; } else if ( a.flags().row_contiguous && b.flags().row_contiguous || a.flags().col_contiguous && b.flags().col_contiguous) { bopt = BinaryOpType::VectorVector; } else { bopt = BinaryOpType::General; } return bopt; } void set_binary_op_output_data( const array& a, const array& b, array& out, BinaryOpType bopt, bool donate_with_move = false) { switch (bopt) { case BinaryOpType::ScalarScalar: out.set_data( allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); break; case BinaryOpType::ScalarVector: if (b.is_donatable() && b.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(b); } else { out.copy_shared_buffer(b); } } else { out.set_data( allocator::malloc_or_wait(b.data_size() * out.itemsize()), b.data_size(), b.strides(), b.flags()); } break; case BinaryOpType::VectorScalar: if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(a); } else { out.copy_shared_buffer(a); } } else { out.set_data( allocator::malloc_or_wait(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); } break; case BinaryOpType::VectorVector: if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(a); } else { out.copy_shared_buffer(a); } } else if (b.is_donatable() && b.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(b); } else { out.copy_shared_buffer(b); } } else { out.set_data( allocator::malloc_or_wait(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); } break; case BinaryOpType::General: if (a.is_donatable() && a.flags().row_contiguous && a.itemsize() == out.itemsize() && a.size() == out.size()) { if (donate_with_move) { out.move_shared_buffer(a); } else { out.copy_shared_buffer(a); } } else if ( b.is_donatable() && b.flags().row_contiguous && b.itemsize() == out.itemsize() && b.size() == out.size()) { if (donate_with_move) { out.move_shared_buffer(b); } else { out.copy_shared_buffer(b); } } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } break; } } struct UseDefaultBinaryOp { template void operator()(const T* a, const T* b, U* dst, int size) { // Should we throw? This should normally never be called. assert(false); } template void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { // Should we throw? This should normally never be called. assert(false); } }; template struct DefaultVectorScalar { Op op; DefaultVectorScalar(Op op_) : op(op_) {} void operator()(const T* a, const T* b, U* dst, int size) { T scalar = *b; while (size-- > 0) { *dst = op(*a, scalar); dst++; a++; } } void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { T scalar = *b; while (size-- > 0) { auto dst = op(*a, scalar); *dst_a = dst.first; *dst_b = dst.second; dst_a++; dst_b++; a++; } } }; template struct DefaultScalarVector { Op op; DefaultScalarVector(Op op_) : op(op_) {} void operator()(const T* a, const T* b, U* dst, int size) { T scalar = *a; while (size-- > 0) { *dst = op(scalar, *b); dst++; b++; } } void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { T scalar = *a; while (size-- > 0) { auto dst = op(scalar, *b); *dst_a = dst.first; *dst_b = dst.second; dst_a++; dst_b++; b++; } } }; template struct DefaultVectorVector { Op op; DefaultVectorVector(Op op_) : op(op_) {} void operator()(const T* a, const T* b, U* dst, int size) { while (size-- > 0) { *dst = op(*a, *b); dst++; a++; b++; } } void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) { while (size-- > 0) { auto dst = op(*a, *b); *dst_a = dst.first; *dst_b = dst.second; dst_a++; dst_b++; a++; b++; } } }; template void binary_op_dims1(const array& a, const array& b, array& out, Op op) { const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* dst = out.data(); size_t a_idx = 0; size_t b_idx = 0; for (size_t i = 0; i < out.size(); ++i) { dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]); a_idx += a.strides()[0]; b_idx += b.strides()[0]; } } template void binary_op_dims1( const array& a, const array& b, array& out, Op op, int stride) { const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* dst = out.data(); size_t a_idx = 0; size_t b_idx = 0; for (size_t i = 0; i < a.shape()[0]; i++) { op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); a_idx += a.strides()[0]; b_idx += b.strides()[0]; dst += stride; } } template void binary_op_dims2(const array& a, const array& b, array& out, Op op) { const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* dst = out.data(); size_t a_idx = 0; size_t b_idx = 0; size_t out_idx = 0; for (size_t i = 0; i < a.shape()[0]; ++i) { for (size_t j = 0; j < a.shape()[1]; ++j) { dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]); a_idx += a.strides()[1]; b_idx += b.strides()[1]; } a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; } } template void binary_op_dims2( const array& a, const array& b, array& out, Op op, int stride) { const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* dst = out.data(); size_t a_idx = 0; size_t b_idx = 0; for (size_t i = 0; i < a.shape()[0]; ++i) { for (size_t j = 0; j < a.shape()[1]; ++j) { op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); a_idx += a.strides()[1]; b_idx += b.strides()[1]; dst += stride; } a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; } } template void binary_op_dims3(const array& a, const array& b, array& out, Op op) { const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* dst = out.data(); size_t a_idx = 0; size_t b_idx = 0; size_t out_idx = 0; for (size_t i = 0; i < a.shape()[0]; ++i) { for (size_t j = 0; j < a.shape()[1]; ++j) { for (size_t k = 0; k < a.shape()[2]; ++k) { dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]); a_idx += a.strides()[2]; b_idx += b.strides()[2]; } a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; } a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; } } template void binary_op_dims4(const array& a, const array& b, array& out, Op op) { const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* dst = out.data(); size_t a_idx = 0; size_t b_idx = 0; size_t out_idx = 0; for (size_t i = 0; i < a.shape()[0]; ++i) { for (size_t j = 0; j < a.shape()[1]; ++j) { for (size_t k = 0; k < a.shape()[2]; ++k) { for (size_t ii = 0; ii < a.shape()[3]; ++ii) { dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]); a_idx += a.strides()[3]; b_idx += b.strides()[3]; } a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3]; b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3]; } a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; } a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; } } template void binary_op_dispatch_dims( const array& a, const array& b, array& out, Op op) { switch (out.ndim()) { case 1: binary_op_dims1(a, b, out, op); return; case 2: binary_op_dims2(a, b, out, op); return; case 3: binary_op_dims3(a, b, out, op); return; case 4: binary_op_dims4(a, b, out, op); return; } const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* dst = out.data(); for (size_t i = 0; i < out.size(); i++) { int a_idx = elem_to_loc(i, a.shape(), a.strides()); int b_idx = elem_to_loc(i, b.shape(), b.strides()); dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]); } } template void binary_op_dispatch_dims( const array& a, const array& b, array& out, Op op, int dim, int stride) { // Number of dimensions to loop over for vectorized ops switch (dim) { case 1: binary_op_dims1(a, b, out, op, stride); return; case 2: binary_op_dims2(a, b, out, op, stride); return; } const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* dst = out.data(); for (size_t i = 0; i < out.size(); i += stride) { int a_idx = elem_to_loc(i, a.shape(), a.strides()); int b_idx = elem_to_loc(i, b.shape(), b.strides()); op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); dst += stride; } } template < typename T, typename U, typename Op, typename OpSV, typename OpVS, typename OpVV> void binary_op( const array& a, const array& b, array& out, Op op, OpSV opsv, OpVS opvs, OpVV opvv) { auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out, bopt); // The full computation is scalar scalar so call the base op once if (bopt == BinaryOpType::ScalarScalar) { *(out.data()) = op(*a.data(), *b.data()); return; } // The full computation is scalar vector so delegate to the op if (bopt == BinaryOpType::ScalarVector) { opsv(a.data(), b.data(), out.data(), b.data_size()); return; } // The full computation is vector scalar so delegate to the op if (bopt == BinaryOpType::VectorScalar) { opvs(a.data(), b.data(), out.data(), a.data_size()); return; } // The full computation is vector vector so delegate to the op if (bopt == BinaryOpType::VectorVector) { opvv(a.data(), b.data(), out.data(), out.size()); return; } // General computation so let's try to optimize // Get the left-most dim such that the array is row contiguous after auto& strides = out.strides(); auto leftmost_rc_dim = [&strides](const array& arr) { int d = arr.ndim() - 1; for (; d >= 0 && arr.strides()[d] == strides[d]; d--) { } return d + 1; }; auto a_rc_dim = leftmost_rc_dim(a); auto b_rc_dim = leftmost_rc_dim(b); // Get the left-most dim such that the array is a broadcasted "scalar" after auto leftmost_s_dim = [](const array& arr) { int d = arr.ndim() - 1; for (; d >= 0 && arr.strides()[d] == 0; d--) { } return d + 1; }; auto a_s_dim = leftmost_s_dim(a); auto b_s_dim = leftmost_s_dim(b); auto ndim = out.ndim(); // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous int dim = ndim; if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { bopt = BinaryOpType::VectorVector; dim = d; // Case 2: LxM and Fx1 where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { bopt = BinaryOpType::VectorScalar; dim = d; // Case 3: Lx1 and FxM where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { bopt = BinaryOpType::ScalarVector; dim = d; } // Can be sure dim > 0 since otherwise we would have used one of the fully // contiguous methods above. Except for the case that the flags do not // correspond to the underlying contiguity. size_t stride; if (dim == 0 || strides[dim - 1] < 16) { stride = 1; bopt = BinaryOpType::General; dim = ndim; } else { stride = strides[dim - 1]; } switch (bopt) { case BinaryOpType::VectorVector: binary_op_dispatch_dims(a, b, out, opvv, dim, stride); break; case BinaryOpType::VectorScalar: binary_op_dispatch_dims(a, b, out, opvs, dim, stride); break; case BinaryOpType::ScalarVector: binary_op_dispatch_dims(a, b, out, opsv, dim, stride); break; default: binary_op_dispatch_dims(a, b, out, op); break; } } template void binary_op( const array& a, const array& b, array& out, Op op, OpSV opsv, OpVS opvs, OpVV opvv) { // TODO: The following mess of constexpr evaluations can probably be achieved // with template specializations and overloading. Would it be simpler? if (std::is_same::value) { if (std::is_same::value) { if (std::is_same::value) { // All ops are UseDefaultBinaryOp (why oh why would someone call that?) binary_op( a, b, out, op, DefaultScalarVector(op), DefaultVectorScalar(op), DefaultVectorVector(op)); } else { // opsv and opvs were UseDefaultBinaryOp binary_op( a, b, out, op, DefaultScalarVector(op), DefaultVectorScalar(op), opvv); } } else if (std::is_same::value) { // opsv and opvv were UseDefaultBinaryOp binary_op( a, b, out, op, DefaultScalarVector(op), opvs, DefaultVectorVector(op)); } else { // opsv was UseDefaultBinaryOp binary_op( a, b, out, op, DefaultScalarVector(op), opvs, opvv); } } else if (std::is_same::value) { if (std::is_same::value) { // opvs and opvv were UseDefaultBinaryOp binary_op( a, b, out, op, opsv, DefaultVectorScalar(op), DefaultVectorVector(op)); } else { // opvs was UseDefaultBinaryOp binary_op( a, b, out, op, opsv, DefaultVectorScalar(op), opvv); } } else if (std::is_same::value) { // opvv was UseDefaultBinaryOp binary_op( a, b, out, op, opsv, opvs, DefaultVectorVector(op)); } else { // All ops provided binary_op(a, b, out, op, opsv, opvs, opvv); } } template void binary_op(const array& a, const array& b, array& out, Op op) { DefaultScalarVector opsv(op); DefaultVectorScalar opvs(op); DefaultVectorVector opvv(op); binary_op(a, b, out, op, opsv, opvs, opvv); } template void binary(const array& a, const array& b, array& out, Ops... ops) { switch (out.dtype()) { case bool_: binary_op(a, b, out, ops...); break; case uint8: binary_op(a, b, out, ops...); break; case uint16: binary_op(a, b, out, ops...); break; case uint32: binary_op(a, b, out, ops...); break; case uint64: binary_op(a, b, out, ops...); break; case int8: binary_op(a, b, out, ops...); break; case int16: binary_op(a, b, out, ops...); break; case int32: binary_op(a, b, out, ops...); break; case int64: binary_op(a, b, out, ops...); break; case float16: binary_op(a, b, out, ops...); break; case float32: binary_op(a, b, out, ops...); break; case bfloat16: binary_op(a, b, out, ops...); break; case complex64: binary_op(a, b, out, ops...); break; } } } // namespace } // namespace mlx::core