MLX
 
Loading...
Searching...
No Matches
ternary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include "mlx/allocator.h"
5#include "mlx/array.h"
7
8namespace mlx::core {
9
10// TODO: Add support for more combinations of input types.
16
17inline TernaryOpType
18get_ternary_op_type(const array& a, const array& b, const array& c) {
19 TernaryOpType topt;
20 if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
22 } else if (
24 c.flags().row_contiguous) ||
26 c.flags().col_contiguous)) {
28 } else {
30 }
31 return topt;
32}
33
35 const array& a,
36 const array& b,
37 const array& c,
38 array& out,
39 TernaryOpType topt,
40 bool donate_with_move = false) {
41 auto maybe_donate = [&out, donate_with_move](const array& x) {
42 if (is_donatable(x, out)) {
43 if (donate_with_move) {
44 out.move_shared_buffer(x);
45 } else {
46 out.copy_shared_buffer(x);
47 }
48 return true;
49 }
50 return false;
51 };
52
53 switch (topt) {
55 out.set_data(
57 break;
59 if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
60 out.set_data(
62 b.data_size(),
63 b.strides(),
64 b.flags());
65 }
66 break;
68 // Try to donate an input which is row_contiguous
69 if (!((a.flags().row_contiguous && maybe_donate(a)) ||
70 (b.flags().row_contiguous && maybe_donate(b)) ||
71 (c.flags().row_contiguous && maybe_donate(c)))) {
73 }
74 break;
75 }
76}
77
78} // namespace mlx::core
Definition array.h:24
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:318
const Strides & strides() const
The strides of the array.
Definition array.h:117
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
void copy_shared_buffer(const array &other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
void move_shared_buffer(array other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:83
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:332
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
@ General
Definition binary.h:16
TernaryOpType get_ternary_op_type(const array &a, const array &b, const array &c)
Definition ternary.h:18
void set_ternary_op_output_data(const array &a, const array &b, const array &c, array &out, TernaryOpType topt, bool donate_with_move=false)
Definition ternary.h:34
TernaryOpType
Definition ternary.h:11
@ ScalarScalarScalar
Definition ternary.h:12
@ General
Definition ternary.h:14
@ VectorVectorVector
Definition ternary.h:13
bool is_donatable(const array &in, const array &out)
Definition utils.h:155
bool row_contiguous
Definition array.h:237
bool col_contiguous
Definition array.h:243