MLX
 
Loading...
Searching...
No Matches
unary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include "mlx/allocator.h"
6#include "mlx/array.h"
8#include "mlx/utils.h"
9
10namespace mlx::core {
11
12namespace {
13
14void set_unary_output_data(const array& in, array& out) {
15 if (is_donatable(in, out)) {
16 out.copy_shared_buffer(in);
17 } else {
18 auto size = in.data_size();
19 out.set_data(
20 allocator::malloc_or_wait(size * out.itemsize()),
21 size,
22 in.strides(),
23 in.flags());
24 }
25}
26
27template <typename T, typename U = T, typename Op>
28void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
29 for (size_t i = 0; i < shape; i += 1) {
30 out[i] = op(*a);
31 a += stride;
32 }
33}
34
35template <typename T, typename U = T, typename Op>
36void unary_op(const array& a, array& out, Op op) {
37 const T* a_ptr = a.data<T>();
38 if (a.flags().contiguous) {
39 set_unary_output_data(a, out);
40 U* dst = out.data<U>();
41 for (size_t i = 0; i < a.data_size(); ++i) {
42 dst[i] = op(a_ptr[i]);
43 }
44 } else {
45 out.set_data(allocator::malloc_or_wait(out.nbytes()));
46 U* dst = out.data<U>();
47 size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
48 size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
49 if (a.ndim() <= 1) {
50 unary_op(a_ptr, dst, op, shape, stride);
51 return;
52 }
53 ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
54 for (size_t elem = 0; elem < a.size(); elem += shape) {
55 unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
56 it.step();
57 }
58 }
59}
60
61template <typename Op>
62void unary(const array& a, array& out, Op op) {
63 switch (out.dtype()) {
64 case bool_:
65 unary_op<bool>(a, out, op);
66 break;
67 case uint8:
68 unary_op<uint8_t>(a, out, op);
69 break;
70 case uint16:
71 unary_op<uint16_t>(a, out, op);
72 break;
73 case uint32:
74 unary_op<uint32_t>(a, out, op);
75 break;
76 case uint64:
77 unary_op<uint64_t>(a, out, op);
78 break;
79 case int8:
80 unary_op<int8_t>(a, out, op);
81 break;
82 case int16:
83 unary_op<int16_t>(a, out, op);
84 break;
85 case int32:
86 unary_op<int32_t>(a, out, op);
87 break;
88 case int64:
89 unary_op<int64_t>(a, out, op);
90 break;
91 case float16:
92 unary_op<float16_t>(a, out, op);
93 break;
94 case float32:
95 unary_op<float>(a, out, op);
96 break;
97 case bfloat16:
98 unary_op<bfloat16_t>(a, out, op);
99 break;
100 case complex64:
101 unary_op<complex64_t>(a, out, op);
102 break;
103 }
104}
105
106template <typename Op>
107void unary_fp(const array& a, array& out, Op op) {
108 switch (out.dtype()) {
109 case bfloat16:
110 unary_op<bfloat16_t>(a, out, op);
111 break;
112 case float16:
113 unary_op<float16_t>(a, out, op);
114 break;
115 case float32:
116 unary_op<float>(a, out, op);
117 break;
118 case complex64:
119 unary_op<complex64_t>(a, out, op);
120 break;
121 default:
122 std::ostringstream err;
123 err << "[unary_fp] Does not support " << out.dtype();
124 throw std::runtime_error(err.str());
125 }
126}
127
128} // namespace
129
130} // namespace mlx::core
Definition array.h:24
Buffer malloc_or_wait(size_t size)
const char * unary()
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:67
constexpr Dtype uint64
Definition dtype.h:72
constexpr Dtype uint16
Definition dtype.h:70
constexpr Dtype bfloat16
Definition dtype.h:81
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
constexpr Dtype int16
Definition dtype.h:75
constexpr Dtype int8
Definition dtype.h:74
constexpr Dtype int64
Definition dtype.h:77
constexpr Dtype uint8
Definition dtype.h:69
constexpr Dtype float16
Definition dtype.h:79
constexpr Dtype uint32
Definition dtype.h:71
bool is_donatable(const array &in, const array &out)
Definition utils.h:155
constexpr Dtype complex64
Definition dtype.h:82
void dst(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
Definition pocketfft.h:3416
Definition utils.h:73