MLX
 
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include "mlx/allocator.h"
6#include "mlx/array.h"
8
9namespace mlx::core {
10
18
19inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
20 BinaryOpType bopt;
21 if (a.data_size() == 1 && b.data_size() == 1) {
23 } else if (a.data_size() == 1 && b.flags().contiguous) {
25 } else if (b.data_size() == 1 && a.flags().contiguous) {
27 } else if (
31 } else {
33 }
34 return bopt;
35}
36
38 const array& a,
39 const array& b,
40 array& out,
41 BinaryOpType bopt,
42 bool donate_with_move = false) {
43 bool b_donatable = is_donatable(b, out);
44 bool a_donatable = is_donatable(a, out);
45 switch (bopt) {
47 out.set_data(
49 break;
51 if (b_donatable) {
52 if (donate_with_move) {
53 out.move_shared_buffer(b);
54 } else {
55 out.copy_shared_buffer(b);
56 }
57 } else {
58 out.set_data(
60 b.data_size(),
61 b.strides(),
62 b.flags());
63 }
64 break;
66 if (a_donatable) {
67 if (donate_with_move) {
68 out.move_shared_buffer(a);
69 } else {
70 out.copy_shared_buffer(a);
71 }
72 } else {
73 out.set_data(
75 a.data_size(),
76 a.strides(),
77 a.flags());
78 }
79 break;
81 if (a_donatable) {
82 if (donate_with_move) {
83 out.move_shared_buffer(a);
84 } else {
85 out.copy_shared_buffer(a);
86 }
87 } else if (b_donatable) {
88 if (donate_with_move) {
89 out.move_shared_buffer(b);
90 } else {
91 out.copy_shared_buffer(b);
92 }
93 } else {
94 out.set_data(
96 a.data_size(),
97 a.strides(),
98 a.flags());
99 }
100 break;
102 if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
103 if (donate_with_move) {
104 out.move_shared_buffer(a);
105 } else {
106 out.copy_shared_buffer(a);
107 }
108 } else if (
109 b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
110 if (donate_with_move) {
111 out.move_shared_buffer(b);
112 } else {
113 out.copy_shared_buffer(b);
114 }
115 } else {
117 }
118 break;
119 }
120}
121
122} // namespace mlx::core
Definition array.h:24
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:318
const Strides & strides() const
The strides of the array.
Definition array.h:117
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
size_t size() const
The number of elements in the array.
Definition array.h:88
void copy_shared_buffer(const array &other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
void move_shared_buffer(array other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:83
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:332
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
BinaryOpType get_binary_op_type(const array &a, const array &b)
Definition binary.h:19
BinaryOpType
Definition binary.h:11
@ General
Definition binary.h:16
@ VectorVector
Definition binary.h:15
@ ScalarScalar
Definition binary.h:12
@ VectorScalar
Definition binary.h:14
@ ScalarVector
Definition binary.h:13
void set_binary_op_output_data(const array &a, const array &b, array &out, BinaryOpType bopt, bool donate_with_move=false)
Definition binary.h:37
bool is_donatable(const array &in, const array &out)
Definition utils.h:155
Definition binary.h:40
Definition binary.h:16
Definition binary.h:64
bool row_contiguous
Definition array.h:237
bool col_contiguous
Definition array.h:243
bool contiguous
Definition array.h:231