MLX
 
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include "mlx/array.h"
7#include "mlx/primitives.h"
8
9namespace mlx::core {
10
11std::string type_to_name(const Dtype& t);
12std::string type_to_name(const array& a);
13
14// Compute the thread block dimensions which fit the given
15// input dimensions.
16// - The thread block dimensions will be powers of two
17// - The thread block size will be less than 2^pow2
18MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
19
20// Computes a 2D grid where each element is < UINT_MAX
21// Assumes:
22// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
23// - shape and strides correspond to a contiguous (no holes) but
24// possibly broadcasted array
25MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides);
26
27// Same as above but we do an implicit division with divisor.
28// Basically, equivalent to factorizing
29// Prod(s \forall s in shape if strides[s] > 0) / divisor.
30MTL::Size
31get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor);
32
33inline NS::String* make_string(std::ostringstream& os) {
34 std::string string = os.str();
35 return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
36}
37
38inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
39#ifdef MLX_METAL_DEBUG
40 std::ostringstream label;
41 label << "Stream " << index;
42 queue->setLabel(make_string(label));
43#endif
44}
45
47 MTL::CommandBuffer* command_buffer,
48 Primitive& primitive) {
49#ifdef MLX_METAL_DEBUG
50 std::ostringstream label;
51 if (auto cbuf_label = command_buffer->label(); cbuf_label) {
52 label << cbuf_label->utf8String();
53 }
54 primitive.print(label);
55 command_buffer->setLabel(make_string(label));
56#endif
57}
58
59std::string get_primitive_string(Primitive* primitive);
60
61template <typename T>
62void concatenate(std::string& acc, T first) {
63 acc += first;
64}
65
66template <typename T, typename... Args>
67void concatenate(std::string& acc, T first, Args... args) {
68 acc += first;
69 concatenate(acc, args...);
70}
71
76inline array unsafe_weak_copy(const array& x) {
77 return array(
78 x.buffer(),
79 x.shape(),
80 x.dtype(),
81 x.strides(),
82 x.data_size(),
83 x.flags(),
84 [](auto b) {});
85}
86
87} // namespace mlx::core
Definition primitives.h:48
virtual void print(std::ostream &os)=0
Print the primitive.
Definition array.h:24
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:318
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
const Strides & strides() const
The strides of the array.
Definition array.h:117
allocator::Buffer & buffer()
Definition array.h:336
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:332
Definition allocator.h:7
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2=10)
array unsafe_weak_copy(const array &x)
Get a new array that refers to the same data but has a non-owning pointer to them.
Definition utils.h:76
void debug_set_primitive_buffer_label(MTL::CommandBuffer *command_buffer, Primitive &primitive)
Definition utils.h:46
std::vector< ShapeElem > Shape
Definition array.h:21
void concatenate(std::string &acc, T first)
Definition utils.h:62
void debug_set_stream_queue_label(MTL::CommandQueue *queue, int index)
Definition utils.h:38
std::vector< int64_t > Strides
Definition array.h:22
MTL::Size get_2d_grid_dims(const Shape &shape, const Strides &strides)
std::vector< array > Args
Definition export.h:11
std::string get_primitive_string(Primitive *primitive)
NS::String * make_string(std::ostringstream &os)
Definition utils.h:33
std::string type_to_name(const Dtype &t)
Definition dtype.h:13