mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Move common gpu primitives to backend/gpu (#2145)
This commit is contained in:
parent
af705590ac
commit
1683975acf
@ -47,6 +47,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
|||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx
|
target_sources(mlx
|
||||||
|
5
mlx/backend/gpu/CMakeLists.txt
Normal file
5
mlx/backend/gpu/CMakeLists.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
49
mlx/backend/gpu/copy.cpp
Normal file
49
mlx/backend/gpu/copy.cpp
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||||
|
bool donated = set_copy_output_data(in, out, ctype);
|
||||||
|
if (donated && in.dtype() == out.dtype()) {
|
||||||
|
// If the output has the same type as the input then there is nothing to
|
||||||
|
// copy, just use the buffer.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
|
ctype = CopyType::General;
|
||||||
|
}
|
||||||
|
copy_gpu_inplace(in, out, ctype, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||||
|
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s) {
|
||||||
|
assert(in.shape() == out.shape());
|
||||||
|
return copy_gpu_inplace(
|
||||||
|
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const Strides& i_strides,
|
||||||
|
int64_t i_offset,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s) {
|
||||||
|
assert(in.shape() == out.shape());
|
||||||
|
return copy_gpu_inplace(
|
||||||
|
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -5,6 +5,8 @@
|
|||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Generic copy inplace
|
// Generic copy inplace
|
217
mlx/backend/gpu/primitives.cpp
Normal file
217
mlx/backend/gpu/primitives.cpp
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#define MLX_PROFILER_RANGE(message)
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void reshape(const array& in, array& out, Stream s) {
|
||||||
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
|
if (copy_necessary) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
copy_gpu_inplace(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
in.shape(),
|
||||||
|
in.strides(),
|
||||||
|
make_contiguous_strides(in.shape()),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
CopyType::General,
|
||||||
|
s);
|
||||||
|
} else {
|
||||||
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("AsType::eval_gpu");
|
||||||
|
CopyType ctype =
|
||||||
|
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
copy_gpu(inputs[0], out, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
|
||||||
|
concatenate_gpu(inputs, out, axis_, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
constexpr size_t extra_bytes = 16384;
|
||||||
|
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||||
|
(in.flags().row_contiguous ||
|
||||||
|
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
copy_gpu(in, out, CopyType::General);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Copy::eval_gpu");
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CustomTransforms::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
|
||||||
|
eval(inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Depends::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
MLX_PROFILER_RANGE("Depends::eval_gpu");
|
||||||
|
eval(inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Full::eval_gpu");
|
||||||
|
auto in = inputs[0];
|
||||||
|
CopyType ctype;
|
||||||
|
if (in.data_size() == 1) {
|
||||||
|
ctype = CopyType::Scalar;
|
||||||
|
} else if (in.flags().contiguous) {
|
||||||
|
ctype = CopyType::Vector;
|
||||||
|
} else {
|
||||||
|
ctype = CopyType::General;
|
||||||
|
}
|
||||||
|
copy_gpu(in, out, ctype);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
||||||
|
reshape(inputs[0], out, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
// Inputs must be base input array and scalar val array
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& val = inputs[1];
|
||||||
|
|
||||||
|
// Padding value must be a scalar
|
||||||
|
assert(val.size() == 1);
|
||||||
|
|
||||||
|
// Padding value, input and output must be of the same type
|
||||||
|
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||||
|
|
||||||
|
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
||||||
|
reshape(inputs[0], out, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Split::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
MLX_PROFILER_RANGE("Split::eval_gpu");
|
||||||
|
eval(inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Slice::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Transpose::eval_gpu");
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
||||||
|
reshape(inputs[0], out, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
MLX_PROFILER_RANGE("View::eval_gpu");
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto ibytes = size_of(in.dtype());
|
||||||
|
auto obytes = size_of(out.dtype());
|
||||||
|
// Conditions for buffer copying (disjunction):
|
||||||
|
// - type size is the same
|
||||||
|
// - type size is smaller and the last axis is contiguous
|
||||||
|
// - the entire array is row contiguous
|
||||||
|
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
|
||||||
|
in.flags().row_contiguous) {
|
||||||
|
auto strides = in.strides();
|
||||||
|
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||||
|
strides[i] *= ibytes;
|
||||||
|
strides[i] /= obytes;
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(
|
||||||
|
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||||
|
} else {
|
||||||
|
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||||
|
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||||
|
|
||||||
|
auto flags = out.flags();
|
||||||
|
flags.contiguous = true;
|
||||||
|
flags.row_contiguous = true;
|
||||||
|
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||||
|
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||||
|
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
44
mlx/backend/gpu/slicing.cpp
Normal file
44
mlx/backend/gpu/slicing.cpp
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/slicing.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void slice_gpu(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const Shape& start_indices,
|
||||||
|
const Shape& strides,
|
||||||
|
const Stream& s) {
|
||||||
|
slice(in, out, start_indices, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void pad_gpu(
|
||||||
|
const array& in,
|
||||||
|
const array& val,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const Shape& low_pad_size,
|
||||||
|
const Stream& s) {
|
||||||
|
// Fill output with val
|
||||||
|
fill_gpu(val, out, s);
|
||||||
|
|
||||||
|
// Find offset for start of input values
|
||||||
|
size_t data_offset = 0;
|
||||||
|
for (int i = 0; i < axes.size(); i++) {
|
||||||
|
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
||||||
|
data_offset += out.strides()[ax] * low_pad_size[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract slice from output where input will be pasted
|
||||||
|
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||||
|
out_slice.copy_shared_buffer(
|
||||||
|
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -5,7 +5,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
@ -1,35 +1,15 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <sstream>
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
|
||||||
bool donated = set_copy_output_data(in, out, ctype);
|
|
||||||
if (donated && in.dtype() == out.dtype()) {
|
|
||||||
// If the output has the same type as the input then there is nothing to
|
|
||||||
// copy, just use the buffer.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (ctype == CopyType::GeneralGeneral) {
|
|
||||||
ctype = CopyType::General;
|
|
||||||
}
|
|
||||||
copy_gpu_inplace(in, out, ctype, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
|
||||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
|
||||||
}
|
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
void copy_gpu_inplace(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
@ -184,28 +164,6 @@ void copy_gpu_inplace(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
CopyType ctype,
|
|
||||||
const Stream& s) {
|
|
||||||
assert(in.shape() == out.shape());
|
|
||||||
return copy_gpu_inplace(
|
|
||||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
const Strides& i_strides,
|
|
||||||
int64_t i_offset,
|
|
||||||
CopyType ctype,
|
|
||||||
const Stream& s) {
|
|
||||||
assert(in.shape() == out.shape());
|
|
||||||
return copy_gpu_inplace(
|
|
||||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
|
@ -7,10 +7,10 @@
|
|||||||
|
|
||||||
#include "mlx/3rdparty/pocketfft.h"
|
#include "mlx/3rdparty/pocketfft.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
#include "mlx/backend/metal/binary.h"
|
#include "mlx/backend/metal/binary.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/slicing.h"
|
|
||||||
#include "mlx/backend/metal/unary.h"
|
#include "mlx/backend/metal/unary.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include "mlx/backend/common/hadamard.h"
|
#include "mlx/backend/common/hadamard.h"
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
#include "mlx/backend/metal/jit/indexing.h"
|
#include "mlx/backend/metal/jit/indexing.h"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/broadcasting.h"
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/reduce.h"
|
#include "mlx/backend/metal/reduce.h"
|
||||||
|
@ -7,10 +7,10 @@
|
|||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/slicing.h"
|
#include "mlx/backend/common/slicing.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/slicing.h"
|
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
@ -25,25 +25,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
|||||||
enc.set_bytes(step, 1);
|
enc.set_bytes(step, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void reshape(const array& in, array& out, Stream s) {
|
|
||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
|
||||||
if (copy_necessary) {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
copy_gpu_inplace(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
in.shape(),
|
|
||||||
in.strides(),
|
|
||||||
make_contiguous_strides(in.shape()),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
CopyType::General,
|
|
||||||
s);
|
|
||||||
} else {
|
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static array compute_dynamic_offset(
|
static array compute_dynamic_offset(
|
||||||
const array& indices,
|
const array& indices,
|
||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
@ -226,105 +207,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
CopyType ctype =
|
|
||||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
|
||||||
copy_gpu(inputs[0], out, ctype);
|
|
||||||
}
|
|
||||||
|
|
||||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
concatenate_gpu(inputs, out, axis_, stream());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
constexpr size_t extra_bytes = 16384;
|
|
||||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
|
||||||
(in.flags().row_contiguous ||
|
|
||||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
} else {
|
|
||||||
copy_gpu(in, out, CopyType::General);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CustomTransforms::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
eval(inputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Depends::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
eval(inputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
auto in = inputs[0];
|
|
||||||
CopyType ctype;
|
|
||||||
if (in.data_size() == 1) {
|
|
||||||
ctype = CopyType::Scalar;
|
|
||||||
} else if (in.flags().contiguous) {
|
|
||||||
ctype = CopyType::Vector;
|
|
||||||
} else {
|
|
||||||
ctype = CopyType::General;
|
|
||||||
}
|
|
||||||
copy_gpu(in, out, ctype);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
reshape(inputs[0], out, stream());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
reshape(inputs[0], out, stream());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
|
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
// Inputs must be base input array and scalar val array
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
auto& val = inputs[1];
|
|
||||||
|
|
||||||
// Padding value must be a scalar
|
|
||||||
assert(val.size() == 1);
|
|
||||||
|
|
||||||
// Padding value, input and output must be of the same type
|
|
||||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
|
||||||
|
|
||||||
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
|
|
||||||
}
|
|
||||||
|
|
||||||
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
@ -370,27 +256,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
reshape(inputs[0], out, stream());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Split::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
eval(inputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& in = inputs[0];
|
|
||||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
|
||||||
}
|
|
||||||
|
|
||||||
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
out.set_data(nullptr);
|
out.set_data(nullptr);
|
||||||
@ -492,18 +357,6 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/* const Stream& s = */ stream());
|
/* const Stream& s = */ stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void QRF::eval_gpu(
|
void QRF::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
@ -537,35 +390,4 @@ void LUF::eval_gpu(
|
|||||||
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
auto& in = inputs[0];
|
|
||||||
auto ibytes = size_of(in.dtype());
|
|
||||||
auto obytes = size_of(out.dtype());
|
|
||||||
// Conditions for buffer copying (disjunction):
|
|
||||||
// - type size is the same
|
|
||||||
// - type size is smaller and the last axis is contiguous
|
|
||||||
// - the entire array is row contiguous
|
|
||||||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
|
|
||||||
in.flags().row_contiguous) {
|
|
||||||
auto strides = in.strides();
|
|
||||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
|
||||||
strides[i] *= ibytes;
|
|
||||||
strides[i] /= obytes;
|
|
||||||
}
|
|
||||||
out.copy_shared_buffer(
|
|
||||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
|
||||||
} else {
|
|
||||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
|
||||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
|
||||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
|
||||||
|
|
||||||
auto flags = out.flags();
|
|
||||||
flags.contiguous = true;
|
|
||||||
flags.row_contiguous = true;
|
|
||||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
|
||||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
|
||||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/broadcasting.h"
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/reduce.h"
|
#include "mlx/backend/metal/reduce.h"
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
@ -2,21 +2,12 @@
|
|||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "mlx/backend/common/slicing.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void slice_gpu(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
const Shape& start_indices,
|
|
||||||
const Shape& strides,
|
|
||||||
const Stream& s) {
|
|
||||||
slice(in, out, start_indices, strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
void concatenate_gpu(
|
void concatenate_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out,
|
array& out,
|
||||||
@ -48,30 +39,4 @@ void concatenate_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void pad_gpu(
|
|
||||||
const array& in,
|
|
||||||
const array& val,
|
|
||||||
array& out,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
const Shape& low_pad_size,
|
|
||||||
const Stream& s) {
|
|
||||||
// Fill output with val
|
|
||||||
fill_gpu(val, out, s);
|
|
||||||
|
|
||||||
// Find offset for start of input values
|
|
||||||
size_t data_offset = 0;
|
|
||||||
for (int i = 0; i < axes.size(); i++) {
|
|
||||||
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
|
|
||||||
data_offset += out.strides()[ax] * low_pad_size[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract slice from output where input will be pasted
|
|
||||||
array out_slice(in.shape(), out.dtype(), nullptr, {});
|
|
||||||
out_slice.copy_shared_buffer(
|
|
||||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
|
||||||
|
|
||||||
// Copy input values into the slice
|
|
||||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
Loading…
Reference in New Issue
Block a user