diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 465954d6f..00898e73e 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -47,6 +47,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if(MLX_BUILD_METAL) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) else() target_sources(mlx diff --git a/mlx/backend/gpu/CMakeLists.txt b/mlx/backend/gpu/CMakeLists.txt new file mode 100644 index 000000000..0396ae03a --- /dev/null +++ b/mlx/backend/gpu/CMakeLists.txt @@ -0,0 +1,5 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp new file mode 100644 index 000000000..6127ac921 --- /dev/null +++ b/mlx/backend/gpu/copy.cpp @@ -0,0 +1,49 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/gpu/copy.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + bool donated = set_copy_output_data(in, out, ctype); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + +void copy_gpu(const array& in, array& out, CopyType ctype) { + copy_gpu(in, out, ctype, out.primitive().stream()); +} + +void copy_gpu_inplace( + const array& in, + array& out, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); +} + +void copy_gpu_inplace( + const array& in, + array& out, + const Strides& i_strides, + int64_t i_offset, + CopyType ctype, + const Stream& s) { + assert(in.shape() == out.shape()); + return copy_gpu_inplace( + in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/copy.h b/mlx/backend/gpu/copy.h similarity index 98% rename from mlx/backend/metal/copy.h rename to mlx/backend/gpu/copy.h index 37c60df42..020f579e4 100644 --- a/mlx/backend/metal/copy.h +++ b/mlx/backend/gpu/copy.h @@ -5,6 +5,8 @@ #include "mlx/backend/common/copy.h" #include "mlx/stream.h" +#include + namespace mlx::core { // Generic copy inplace diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp new file mode 100644 index 000000000..cd9296075 --- /dev/null +++ b/mlx/backend/gpu/primitives.cpp @@ -0,0 +1,217 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/primitives.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +#include + +#define MLX_PROFILER_RANGE(message) + +namespace mlx::core { + +namespace { + +void reshape(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +} // namespace + +void AsStrided::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsStrided::eval_gpu"); + eval(inputs, out); +} + +void AsType::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("AsType::eval_gpu"); + CopyType ctype = + inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; + copy_gpu(inputs[0], out, ctype); +} + +void Broadcast::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Broadcast::eval_gpu"); + eval(inputs, out); +} + +void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu"); + eval(inputs, out); +} + +void Concatenate::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Concatenate::eval_gpu"); + concatenate_gpu(inputs, out, axis_, stream()); +} + +void Contiguous::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Contiguous::eval_gpu"); + assert(inputs.size() == 1); + auto& in = inputs[0]; + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { + out.copy_shared_buffer(in); + } else { + copy_gpu(in, out, CopyType::General); + } +} + +void Copy::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Copy::eval_gpu"); + eval(inputs, out); +} + +void CustomTransforms::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("CustomTransforms::eval_gpu"); + eval(inputs, outputs); +} + +void Depends::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Depends::eval_gpu"); + eval(inputs, outputs); +} + +void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("ExpandDims::eval_gpu"); + eval(inputs, out); +} + +void Full::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Full::eval_gpu"); + auto in = inputs[0]; + CopyType ctype; + if (in.data_size() == 1) { + ctype = CopyType::Scalar; + } else if (in.flags().contiguous) { + ctype = CopyType::Vector; + } else { + ctype = CopyType::General; + } + copy_gpu(in, out, ctype); +} + +void Flatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Flatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("NumberOfElements::eval_gpu"); + eval(inputs, out); +} + +void Pad::eval_gpu(const std::vector& inputs, array& out) { + // Inputs must be base input array and scalar val array + assert(inputs.size() == 2); + auto& in = inputs[0]; + auto& val = inputs[1]; + + // Padding value must be a scalar + assert(val.size() == 1); + + // Padding value, input and output must be of the same type + assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); + + pad_gpu(in, val, out, axes_, low_pad_size_, stream()); +} + +void Reshape::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Reshape::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void Split::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + MLX_PROFILER_RANGE("Split::eval_gpu"); + eval(inputs, outputs); +} + +void Slice::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Slice::eval_gpu"); + assert(inputs.size() == 1); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + slice_gpu(in, out, start_indices_, strides_, stream()); +} + +void Squeeze::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Squeeze::eval_gpu"); + eval(inputs, out); +} + +void StopGradient::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("StopGradient::eval_gpu"); + eval(inputs, out); +} + +void Transpose::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Transpose::eval_gpu"); + eval(inputs, out); +} + +void Unflatten::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Unflatten::eval_gpu"); + reshape(inputs[0], out, stream()); +} + +void View::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("View::eval_gpu"); + auto& in = inputs[0]; + auto ibytes = size_of(in.dtype()); + auto obytes = size_of(out.dtype()); + // Conditions for buffer copying (disjunction): + // - type size is the same + // - type size is smaller and the last axis is contiguous + // - the entire array is row contiguous + if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || + in.flags().row_contiguous) { + auto strides = in.strides(); + for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { + strides[i] *= ibytes; + strides[i] /= obytes; + } + out.copy_shared_buffer( + in, strides, in.flags(), in.data_size() * ibytes / obytes); + } else { + auto tmp = array(in.shape(), in.dtype(), nullptr, {}); + tmp.set_data(allocator::malloc(tmp.nbytes())); + copy_gpu_inplace(in, tmp, CopyType::General, stream()); + + auto flags = out.flags(); + flags.contiguous = true; + flags.row_contiguous = true; + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/gpu/slicing.cpp b/mlx/backend/gpu/slicing.cpp new file mode 100644 index 000000000..fde2a01cd --- /dev/null +++ b/mlx/backend/gpu/slicing.cpp @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" + +namespace mlx::core { + +void slice_gpu( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides, + const Stream& s) { + slice(in, out, start_indices, strides); +} + +void pad_gpu( + const array& in, + const array& val, + array& out, + const std::vector& axes, + const Shape& low_pad_size, + const Stream& s) { + // Fill output with val + fill_gpu(val, out, s); + + // Find offset for start of input values + size_t data_offset = 0; + for (int i = 0; i < axes.size(); i++) { + auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; + data_offset += out.strides()[ax] * low_pad_size[i]; + } + + // Extract slice from output where input will be pasted + array out_slice(in.shape(), out.dtype(), nullptr, {}); + out_slice.copy_shared_buffer( + out, out.strides(), out.flags(), out_slice.size(), data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/slicing.h b/mlx/backend/gpu/slicing.h similarity index 100% rename from mlx/backend/metal/slicing.h rename to mlx/backend/gpu/slicing.h diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 9075ea4c5..ae31a6cff 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -5,7 +5,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index ee004359f..8dfe15c11 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -1,35 +1,15 @@ // Copyright © 2023-2024 Apple Inc. -#include - +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" -#include "mlx/primitives.h" namespace mlx::core { constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; -void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { - bool donated = set_copy_output_data(in, out, ctype); - if (donated && in.dtype() == out.dtype()) { - // If the output has the same type as the input then there is nothing to - // copy, just use the buffer. - return; - } - if (ctype == CopyType::GeneralGeneral) { - ctype = CopyType::General; - } - copy_gpu_inplace(in, out, ctype, s); -} - -void copy_gpu(const array& in, array& out, CopyType ctype) { - copy_gpu(in, out, ctype, out.primitive().stream()); -} - void copy_gpu_inplace( const array& in, array& out, @@ -184,28 +164,6 @@ void copy_gpu_inplace( } } -void copy_gpu_inplace( - const array& in, - array& out, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); -} - -void copy_gpu_inplace( - const array& in, - array& out, - const Strides& i_strides, - int64_t i_offset, - CopyType ctype, - const Stream& s) { - assert(in.shape() == out.shape()); - return copy_gpu_inplace( - in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); -} - void fill_gpu(const array& val, array& out, const Stream& s) { if (out.size() == 0) { return; diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 8a672289a..ea4f258cc 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -1,6 +1,6 @@ // Copyright © 2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 82e8fff7d..a800d2e0f 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -4,7 +4,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #include "mlx/distributed/ops.h" diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 153c62c02..011eb7ebb 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -7,10 +7,10 @@ #include "mlx/3rdparty/pocketfft.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/binary.h" -#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 89b970fce..65a877151 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -3,7 +3,7 @@ #include "mlx/backend/common/hadamard.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/kernels.h" diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index d2a263051..cccfd908a 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -2,7 +2,7 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/indexing.h" diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp index 4901190e1..e53bc58d9 100644 --- a/mlx/backend/metal/logsumexp.cpp +++ b/mlx/backend/metal/logsumexp.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index f55d20c9f..71221f8d9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -7,7 +7,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c1d993d2a..21142183e 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -1,7 +1,7 @@ // Copyright © 2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/reduce.h" diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6946ffb9e..860e9ddd7 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -7,10 +7,10 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(step, 1); } -void reshape(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - static array compute_dynamic_offset( const array& indices, const Strides& strides, @@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { } } -void AsType::eval_gpu(const std::vector& inputs, array& out) { - CopyType ctype = - inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; - copy_gpu(inputs[0], out, ctype); -} - -void AsStrided::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Broadcast::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Concatenate::eval_gpu(const std::vector& inputs, array& out) { - concatenate_gpu(inputs, out, axis_, stream()); -} - -void Contiguous::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - constexpr size_t extra_bytes = 16384; - if (in.buffer_size() <= out.nbytes() + extra_bytes && - (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous))) { - out.copy_shared_buffer(in); - } else { - copy_gpu(in, out, CopyType::General); - } -} - -void Copy::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void CustomTransforms::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Depends::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Full::eval_gpu(const std::vector& inputs, array& out) { - auto in = inputs[0]; - CopyType ctype; - if (in.data_size() == 1) { - ctype = CopyType::Scalar; - } else if (in.flags().contiguous) { - ctype = CopyType::Vector; - } else { - ctype = CopyType::General; - } - copy_gpu(in, out, ctype); -} - -void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Flatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Unflatten::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - void Load::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Load::eval_gpu] Not implemented."); } -void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Pad::eval_gpu(const std::vector& inputs, array& out) { - // Inputs must be base input array and scalar val array - assert(inputs.size() == 2); - auto& in = inputs[0]; - auto& val = inputs[1]; - - // Padding value must be a scalar - assert(val.size() == 1); - - // Padding value, input and output must be of the same type - assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); - - pad_gpu(in, val, out, axes_, low_pad_size_, stream()); -} - void RandomBits::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); @@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } -void Reshape::eval_gpu(const std::vector& inputs, array& out) { - reshape(inputs[0], out, stream()); -} - -void Split::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - eval(inputs, outputs); -} - -void Slice::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - slice_gpu(in, out, start_indices_, strides_, stream()); -} - void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { if (out.size() == 0) { out.set_data(nullptr); @@ -492,18 +357,6 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { /* const Stream& s = */ stream()); } -void Squeeze::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void StopGradient::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - -void Transpose::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); -} - void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -537,35 +390,4 @@ void LUF::eval_gpu( throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); } -void View::eval_gpu(const std::vector& inputs, array& out) { - auto& in = inputs[0]; - auto ibytes = size_of(in.dtype()); - auto obytes = size_of(out.dtype()); - // Conditions for buffer copying (disjunction): - // - type size is the same - // - type size is smaller and the last axis is contiguous - // - the entire array is row contiguous - if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) || - in.flags().row_contiguous) { - auto strides = in.strides(); - for (int i = 0; i < static_cast(strides.size()) - 1; ++i) { - strides[i] *= ibytes; - strides[i] /= obytes; - } - out.copy_shared_buffer( - in, strides, in.flags(), in.data_size() * ibytes / obytes); - } else { - auto tmp = array(in.shape(), in.dtype(), nullptr, {}); - tmp.set_data(allocator::malloc(tmp.nbytes())); - copy_gpu_inplace(in, tmp, CopyType::General, stream()); - - auto flags = out.flags(); - flags.contiguous = true; - flags.row_contiguous = true; - auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); - flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - out.copy_shared_buffer(tmp, out.strides(), flags, out.size()); - } -} - } // namespace mlx::core diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 6f5807543..11a2355cc 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -4,7 +4,7 @@ #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/reduce.h" diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index c5650bdd7..8cb55ba58 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 060758333..d8201afe6 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -1,5 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index d75e6d87d..3c7b7ff19 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -2,7 +2,7 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/steel/attn/params.h" diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index b1800fea9..3c4051105 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 6ab08a108..3e1a8b541 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -2,21 +2,12 @@ #include -#include "mlx/backend/common/slicing.h" -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" namespace mlx::core { -void slice_gpu( - const array& in, - array& out, - const Shape& start_indices, - const Shape& strides, - const Stream& s) { - slice(in, out, start_indices, strides); -} - void concatenate_gpu( const std::vector& inputs, array& out, @@ -48,30 +39,4 @@ void concatenate_gpu( } } -void pad_gpu( - const array& in, - const array& val, - array& out, - const std::vector& axes, - const Shape& low_pad_size, - const Stream& s) { - // Fill output with val - fill_gpu(val, out, s); - - // Find offset for start of input values - size_t data_offset = 0; - for (int i = 0; i < axes.size(); i++) { - auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; - data_offset += out.strides()[ax] * low_pad_size[i]; - } - - // Extract slice from output where input will be pasted - array out_slice(in.shape(), out.dtype(), nullptr, {}); - out_slice.copy_shared_buffer( - out, out.strides(), out.flags(), out_slice.size(), data_offset); - - // Copy input values into the slice - copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); -} - } // namespace mlx::core diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 224721a50..59662b05d 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 543dfd180..3c84022f2 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -2,7 +2,7 @@ #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h"