MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include "mlx/array.h"
7#include "mlx/primitives.h"
8
9namespace mlx::core {
10
11using metal::CommandEncoder;
12
13template <typename T>
14inline void set_vector_bytes(
15 CommandEncoder& enc,
16 const std::vector<T>& vec,
17 size_t nelems,
18 int idx) {
19 enc->setBytes(vec.data(), nelems * sizeof(T), idx);
20}
21
22template <typename T>
23inline void
24set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
25 return set_vector_bytes(enc, vec, vec.size(), idx);
26}
27
28std::string type_to_name(const array& a);
29
30// Compute the thread block dimensions which fit the given
31// input dimensions.
32// - The thread block dimensions will be powers of two
33// - The thread block size will be less than 2^pow2
34MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
35
36// Computes a 2D grid where each element is < UINT_MAX
37// Assumes:
38// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
39// - shape and strides correspond to a contiguous (no holes) but
40// possibly broadcasted array
42 const std::vector<int>& shape,
43 const std::vector<size_t>& strides);
44
45// Same as above but we do an implicit division with divisor.
46// Basically, equivalent to factorizing
47// Prod(s \forall s in shape if strides[s] > 0) / divisor.
49 const std::vector<int>& shape,
50 const std::vector<size_t>& strides,
51 size_t divisor);
52
53inline NS::String* make_string(std::ostringstream& os) {
54 std::string string = os.str();
55 return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
56}
57
58inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
59#ifdef MLX_METAL_DEBUG
60 std::ostringstream label;
61 label << "Stream " << index;
62 queue->setLabel(make_string(label));
63#endif
64}
65
67 MTL::CommandBuffer* command_buffer,
68 Primitive& primitive) {
69#ifdef MLX_METAL_DEBUG
70 std::ostringstream label;
71 if (auto cbuf_label = command_buffer->label(); cbuf_label) {
72 label << cbuf_label->utf8String();
73 }
74 primitive.print(label);
75 command_buffer->setLabel(make_string(label));
76#endif
77}
78
79std::string get_primitive_string(Primitive* primitive);
80
81} // namespace mlx::core
Definition primitives.h:48
virtual void print(std::ostream &os)=0
Print the primitive.
Definition array.h:20
Definition allocator.h:7
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2=10)
void debug_set_primitive_buffer_label(MTL::CommandBuffer *command_buffer, Primitive &primitive)
Definition utils.h:66
void set_vector_bytes(CommandEncoder &enc, const std::vector< T > &vec, size_t nelems, int idx)
Definition utils.h:14
void debug_set_stream_queue_label(MTL::CommandQueue *queue, int index)
Definition utils.h:58
MTL::Size get_2d_grid_dims(const std::vector< int > &shape, const std::vector< size_t > &strides)
std::string get_primitive_string(Primitive *primitive)
NS::String * make_string(std::ostringstream &os)
Definition utils.h:53
std::string type_to_name(const array &a)
Definition device.h:41