MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include "mlx/array.h"
7#include "mlx/primitives.h"
8
9namespace mlx::core {
10
11std::string type_to_name(const Dtype& t);
12std::string type_to_name(const array& a);
13
14// Compute the thread block dimensions which fit the given
15// input dimensions.
16// - The thread block dimensions will be powers of two
17// - The thread block size will be less than 2^pow2
18MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
19
20// Computes a 2D grid where each element is < UINT_MAX
21// Assumes:
22// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
23// - shape and strides correspond to a contiguous (no holes) but
24// possibly broadcasted array
26 const std::vector<int>& shape,
27 const std::vector<size_t>& strides);
28
29// Same as above but we do an implicit division with divisor.
30// Basically, equivalent to factorizing
31// Prod(s \forall s in shape if strides[s] > 0) / divisor.
33 const std::vector<int>& shape,
34 const std::vector<size_t>& strides,
35 size_t divisor);
36
37inline NS::String* make_string(std::ostringstream& os) {
38 std::string string = os.str();
39 return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
40}
41
42inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
43#ifdef MLX_METAL_DEBUG
44 std::ostringstream label;
45 label << "Stream " << index;
46 queue->setLabel(make_string(label));
47#endif
48}
49
51 MTL::CommandBuffer* command_buffer,
52 Primitive& primitive) {
53#ifdef MLX_METAL_DEBUG
54 std::ostringstream label;
55 if (auto cbuf_label = command_buffer->label(); cbuf_label) {
56 label << cbuf_label->utf8String();
57 }
58 primitive.print(label);
59 command_buffer->setLabel(make_string(label));
60#endif
61}
62
63std::string get_primitive_string(Primitive* primitive);
64
65template <typename T>
66void concatenate(std::string& acc, T first) {
67 acc += first;
68}
69
70template <typename T, typename... Args>
71void concatenate(std::string& acc, T first, Args... args) {
72 acc += first;
73 concatenate(acc, args...);
74}
75
76} // namespace mlx::core
Definition primitives.h:48
virtual void print(std::ostream &os)=0
Print the primitive.
Definition array.h:23
Definition allocator.h:7
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2=10)
void debug_set_primitive_buffer_label(MTL::CommandBuffer *command_buffer, Primitive &primitive)
Definition utils.h:50
void concatenate(std::string &acc, T first)
Definition utils.h:66
void debug_set_stream_queue_label(MTL::CommandQueue *queue, int index)
Definition utils.h:42
MTL::Size get_2d_grid_dims(const std::vector< int > &shape, const std::vector< size_t > &strides)
std::string get_primitive_string(Primitive *primitive)
NS::String * make_string(std::ostringstream &os)
Definition utils.h:37
std::string type_to_name(const Dtype &t)
Definition dtype.h:13