MLX
 
Loading...
Searching...
No Matches
unary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include "mlx/allocator.h"
6#include "mlx/array.h"
9#include "mlx/utils.h"
10
11namespace mlx::core {
12
13void set_unary_output_data(const array& in, array& out) {
14 if (is_donatable(in, out)) {
15 out.copy_shared_buffer(in);
16 } else {
17 auto size = in.data_size();
18 out.set_data(
20 size,
21 in.strides(),
22 in.flags());
23 }
24}
25
26template <typename T, typename U = T, typename Op>
27void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
28 for (size_t i = 0; i < shape; i += 1) {
29 out[i] = op(*a);
30 a += stride;
31 }
32}
33
34template <typename T, typename U = T, typename Op>
35void unary_op(const array& a, array& out, Op op) {
36 const T* a_ptr = a.data<T>();
37 if (a.flags().contiguous) {
39 U* dst = out.data<U>();
40 constexpr int N = simd::max_size<T>;
41 size_t size = a.data_size();
42 while (size >= N) {
43 simd::store(dst, op(simd::load<T, N>(a_ptr)));
44 size -= N;
45 a_ptr += N;
46 dst += N;
47 }
48 while (size > 0) {
49 *dst = op(*a_ptr);
50 size--;
51 dst++;
52 a_ptr++;
53 }
54 } else {
56 U* dst = out.data<U>();
57 size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
58 size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
59 if (a.ndim() <= 1) {
60 unary_op(a_ptr, dst, op, shape, stride);
61 return;
62 }
63 ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
64 for (size_t elem = 0; elem < a.size(); elem += shape) {
65 unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
66 it.step();
67 }
68 }
69}
70
71template <typename Op>
72void unary(const array& a, array& out, Op op) {
73 switch (out.dtype()) {
74 case bool_:
75 unary_op<bool>(a, out, op);
76 break;
77 case uint8:
78 unary_op<uint8_t>(a, out, op);
79 break;
80 case uint16:
81 unary_op<uint16_t>(a, out, op);
82 break;
83 case uint32:
84 unary_op<uint32_t>(a, out, op);
85 break;
86 case uint64:
87 unary_op<uint64_t>(a, out, op);
88 break;
89 case int8:
90 unary_op<int8_t>(a, out, op);
91 break;
92 case int16:
93 unary_op<int16_t>(a, out, op);
94 break;
95 case int32:
96 unary_op<int32_t>(a, out, op);
97 break;
98 case int64:
99 unary_op<int64_t>(a, out, op);
100 break;
101 case float16:
102 unary_op<float16_t>(a, out, op);
103 break;
104 case float32:
105 unary_op<float>(a, out, op);
106 break;
107 case float64:
108 unary_op<double>(a, out, op);
109 break;
110 case bfloat16:
111 unary_op<bfloat16_t>(a, out, op);
112 break;
113 case complex64:
114 unary_op<complex64_t>(a, out, op);
115 break;
116 }
117}
118
119template <typename Op>
120void unary_fp(const array& a, array& out, Op op) {
121 switch (out.dtype()) {
122 case bfloat16:
123 unary_op<bfloat16_t>(a, out, op);
124 break;
125 case float16:
126 unary_op<float16_t>(a, out, op);
127 break;
128 case float32:
129 unary_op<float>(a, out, op);
130 break;
131 case float64:
132 unary_op<double>(a, out, op);
133 break;
134 case complex64:
135 unary_op<complex64_t>(a, out, op);
136 break;
137 default:
138 std::ostringstream err;
139 err << "[unary_fp] Does not support " << out.dtype();
140 throw std::runtime_error(err.str());
141 }
142}
143
144template <typename Op>
145void unary_int(const array& a, array& out, Op op) {
146 switch (out.dtype()) {
147 case uint8:
148 unary_op<uint8_t>(a, out, op);
149 break;
150 case uint16:
151 unary_op<uint16_t>(a, out, op);
152 break;
153 case uint32:
154 unary_op<uint32_t>(a, out, op);
155 break;
156 case uint64:
157 unary_op<uint64_t>(a, out, op);
158 break;
159 case int8:
160 unary_op<int8_t>(a, out, op);
161 break;
162 case int16:
163 unary_op<int16_t>(a, out, op);
164 break;
165 case int32:
166 unary_op<int32_t>(a, out, op);
167 break;
168 case int64:
169 unary_op<int64_t>(a, out, op);
170 break;
171 default:
172 std::ostringstream err;
173 err << "[unary_int] Does not support " << out.dtype();
174 throw std::runtime_error(err.str());
175 }
176}
177
178} // namespace mlx::core
Definition array.h:24
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:318
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
const Strides & strides() const
The strides of the array.
Definition array.h:117
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
size_t ndim() const
The number of dimensions of the array.
Definition array.h:98
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:354
void copy_shared_buffer(const array &other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:83
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:332
Buffer malloc_or_wait(size_t size)
Simd< T, N > load(const T *x)
Definition base_simd.h:28
static constexpr int max_size
Definition base_simd.h:14
void store(T *dst, Simd< T, N > x)
Definition base_simd.h:33
Definition allocator.h:7
void unary_int(const array &a, array &out, Op op)
Definition unary.h:145
constexpr Dtype bool_
Definition dtype.h:68
constexpr Dtype uint64
Definition dtype.h:73
void unary_op(const T *a, U *out, Op op, size_t shape, size_t stride)
Definition unary.h:27
constexpr Dtype uint16
Definition dtype.h:71
constexpr Dtype float64
Definition dtype.h:82
void set_unary_output_data(const array &in, array &out)
Definition unary.h:13
constexpr Dtype bfloat16
Definition dtype.h:83
constexpr Dtype int32
Definition dtype.h:77
constexpr Dtype float32
Definition dtype.h:81
void unary(const array &a, array &out, Op op)
Definition unary.h:72
constexpr Dtype int16
Definition dtype.h:76
void unary_fp(const array &a, array &out, Op op)
Definition unary.h:120
constexpr Dtype int8
Definition dtype.h:75
constexpr Dtype int64
Definition dtype.h:78
constexpr Dtype uint8
Definition dtype.h:70
constexpr Dtype float16
Definition dtype.h:80
constexpr Dtype uint32
Definition dtype.h:72
bool is_donatable(const array &in, const array &out)
Definition utils.h:155
constexpr Dtype complex64
Definition dtype.h:84
Definition utils.h:73
int64_t loc
Definition utils.h:126
void step()
Definition utils.h:74
bool contiguous
Definition array.h:231