MLX
Loading...
Searching...
No Matches
unary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include "mlx/allocator.h"
6#include "mlx/array.h"
8#include "mlx/utils.h"
9
10namespace mlx::core {
11
12namespace {
13
14void set_unary_output_data(const array& in, array& out) {
15 if (is_donatable(in, out)) {
16 out.copy_shared_buffer(in);
17 } else {
18 auto size = in.data_size();
19 out.set_data(
20 allocator::malloc_or_wait(size * out.itemsize()),
21 size,
22 in.strides(),
23 in.flags());
24 }
25}
26
27template <typename T, typename Op>
28void unary_op(const array& a, array& out, Op op) {
29 const T* a_ptr = a.data<T>();
30 if (a.flags().contiguous) {
31 set_unary_output_data(a, out);
32 T* dst = out.data<T>();
33 for (size_t i = 0; i < a.data_size(); ++i) {
34 dst[i] = op(a_ptr[i]);
35 }
36 } else {
37 out.set_data(allocator::malloc_or_wait(out.nbytes()));
38 T* dst = out.data<T>();
39 for (size_t i = 0; i < out.size(); ++i) {
40 // TODO this is super inefficient, need to fix.
41 int a_idx = elem_to_loc(i, a.shape(), a.strides());
42 dst[i] = op(a_ptr[a_idx]);
43 }
44 }
45}
46
47template <typename Op>
48void unary(const array& a, array& out, Op op) {
49 switch (out.dtype()) {
50 case bool_:
51 unary_op<bool>(a, out, op);
52 break;
53 case uint8:
54 unary_op<uint8_t>(a, out, op);
55 break;
56 case uint16:
57 unary_op<uint16_t>(a, out, op);
58 break;
59 case uint32:
60 unary_op<uint32_t>(a, out, op);
61 break;
62 case uint64:
63 unary_op<uint64_t>(a, out, op);
64 break;
65 case int8:
66 unary_op<int8_t>(a, out, op);
67 break;
68 case int16:
69 unary_op<int16_t>(a, out, op);
70 break;
71 case int32:
72 unary_op<int32_t>(a, out, op);
73 break;
74 case int64:
75 unary_op<int64_t>(a, out, op);
76 break;
77 case float16:
78 unary_op<float16_t>(a, out, op);
79 break;
80 case float32:
81 unary_op<float>(a, out, op);
82 break;
83 case bfloat16:
84 unary_op<bfloat16_t>(a, out, op);
85 break;
86 case complex64:
87 unary_op<complex64_t>(a, out, op);
88 break;
89 }
90}
91
92template <typename Op>
93void unary_fp(const array& a, array& out, Op op) {
94 switch (out.dtype()) {
95 case bfloat16:
96 unary_op<bfloat16_t>(a, out, op);
97 break;
98 case float16:
99 unary_op<float16_t>(a, out, op);
100 break;
101 case float32:
102 unary_op<float>(a, out, op);
103 break;
104 case complex64:
105 unary_op<complex64_t>(a, out, op);
106 break;
107 default:
108 std::ostringstream err;
109 err << "[unary_fp] Does not support " << out.dtype();
110 throw std::runtime_error(err.str());
111 }
112}
113
114} // namespace
115
116} // namespace mlx::core
Op op
Definition binary.h:141
Buffer malloc_or_wait(size_t size)
const char * unary()
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:58
constexpr Dtype uint64
Definition dtype.h:63
constexpr Dtype uint16
Definition dtype.h:61
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
constexpr Dtype bfloat16
Definition dtype.h:72
constexpr Dtype int32
Definition dtype.h:67
constexpr Dtype float32
Definition dtype.h:71
constexpr Dtype int16
Definition dtype.h:66
constexpr Dtype int8
Definition dtype.h:65
constexpr Dtype int64
Definition dtype.h:68
constexpr Dtype uint8
Definition dtype.h:60
constexpr Dtype float16
Definition dtype.h:70
constexpr Dtype uint32
Definition dtype.h:62
bool is_donatable(const array &in, const array &out)
Definition utils.h:158
constexpr Dtype complex64
Definition dtype.h:73