diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 6f265b1ad..5a9fcc5ba 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -48,6 +48,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index e89ce7d6c..08dc34037 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -250,49 +250,6 @@ void Split::eval( } } -std::tuple> Slice::prepare_slice( - const array& in) { - int64_t data_offset = 0; - bool copy_needed = false; - std::vector inp_strides(in.ndim(), 0); - for (int i = 0; i < in.ndim(); ++i) { - data_offset += start_indices_[i] * in.strides()[i]; - inp_strides[i] = in.strides()[i] * strides_[i]; - - copy_needed |= strides_[i] < 0; - } - - return std::make_tuple(copy_needed, data_offset, inp_strides); -} - -void Slice::shared_buffer_slice( - const array& in, - const std::vector& out_strides, - size_t data_offset, - array& out) { - // Compute row/col contiguity - auto [data_size, is_row_contiguous, is_col_contiguous] = - check_contiguity(out.shape(), out_strides); - - auto flags = in.flags(); - flags.row_contiguous = is_row_contiguous; - flags.col_contiguous = is_col_contiguous; - - if (data_size == 1) { - // Broadcasted scalar array is contiguous. - flags.contiguous = true; - } else if (data_size == in.data_size()) { - // Means we sliced a broadcasted dimension so leave the "no holes" flag - // alone. - } else { - // We sliced something. So either we are row or col contiguous or we - // punched a hole. - flags.contiguous &= flags.row_contiguous || flags.col_contiguous; - } - - out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); -} - std::tuple> SliceUpdate::prepare_slice( const array& in) { int64_t data_offset = 0; diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 29c4dcfd0..d6b667c8d 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -11,6 +11,7 @@ #include "mlx/backend/common/binary.h" #include "mlx/backend/common/copy.h" #include "mlx/backend/common/ops.h" +#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/threefry.h" #include "mlx/backend/common/unary.h" #include "mlx/backend/common/utils.h" @@ -492,7 +493,8 @@ void Slice::eval(const std::vector& inputs, array& out) { auto& in = inputs[0]; // Calculate out strides, initial offset and if copy needs to be made - auto [copy_needed, data_offset, inp_strides] = prepare_slice(in); + auto [copy_needed, data_offset, inp_strides] = + prepare_slice(in, start_indices_, strides_); // Do copy if needed if (copy_needed) { diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp new file mode 100644 index 000000000..3425179e0 --- /dev/null +++ b/mlx/backend/common/slicing.cpp @@ -0,0 +1,52 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +std::tuple> prepare_slice( + const array& in, + std::vector& start_indices, + std::vector& strides) { + int64_t data_offset = 0; + bool copy_needed = false; + std::vector inp_strides(in.ndim(), 0); + for (int i = 0; i < in.ndim(); ++i) { + data_offset += start_indices[i] * in.strides()[i]; + inp_strides[i] = in.strides()[i] * strides[i]; + + copy_needed |= strides[i] < 0; + } + + return std::make_tuple(copy_needed, data_offset, inp_strides); +} + +void shared_buffer_slice( + const array& in, + const std::vector& out_strides, + size_t data_offset, + array& out) { + // Compute row/col contiguity + auto [data_size, is_row_contiguous, is_col_contiguous] = + check_contiguity(out.shape(), out_strides); + + auto flags = in.flags(); + flags.row_contiguous = is_row_contiguous; + flags.col_contiguous = is_col_contiguous; + + if (data_size == 1) { + // Broadcasted scalar array is contiguous. + flags.contiguous = true; + } else if (data_size == in.data_size()) { + // Means we sliced a broadcasted dimension so leave the "no holes" flag + // alone. + } else { + // We sliced something. So either we are row or col contiguous or we + // punched a hole. + flags.contiguous &= flags.row_contiguous || flags.col_contiguous; + } + + out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/slicing.h b/mlx/backend/common/slicing.h new file mode 100644 index 000000000..842793783 --- /dev/null +++ b/mlx/backend/common/slicing.h @@ -0,0 +1,20 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +std::tuple> prepare_slice( + const array& in, + std::vector& start_indices, + std::vector& strides); + +void shared_buffer_slice( + const array& in, + const std::vector& out_strides, + size_t data_offset, + array& out); + +} // namespace mlx::core diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 671b45dca..c978fe5e5 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -135,6 +135,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 4b1e0f01d..06cf734f0 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -10,16 +10,14 @@ namespace mlx::core { constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; -void binary_op( +void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, - const std::string op) { - assert(inputs.size() == 2); + const std::string op, + const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, outputs[0], bopt, true); - set_binary_op_output_data(a, b, outputs[1], bopt, true); auto& out = outputs[0]; if (out.size() == 0) { @@ -61,7 +59,6 @@ void binary_op( kernel_name = kname.str(); } - auto& s = out.primitive().stream(); auto& d = metal::device(s.device); auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]); @@ -120,15 +117,36 @@ void binary_op( } } -void binary_op( +void binary_op_gpu( const std::vector& inputs, - array& out, - const std::string op) { + std::vector& outputs, + const std::string op, + const Stream& s) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt, true); + set_binary_op_output_data(a, b, outputs[0], bopt, true); + set_binary_op_output_data(a, b, outputs[1], bopt, true); + binary_op_gpu_inplace(inputs, outputs, op, s); +} + +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + const std::string op) { + auto& s = outputs[0].primitive().stream(); + binary_op_gpu(inputs, outputs, op, s); +} + +void binary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); if (out.size() == 0) { return; } @@ -168,7 +186,6 @@ void binary_op( kernel_name = kname.str(); } - auto& s = out.primitive().stream(); auto& d = metal::device(s.device); auto kernel = get_binary_kernel(d, kernel_name, a, out); @@ -221,102 +238,123 @@ void binary_op( } } +void binary_op_gpu( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt, true); + binary_op_gpu_inplace(inputs, out, op, s); +} + +void binary_op_gpu( + const std::vector& inputs, + array& out, + const std::string op) { + auto& s = out.primitive().stream(); + binary_op_gpu(inputs, out, op, s); +} + void Add::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "add"); + binary_op_gpu(inputs, out, "add"); } void ArcTan2::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "arctan2"); + binary_op_gpu(inputs, out, "arctan2"); } void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { switch (op_) { case BitwiseBinary::And: - binary_op(inputs, out, "bitwise_and"); + binary_op_gpu(inputs, out, "bitwise_and"); break; case BitwiseBinary::Or: - binary_op(inputs, out, "bitwise_or"); + binary_op_gpu(inputs, out, "bitwise_or"); break; case BitwiseBinary::Xor: - binary_op(inputs, out, "bitwise_xor"); + binary_op_gpu(inputs, out, "bitwise_xor"); break; case BitwiseBinary::LeftShift: - binary_op(inputs, out, "left_shift"); + binary_op_gpu(inputs, out, "left_shift"); break; case BitwiseBinary::RightShift: - binary_op(inputs, out, "right_shift"); + binary_op_gpu(inputs, out, "right_shift"); break; } } void Divide::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "div"); + binary_op_gpu(inputs, out, "div"); } void DivMod::eval_gpu( const std::vector& inputs, std::vector& outputs) { - binary_op(inputs, outputs, "divmod"); + binary_op_gpu(inputs, outputs, "divmod"); } void Remainder::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "rem"); + binary_op_gpu(inputs, out, "rem"); } void Equal::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, equal_nan_ ? "naneq" : "eq"); + binary_op_gpu(inputs, out, equal_nan_ ? "naneq" : "eq"); } void Greater::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "ge"); + binary_op_gpu(inputs, out, "ge"); } void GreaterEqual::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "geq"); + binary_op_gpu(inputs, out, "geq"); } void Less::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "le"); + binary_op_gpu(inputs, out, "le"); } void LessEqual::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "leq"); + binary_op_gpu(inputs, out, "leq"); } void LogicalAnd::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "land"); + binary_op_gpu(inputs, out, "land"); } void LogicalOr::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "lor"); + binary_op_gpu(inputs, out, "lor"); } void LogAddExp::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "lae"); + binary_op_gpu(inputs, out, "lae"); } void Maximum::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "max"); + binary_op_gpu(inputs, out, "max"); } void Minimum::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "min"); + binary_op_gpu(inputs, out, "min"); } void Multiply::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "mul"); + binary_op_gpu(inputs, out, "mul"); } void NotEqual::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "neq"); + binary_op_gpu(inputs, out, "neq"); } void Power::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "pow"); + binary_op_gpu(inputs, out, "pow"); } void Subtract::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "sub"); + binary_op_gpu(inputs, out, "sub"); } } // namespace mlx::core diff --git a/mlx/backend/metal/binary.h b/mlx/backend/metal/binary.h new file mode 100644 index 000000000..40d414291 --- /dev/null +++ b/mlx/backend/metal/binary.h @@ -0,0 +1,33 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + const std::string op, + const Stream& s); + +void binary_op_gpu( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s); + +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + const std::string op, + const Stream& s); + +void binary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s); + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 3e249a1d4..772b8bc17 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -7,6 +7,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -163,30 +164,7 @@ void Broadcast::eval_gpu(const std::vector& inputs, array& out) { } void Concatenate::eval_gpu(const std::vector& inputs, array& out) { - std::vector sizes; - sizes.push_back(0); - for (auto& p : inputs) { - sizes.push_back(p.shape(axis_)); - } - std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); - - out.set_data(allocator::malloc_or_wait(out.nbytes())); - - auto strides = out.strides(); - auto flags = out.flags(); - flags.row_contiguous = false; - flags.col_contiguous = false; - flags.contiguous = false; - auto& d = metal::device(stream().device); - auto& compute_encoder = d.get_command_encoder(stream().index); - auto concurrent_ctx = compute_encoder.start_concurrent(); - for (int i = 0; i < inputs.size(); i++) { - array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); - size_t data_offset = strides[axis_] * sizes[i]; - out_slice.copy_shared_buffer( - out, strides, flags, out_slice.size(), data_offset); - copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); - } + concatenate_gpu(inputs, out, axis_, stream()); } void Copy::eval_gpu(const std::vector& inputs, array& out) { @@ -238,23 +216,7 @@ void Pad::eval_gpu(const std::vector& inputs, array& out) { // Padding value, input and output must be of the same type assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); - // Fill output with val - copy_gpu(val, out, CopyType::Scalar, stream()); - - // Find offset for start of input values - size_t data_offset = 0; - for (int i = 0; i < axes_.size(); i++) { - auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i]; - data_offset += out.strides()[ax] * low_pad_size_[i]; - } - - // Extract slice from output where input will be pasted - array out_slice(in.shape(), out.dtype(), nullptr, {}); - out_slice.copy_shared_buffer( - out, out.strides(), out.flags(), out_slice.size(), data_offset); - - // Copy input values into the slice - copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); + pad_gpu(in, val, out, axes_, low_pad_size_, stream()); } void RandomBits::eval_gpu(const std::vector& inputs, array& out) { @@ -331,28 +293,7 @@ void Slice::eval_gpu(const std::vector& inputs, array& out) { } auto& in = inputs[0]; - - // Calculate out strides, initial offset and if copy needs to be made - auto [copy_needed, data_offset, inp_strides] = prepare_slice(in); - - // Do copy if needed - if (copy_needed) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - std::vector ostrides{out.strides().begin(), out.strides().end()}; - copy_gpu_inplace( - /* const array& in = */ in, - /* array& out = */ out, - /* const std::vector& data_shape = */ out.shape(), - /* const std::vector& i_strides = */ inp_strides, - /* const std::vector& o_strides = */ ostrides, - /* int64_t i_offset = */ data_offset, - /* int64_t o_offset = */ 0, - /* CopyType ctype = */ CopyType::General, - /* const Stream& s = */ stream()); - } else { - std::vector ostrides{inp_strides.begin(), inp_strides.end()}; - shared_buffer_slice(in, ostrides, data_offset, out); - } + slice_gpu(in, out, start_indices_, strides_, stream()); } void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp new file mode 100644 index 000000000..1057f27b7 --- /dev/null +++ b/mlx/backend/metal/slicing.cpp @@ -0,0 +1,98 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" + +namespace mlx::core { + +void slice_gpu( + const array& in, + array& out, + std::vector start_indices, + std::vector strides, + const Stream& s) { + // Calculate out strides, initial offset and if copy needs to be made + auto [copy_needed, data_offset, inp_strides] = + prepare_slice(in, start_indices, strides); + + // Do copy if needed + if (copy_needed) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + std::vector ostrides{out.strides().begin(), out.strides().end()}; + copy_gpu_inplace( + /* const array& in = */ in, + /* array& out = */ out, + /* const std::vector& data_shape = */ out.shape(), + /* const std::vector& i_strides = */ inp_strides, + /* const std::vector& o_strides = */ ostrides, + /* int64_t i_offset = */ data_offset, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::General, + /* const Stream& s = */ s); + } else { + std::vector ostrides{inp_strides.begin(), inp_strides.end()}; + shared_buffer_slice(in, ostrides, data_offset, out); + } +} + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s) { + std::vector sizes; + sizes.push_back(0); + for (auto& p : inputs) { + sizes.push_back(p.shape(axis)); + } + std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto strides = out.strides(); + auto flags = out.flags(); + flags.row_contiguous = false; + flags.col_contiguous = false; + flags.contiguous = false; + auto& d = metal::device(s.device); + auto& compute_encoder = d.get_command_encoder(s.index); + auto concurrent_ctx = compute_encoder.start_concurrent(); + for (int i = 0; i < inputs.size(); i++) { + array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); + size_t data_offset = strides[axis] * sizes[i]; + out_slice.copy_shared_buffer( + out, strides, flags, out_slice.size(), data_offset); + copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s); + } +} + +void pad_gpu( + const array& in, + const array& val, + array& out, + std::vector axes, + std::vector low_pad_size, + const Stream& s) { + // Fill output with val + copy_gpu(val, out, CopyType::Scalar, s); + + // Find offset for start of input values + size_t data_offset = 0; + for (int i = 0; i < axes.size(); i++) { + auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; + data_offset += out.strides()[ax] * low_pad_size[i]; + } + + // Extract slice from output where input will be pasted + array out_slice(in.shape(), out.dtype(), nullptr, {}); + out_slice.copy_shared_buffer( + out, out.strides(), out.flags(), out_slice.size(), data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/slicing.h b/mlx/backend/metal/slicing.h new file mode 100644 index 000000000..cf81b7000 --- /dev/null +++ b/mlx/backend/metal/slicing.h @@ -0,0 +1,30 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void slice_gpu( + const array& in, + array& out, + std::vector start_indices, + std::vector strides, + const Stream& s); + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s); + +void pad_gpu( + const array& in, + const array& val, + array& out, + std::vector axes, + std::vector low_pad_size, + const Stream& s); + +} // namespace mlx::core diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index ed2f4c14a..51c7b9883 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -10,16 +10,16 @@ namespace mlx::core { constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5; -void ternary_op( +void ternary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op) { + const std::string op, + const Stream& s) { assert(inputs.size() == 3); auto& a = inputs[0]; auto& b = inputs[1]; auto& c = inputs[2]; TernaryOpType topt = get_ternary_op_type(a, b, c); - set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */); if (out.size() == 0) { return; @@ -47,7 +47,6 @@ void ternary_op( kernel_name = kname.str(); } - auto& s = out.primitive().stream(); auto& d = metal::device(s.device); auto kernel = get_ternary_kernel(d, kernel_name, out); @@ -101,8 +100,29 @@ void ternary_op( } } +void ternary_op_gpu( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& c = inputs[2]; + TernaryOpType topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */); + ternary_op_gpu_inplace(inputs, out, op, s); +} + +void ternary_op_gpu( + const std::vector& inputs, + array& out, + const std::string op) { + auto& s = out.primitive().stream(); + ternary_op_gpu(inputs, out, op, s); +} + void Select::eval_gpu(const std::vector& inputs, array& out) { - ternary_op(inputs, out, "select"); + ternary_op_gpu(inputs, out, "select"); } } // namespace mlx::core diff --git a/mlx/backend/metal/ternary.h b/mlx/backend/metal/ternary.h new file mode 100644 index 000000000..0834140b8 --- /dev/null +++ b/mlx/backend/metal/ternary.h @@ -0,0 +1,21 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void ternary_op_gpu( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s); + +void ternary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s); + +} // namespace mlx::core diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index c540f7c06..f4d9721fa 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -7,30 +7,17 @@ namespace mlx::core { -void unary_op( +void unary_op_gpu_inplace( const std::vector& inputs, array& out, - const std::string op) { + const std::string op, + const Stream& s) { auto& in = inputs[0]; bool contig = in.flags().contiguous; - if (contig) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.move_shared_buffer(in); - } else { - out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - } - } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - } if (in.size() == 0) { return; } - auto& s = out.primitive().stream(); auto& d = metal::device(s.device); std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out); @@ -59,39 +46,70 @@ void unary_op( compute_encoder.dispatchThreads(grid_dims, group_dims); } +void unary_op_gpu( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s) { + auto& in = inputs[0]; + bool contig = in.flags().contiguous; + if (contig) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.move_shared_buffer(in); + } else { + out.set_data( + allocator::malloc_or_wait(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + unary_op_gpu_inplace(inputs, out, op, s); +} + +void unary_op_gpu( + const std::vector& inputs, + array& out, + const std::string op) { + auto& s = out.primitive().stream(); + unary_op_gpu(inputs, out, op, s); +} + void Abs::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "abs"); + unary_op_gpu(inputs, out, "abs"); } void ArcCos::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arccos"); + unary_op_gpu(inputs, out, "arccos"); } void ArcCosh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arccosh"); + unary_op_gpu(inputs, out, "arccosh"); } void ArcSin::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arcsin"); + unary_op_gpu(inputs, out, "arcsin"); } void ArcSinh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arcsinh"); + unary_op_gpu(inputs, out, "arcsinh"); } void ArcTan::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arctan"); + unary_op_gpu(inputs, out, "arctan"); } void ArcTanh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arctanh"); + unary_op_gpu(inputs, out, "arctanh"); } void Conjugate::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (out.dtype() == complex64) { - unary_op(inputs, out, "conj"); + unary_op_gpu(inputs, out, "conj"); } else { throw std::invalid_argument( "[conjugate] conjugate must be called on complex input."); @@ -99,68 +117,68 @@ void Conjugate::eval_gpu(const std::vector& inputs, array& out) { } void Cos::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "cos"); + unary_op_gpu(inputs, out, "cos"); } void Cosh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "cosh"); + unary_op_gpu(inputs, out, "cosh"); } void Erf::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "erf"); + unary_op_gpu(inputs, out, "erf"); } void ErfInv::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "erfinv"); + unary_op_gpu(inputs, out, "erfinv"); } void Exp::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "exp"); + unary_op_gpu(inputs, out, "exp"); } void Expm1::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "expm1"); + unary_op_gpu(inputs, out, "expm1"); } void Log::eval_gpu(const std::vector& inputs, array& out) { switch (base_) { case Base::e: - unary_op(inputs, out, "log"); + unary_op_gpu(inputs, out, "log"); break; case Base::two: - unary_op(inputs, out, "log2"); + unary_op_gpu(inputs, out, "log2"); break; case Base::ten: - unary_op(inputs, out, "log10"); + unary_op_gpu(inputs, out, "log10"); break; } } void Log1p::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "log1p"); + unary_op_gpu(inputs, out, "log1p"); } void LogicalNot::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "lnot"); + unary_op_gpu(inputs, out, "lnot"); } void Floor::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "floor"); + unary_op_gpu(inputs, out, "floor"); } void Ceil::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "ceil"); + unary_op_gpu(inputs, out, "ceil"); } void Negative::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "neg"); + unary_op_gpu(inputs, out, "neg"); } void Round::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (issubdtype(in.dtype(), inexact)) { - unary_op(inputs, out, "round"); + unary_op_gpu(inputs, out, "round"); } else { // No-op integer types out.copy_shared_buffer(in); @@ -168,39 +186,39 @@ void Round::eval_gpu(const std::vector& inputs, array& out) { } void Sigmoid::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "sigmoid"); + unary_op_gpu(inputs, out, "sigmoid"); } void Sign::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "sign"); + unary_op_gpu(inputs, out, "sign"); } void Sin::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "sin"); + unary_op_gpu(inputs, out, "sin"); } void Sinh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "sinh"); + unary_op_gpu(inputs, out, "sinh"); } void Square::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "square"); + unary_op_gpu(inputs, out, "square"); } void Sqrt::eval_gpu(const std::vector& inputs, array& out) { if (recip_) { - unary_op(inputs, out, "rsqrt"); + unary_op_gpu(inputs, out, "rsqrt"); } else { - unary_op(inputs, out, "sqrt"); + unary_op_gpu(inputs, out, "sqrt"); } } void Tan::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "tan"); + unary_op_gpu(inputs, out, "tan"); } void Tanh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "tanh"); + unary_op_gpu(inputs, out, "tanh"); } } // namespace mlx::core diff --git a/mlx/backend/metal/unary.h b/mlx/backend/metal/unary.h new file mode 100644 index 000000000..19057076b --- /dev/null +++ b/mlx/backend/metal/unary.h @@ -0,0 +1,21 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void unary_op_gpu( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s); + +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const std::string op, + const Stream& s); + +} // namespace mlx::core diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index f3f6d4250..c82e2a457 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -6,4 +6,5 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp ) diff --git a/mlx/primitives.h b/mlx/primitives.h index 88f5def4b..c72447d4d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1794,14 +1794,6 @@ class Slice : public UnaryPrimitive { std::vector strides_; void eval(const std::vector& inputs, array& out); - - std::tuple> prepare_slice( - const array& in); - void shared_buffer_slice( - const array& in, - const std::vector& out_strides, - size_t data_offset, - array& out); }; class SliceUpdate : public UnaryPrimitive {