MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include "mlx/array.h"
7#include "mlx/primitives.h"
8
9namespace mlx::core {
10
11namespace {
12
13using metal::CommandEncoder;
14
15template <typename T>
16inline void set_vector_bytes(
17 CommandEncoder& enc,
18 const std::vector<T>& vec,
19 size_t nelems,
20 int idx) {
21 enc->setBytes(vec.data(), nelems * sizeof(T), idx);
22}
23
24template <typename T>
25inline void
26set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
27 return set_vector_bytes(enc, vec, vec.size(), idx);
28}
29
30std::string type_to_name(const array& a) {
31 std::string tname;
32 switch (a.dtype()) {
33 case bool_:
34 tname = "bool_";
35 break;
36 case uint8:
37 tname = "uint8";
38 break;
39 case uint16:
40 tname = "uint16";
41 break;
42 case uint32:
43 tname = "uint32";
44 break;
45 case uint64:
46 tname = "uint64";
47 break;
48 case int8:
49 tname = "int8";
50 break;
51 case int16:
52 tname = "int16";
53 break;
54 case int32:
55 tname = "int32";
56 break;
57 case int64:
58 tname = "int64";
59 break;
60 case float16:
61 tname = "float16";
62 break;
63 case float32:
64 tname = "float32";
65 break;
66 case bfloat16:
67 tname = "bfloat16";
68 break;
69 case complex64:
70 tname = "complex64";
71 break;
72 }
73 return tname;
74}
75
76MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
77 int pows[3] = {0, 0, 0};
78 int sum = 0;
79 while (true) {
80 int presum = sum;
81 // Check all the pows
82 if (dim0 >= (1 << (pows[0] + 1))) {
83 pows[0]++;
84 sum++;
85 }
86 if (sum == 10) {
87 break;
88 }
89 if (dim1 >= (1 << (pows[1] + 1))) {
90 pows[1]++;
91 sum++;
92 }
93 if (sum == 10) {
94 break;
95 }
96 if (dim2 >= (1 << (pows[2] + 1))) {
97 pows[2]++;
98 sum++;
99 }
100 if (sum == presum || sum == 10) {
101 break;
102 }
103 }
104 return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
105}
106
107inline NS::String* make_string(std::ostringstream& os) {
108 std::string string = os.str();
109 return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
110}
111
112inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
113#ifdef MLX_METAL_DEBUG
114 std::ostringstream label;
115 label << "Stream " << index;
116 queue->setLabel(make_string(label));
117#endif
118}
119
120inline void debug_set_primitive_buffer_label(
121 MTL::CommandBuffer* command_buffer,
122 Primitive& primitive) {
123#ifdef MLX_METAL_DEBUG
124 std::ostringstream label;
125 if (auto cbuf_label = command_buffer->label(); cbuf_label) {
126 label << cbuf_label->utf8String();
127 }
128 primitive.print(label);
129 command_buffer->setLabel(make_string(label));
130#endif
131}
132
133std::string get_primitive_string(Primitive* primitive) {
134 std::ostringstream op_t;
135 primitive->print(op_t);
136 return op_t.str();
137}
138
139} // namespace
140
141} // namespace mlx::core
array sum(const array &a, bool keepdims, StreamOrDevice s={})
Sums the elements of an array.
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:60
constexpr Dtype uint64
Definition dtype.h:65
constexpr Dtype uint16
Definition dtype.h:63
constexpr Dtype bfloat16
Definition dtype.h:74
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
constexpr Dtype int16
Definition dtype.h:68
constexpr Dtype int8
Definition dtype.h:67
constexpr Dtype int64
Definition dtype.h:70
constexpr Dtype uint8
Definition dtype.h:62
constexpr Dtype float16
Definition dtype.h:72
constexpr Dtype uint32
Definition dtype.h:64
constexpr Dtype complex64
Definition dtype.h:75