Files
mlx/mlx/backend/common/binary.h

98 lines
2.6 KiB
C
Raw Normal View History

2023-11-30 11:12:53 -08:00
// Copyright © 2023 Apple Inc.
2023-11-29 10:42:59 -08:00
#pragma once
2023-11-29 10:42:59 -08:00
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
enum class BinaryOpType {
2023-11-29 10:42:59 -08:00
ScalarScalar,
ScalarVector,
VectorScalar,
VectorVector,
General,
};
inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
2023-11-29 10:42:59 -08:00
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = BinaryOpType::ScalarScalar;
2023-11-29 10:42:59 -08:00
} else if (a.data_size() == 1 && b.flags().contiguous) {
bopt = BinaryOpType::ScalarVector;
2023-11-29 10:42:59 -08:00
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = BinaryOpType::VectorScalar;
2023-11-29 10:42:59 -08:00
} else if (
(a.flags().row_contiguous && b.flags().row_contiguous) ||
(a.flags().col_contiguous && b.flags().col_contiguous)) {
bopt = BinaryOpType::VectorVector;
2023-11-29 10:42:59 -08:00
} else {
bopt = BinaryOpType::General;
2023-11-29 10:42:59 -08:00
}
return bopt;
}
inline void set_binary_op_output_data(
2023-11-29 10:42:59 -08:00
const array& a,
const array& b,
array& out,
BinaryOpType bopt) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
2023-11-29 10:42:59 -08:00
switch (bopt) {
case BinaryOpType::ScalarScalar:
2023-11-29 10:42:59 -08:00
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
}
2023-11-29 10:42:59 -08:00
break;
case BinaryOpType::VectorScalar:
if (a_donatable) {
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break;
case BinaryOpType::VectorVector:
if (a_donatable) {
out.copy_shared_buffer(a);
} else if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
2023-11-29 10:42:59 -08:00
break;
case BinaryOpType::General:
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
out.copy_shared_buffer(a);
} else if (
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
2023-11-29 10:42:59 -08:00
break;
}
}
} // namespace mlx::core