Add some internal GPU apis (#1177)

* Add unary/binary/ternay/slice/concat internal GPU ops

* add pad internal op

* formatting + no_cpu fix
This commit is contained in:
Alex Barron 2024-06-04 09:24:26 -07:00 committed by GitHub
parent ea9090bbc4
commit 375a8bbdcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 449 additions and 203 deletions

View File

@ -48,6 +48,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp

View File

@ -250,49 +250,6 @@ void Split::eval(
}
}
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;

View File

@ -11,6 +11,7 @@
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h"
@ -492,7 +493,8 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices_, strides_);
// Do copy if needed
if (copy_needed) {

View File

@ -0,0 +1,52 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
} // namespace mlx::core

View File

@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out);
} // namespace mlx::core

View File

@ -135,6 +135,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@ -10,16 +10,14 @@ namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
void binary_op(
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op) {
assert(inputs.size() == 2);
const std::string op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
auto& out = outputs[0];
if (out.size() == 0) {
@ -61,7 +59,6 @@ void binary_op(
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
@ -120,15 +117,36 @@ void binary_op(
}
}
void binary_op(
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
std::vector<array>& outputs,
const std::string op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
binary_op_gpu_inplace(inputs, outputs, op, s);
}
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op) {
auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s);
}
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
if (out.size() == 0) {
return;
}
@ -168,7 +186,6 @@ void binary_op(
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a, out);
@ -221,102 +238,123 @@ void binary_op(
}
}
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
binary_op_gpu_inplace(inputs, out, op, s);
}
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s);
}
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "add");
binary_op_gpu(inputs, out, "add");
}
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "arctan2");
binary_op_gpu(inputs, out, "arctan2");
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op(inputs, out, "bitwise_and");
binary_op_gpu(inputs, out, "bitwise_and");
break;
case BitwiseBinary::Or:
binary_op(inputs, out, "bitwise_or");
binary_op_gpu(inputs, out, "bitwise_or");
break;
case BitwiseBinary::Xor:
binary_op(inputs, out, "bitwise_xor");
binary_op_gpu(inputs, out, "bitwise_xor");
break;
case BitwiseBinary::LeftShift:
binary_op(inputs, out, "left_shift");
binary_op_gpu(inputs, out, "left_shift");
break;
case BitwiseBinary::RightShift:
binary_op(inputs, out, "right_shift");
binary_op_gpu(inputs, out, "right_shift");
break;
}
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
binary_op_gpu(inputs, out, "div");
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
binary_op(inputs, outputs, "divmod");
binary_op_gpu(inputs, outputs, "divmod");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "rem");
binary_op_gpu(inputs, out, "rem");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
binary_op_gpu(inputs, out, equal_nan_ ? "naneq" : "eq");
}
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "ge");
binary_op_gpu(inputs, out, "ge");
}
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "geq");
binary_op_gpu(inputs, out, "geq");
}
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "le");
binary_op_gpu(inputs, out, "le");
}
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq");
binary_op_gpu(inputs, out, "leq");
}
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "land");
binary_op_gpu(inputs, out, "land");
}
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lor");
binary_op_gpu(inputs, out, "lor");
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lae");
binary_op_gpu(inputs, out, "lae");
}
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "max");
binary_op_gpu(inputs, out, "max");
}
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
binary_op_gpu(inputs, out, "min");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
binary_op_gpu(inputs, out, "mul");
}
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "neq");
binary_op_gpu(inputs, out, "neq");
}
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "pow");
binary_op_gpu(inputs, out, "pow");
}
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "sub");
binary_op_gpu(inputs, out, "sub");
}
} // namespace mlx::core

View File

@ -0,0 +1,33 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const Stream& s);
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s);
} // namespace mlx::core

View File

@ -7,6 +7,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@ -163,30 +164,7 @@ void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis_));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
auto& d = metal::device(stream().device);
auto& compute_encoder = d.get_command_encoder(stream().index);
auto concurrent_ctx = compute_encoder.start_concurrent();
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
}
concatenate_gpu(inputs, out, axis_, stream());
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -238,23 +216,7 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy_gpu(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -331,28 +293,7 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_gpu_inplace(
/* const array& in = */ in,
/* array& out = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General,
/* const Stream& s = */ stream());
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
}
slice_gpu(in, out, start_indices_, strides_, stream());
}
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
#include <numeric>
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
std::vector<int> start_indices,
std::vector<int> strides,
const Stream& s) {
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices, strides);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_gpu_inplace(
/* const array& in = */ in,
/* array& out = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General,
/* const Stream& s = */ s);
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
}
}
void concatenate_gpu(
const std::vector<array>& inputs,
array& out,
int axis,
const Stream& s) {
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
auto concurrent_ctx = compute_encoder.start_concurrent();
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s);
}
}
void pad_gpu(
const array& in,
const array& val,
array& out,
std::vector<int> axes,
std::vector<int> low_pad_size,
const Stream& s) {
// Fill output with val
copy_gpu(val, out, CopyType::Scalar, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core

View File

@ -0,0 +1,30 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
std::vector<int> start_indices,
std::vector<int> strides,
const Stream& s);
void concatenate_gpu(
const std::vector<array>& inputs,
array& out,
int axis,
const Stream& s);
void pad_gpu(
const array& in,
const array& val,
array& out,
std::vector<int> axes,
std::vector<int> low_pad_size,
const Stream& s);
} // namespace mlx::core

