MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include "mlx/array.h"
7#include "mlx/primitives.h"
8
9namespace mlx::core {
10
11using metal::CommandEncoder;
12
13template <typename T>
14inline void set_vector_bytes(
15 CommandEncoder& enc,
16 const std::vector<T>& vec,
17 size_t nelems,
18 int idx) {
19 enc->setBytes(vec.data(), nelems * sizeof(T), idx);
20}
21
22template <typename T>
23inline void
24set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
25 return set_vector_bytes(enc, vec, vec.size(), idx);
26}
27
28std::string type_to_name(const array& a);
29
30// Compute the thread block dimensions which fit the given
31// input dimensions.
32// - The thread block dimensions will be powers of two
33// - The thread block size will be less than 1024
34MTL::Size get_block_dims(int dim0, int dim1, int dim2);
35
36// Computes a 2D grid where each element is < UINT_MAX
37// Assumes:
38// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
39// - shape and strides correspond to a contiguous (no holes) but
40// possibly broadcasted array
42 const std::vector<int>& shape,
43 const std::vector<size_t>& strides);
44
45inline NS::String* make_string(std::ostringstream& os) {
46 std::string string = os.str();
47 return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
48}
49
50inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
51#ifdef MLX_METAL_DEBUG
52 std::ostringstream label;
53 label << "Stream " << index;
54 queue->setLabel(make_string(label));
55#endif
56}
57
59 MTL::CommandBuffer* command_buffer,
60 Primitive& primitive) {
61#ifdef MLX_METAL_DEBUG
62 std::ostringstream label;
63 if (auto cbuf_label = command_buffer->label(); cbuf_label) {
64 label << cbuf_label->utf8String();
65 }
66 primitive.print(label);
67 command_buffer->setLabel(make_string(label));
68#endif
69}
70
71std::string get_primitive_string(Primitive* primitive);
72
73} // namespace mlx::core
Definition primitives.h:48
virtual void print(std::ostream &os)=0
Print the primitive.
Definition array.h:20
Definition allocator.h:7
void debug_set_primitive_buffer_label(MTL::CommandBuffer *command_buffer, Primitive &primitive)
Definition utils.h:58
void set_vector_bytes(CommandEncoder &enc, const std::vector< T > &vec, size_t nelems, int idx)
Definition utils.h:14
void debug_set_stream_queue_label(MTL::CommandQueue *queue, int index)
Definition utils.h:50
MTL::Size get_block_dims(int dim0, int dim1, int dim2)
MTL::Size get_2d_grid_dims(const std::vector< int > &shape, const std::vector< size_t > &strides)
std::string get_primitive_string(Primitive *primitive)
NS::String * make_string(std::ostringstream &os)
Definition utils.h:45
std::string type_to_name(const array &a)
Definition device.h:40