MLX
 
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include "mlx/allocator.h"
6#include "mlx/array.h"
8
9namespace mlx::core {
10
18
19inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
20 BinaryOpType bopt;
21 if (a.data_size() == 1 && b.data_size() == 1) {
23 } else if (a.data_size() == 1 && b.flags().contiguous) {
25 } else if (b.data_size() == 1 && a.flags().contiguous) {
27 } else if (
31 } else {
33 }
34 return bopt;
35}
36
38 const array& a,
39 const array& b,
40 array& out,
41 BinaryOpType bopt) {
42 bool b_donatable = is_donatable(b, out);
43 bool a_donatable = is_donatable(a, out);
44 switch (bopt) {
46 out.set_data(
48 break;
50 if (b_donatable) {
51 out.copy_shared_buffer(b);
52 } else {
53 out.set_data(
55 b.data_size(),
56 b.strides(),
57 b.flags());
58 }
59 break;
61 if (a_donatable) {
62 out.copy_shared_buffer(a);
63 } else {
64 out.set_data(
66 a.data_size(),
67 a.strides(),
68 a.flags());
69 }
70 break;
72 if (a_donatable) {
73 out.copy_shared_buffer(a);
74 } else if (b_donatable) {
75 out.copy_shared_buffer(b);
76 } else {
77 out.set_data(
79 a.data_size(),
80 a.strides(),
81 a.flags());
82 }
83 break;
85 if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
86 out.copy_shared_buffer(a);
87 } else if (
88 b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
89 out.copy_shared_buffer(b);
90 } else {
92 }
93 break;
94 }
95}
96
97} // namespace mlx::core
Definition array.h:24
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:313
const Strides & strides() const
The strides of the array.
Definition array.h:117
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
size_t size() const
The number of elements in the array.
Definition array.h:88
void copy_shared_buffer(const array &other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:83
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:327
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
BinaryOpType get_binary_op_type(const array &a, const array &b)
Definition binary.h:19
BinaryOpType
Definition binary.h:11
@ General
Definition binary.h:16
@ VectorVector
Definition binary.h:15
@ ScalarScalar
Definition binary.h:12
@ VectorScalar
Definition binary.h:14
@ ScalarVector
Definition binary.h:13
void set_binary_op_output_data(const array &a, const array &b, array &out, BinaryOpType bopt)
Definition binary.h:37
bool is_donatable(const array &in, const array &out)
Definition utils.h:155
Definition binary.h:35
Definition binary.h:15
Definition binary.h:55
bool row_contiguous
Definition array.h:244
bool col_contiguous
Definition array.h:250
bool contiguous
Definition array.h:238