View File

@ -10,16 +10,16 @@ namespace mlx::core {
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
void ternary_op(
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op) {
const std::string op,
const Stream& s) {
assert(inputs.size() == 3);
auto& a = inputs[0];
auto& b = inputs[1];
auto& c = inputs[2];
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
if (out.size() == 0) {
return;
@ -47,7 +47,6 @@ void ternary_op(
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_ternary_kernel(d, kernel_name, out);
@ -101,8 +100,29 @@ void ternary_op(
}
}
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto& c = inputs[2];
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
ternary_op_gpu_inplace(inputs, out, op, s);
}
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
auto& s = out.primitive().stream();
ternary_op_gpu(inputs, out, op, s);
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op(inputs, out, "select");
ternary_op_gpu(inputs, out, "select");
}
} // namespace mlx::core

View File

@ -0,0 +1,21 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void ternary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s);
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s);
} // namespace mlx::core

View File

@ -7,30 +7,17 @@
namespace mlx::core {
void unary_op(
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op) {
const std::string op,
const Stream& s) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
if (in.size() == 0) {
return;
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out);
@ -59,39 +46,70 @@ void unary_op(
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
unary_op_gpu_inplace(inputs, out, op, s);
}
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
auto& s = out.primitive().stream();
unary_op_gpu(inputs, out, op, s);
}
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "abs");
unary_op_gpu(inputs, out, "abs");
}
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccos");
unary_op_gpu(inputs, out, "arccos");
}
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccosh");
unary_op_gpu(inputs, out, "arccosh");
}
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsin");
unary_op_gpu(inputs, out, "arcsin");
}
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsinh");
unary_op_gpu(inputs, out, "arcsinh");
}
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctan");
unary_op_gpu(inputs, out, "arctan");
}
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctanh");
unary_op_gpu(inputs, out, "arctanh");
}
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_op(inputs, out, "conj");
unary_op_gpu(inputs, out, "conj");
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
@ -99,68 +117,68 @@ void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cos");
unary_op_gpu(inputs, out, "cos");
}
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cosh");
unary_op_gpu(inputs, out, "cosh");
}
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erf");
unary_op_gpu(inputs, out, "erf");
}
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erfinv");
unary_op_gpu(inputs, out, "erfinv");
}
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "exp");
unary_op_gpu(inputs, out, "exp");
}
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "expm1");
unary_op_gpu(inputs, out, "expm1");
}
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:
unary_op(inputs, out, "log");
unary_op_gpu(inputs, out, "log");
break;
case Base::two:
unary_op(inputs, out, "log2");
unary_op_gpu(inputs, out, "log2");
break;
case Base::ten:
unary_op(inputs, out, "log10");
unary_op_gpu(inputs, out, "log10");
break;
}
}
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "log1p");
unary_op_gpu(inputs, out, "log1p");
}
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "lnot");
unary_op_gpu(inputs, out, "lnot");
}
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "floor");
unary_op_gpu(inputs, out, "floor");
}
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "ceil");
unary_op_gpu(inputs, out, "ceil");
}
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "neg");
unary_op_gpu(inputs, out, "neg");
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_op(inputs, out, "round");
unary_op_gpu(inputs, out, "round");
} else {
// No-op integer types
out.copy_shared_buffer(in);
@ -168,39 +186,39 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sigmoid");
unary_op_gpu(inputs, out, "sigmoid");
}
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sign");
unary_op_gpu(inputs, out, "sign");
}
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sin");
unary_op_gpu(inputs, out, "sin");
}
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sinh");
unary_op_gpu(inputs, out, "sinh");
}
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "square");
unary_op_gpu(inputs, out, "square");
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
if (recip_) {
unary_op(inputs, out, "rsqrt");
unary_op_gpu(inputs, out, "rsqrt");
} else {
unary_op(inputs, out, "sqrt");
unary_op_gpu(inputs, out, "sqrt");
}
}
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tan");
unary_op_gpu(inputs, out, "tan");
}
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tanh");
unary_op_gpu(inputs, out, "tanh");
}
} // namespace mlx::core

21
mlx/backend/metal/unary.h Normal file
View File

@ -0,0 +1,21 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s);
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s);
} // namespace mlx::core

View File

@ -6,4 +6,5 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
)

View File

@ -1794,14 +1794,6 @@ class Slice : public UnaryPrimitive {
std::vector<int> strides_;
void eval(const std::vector<array>& inputs, array& out);
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out);
};
class SliceUpdate : public UnaryPrimitive {