MLX
Loading...
Searching...
No Matches
device.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <Metal/Metal.hpp>
6#include <dlfcn.h>
7#include <filesystem>
8#include <functional>
9#include <mutex>
10#include <string>
11#include <unordered_map>
12#include <unordered_set>
13
14#include "mlx/array.h"
15#include "mlx/device.h"
16
17namespace fs = std::filesystem;
18
19namespace mlx::core::metal {
20
21// Note, this function must be left inline in a header so that it is not
22// dynamically linked.
23inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
24 Dl_info info;
25 std::string mtllib_path;
26 std::string lib_ext = lib_name + ".metallib";
27
28 int success = dladdr((void*)get_colocated_mtllib_path, &info);
29 if (success) {
30 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
31 mtllib_path = mtllib.c_str();
32 }
33
34 return mtllib_path;
35}
36
37using MTLFCList =
38 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
39
41 CommandEncoder(MTL::CommandBuffer* cbuf);
44
47 enc.concurrent = true;
48 }
50 enc.concurrent = false;
51 enc.outputs.insert(
52 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
53 enc.concurrent_outputs.clear();
54 }
55
56 private:
57 CommandEncoder& enc;
58 };
59
60 MTL::ComputeCommandEncoder* operator->() {
61 return enc;
62 }
63
64 void set_input_array(const array& a, int idx, int64_t offset = 0);
65 void set_output_array(array& a, int idx, int64_t offset = 0);
66 void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
67 void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
68
72
74
75 private:
76 void maybe_split();
77
78 int num_dispatches{0};
79 MTL::CommandBuffer* cbuf;
80 MTL::ComputeCommandEncoder* enc;
81 bool concurrent{false};
82 std::unordered_set<MTL::Resource*> outputs;
83 std::unordered_set<MTL::Resource*> concurrent_outputs;
84};
85
86class Device {
87 public:
89 Device(const Device&) = delete;
90 Device& operator=(const Device&) = delete;
92
93 MTL::Device* mtl_device() {
94 return device_;
95 };
96
97 void new_queue(int index);
98 MTL::CommandBuffer* get_command_buffer(int index);
99 int get_command_buffer_ops(int index);
101 void commit_command_buffer(int index);
103 void end_encoding(int index);
104
106 const std::string& lib_name,
107 const std::string& lib_path);
108
109 // Note, this should remain in the header so that it is not dynamically
110 // linked
111 void register_library(const std::string& lib_name) {
112 if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
113 register_library(lib_name, get_colocated_mtllib_path(lib_name));
114 }
115 }
116
117 MTL::Library* get_library(const std::string& name);
118
119 MTL::Library* get_library(
120 const std::string& name,
121 const std::string& source_string,
122 bool cache = true);
123
124 MTL::Library* get_library(
125 const std::string& name,
126 const MTL::StitchedLibraryDescriptor* desc,
127 bool cache = true);
128
129 MTL::Function* get_function(
130 const std::string& base_name,
131 MTL::Library* mtl_lib,
132 const std::string& specialized_name = "",
133 const MTLFCList& func_consts = {});
134
135 MTL::Function* get_function(
136 const std::string& base_name,
137 const std::string& lib_name = "mlx",
138 const std::string& specialized_name = "",
139 const MTLFCList& func_consts = {});
140
141 MTL::ComputePipelineState* get_kernel(
142 const std::string& base_name,
143 MTL::Library* mtl_lib,
144 const std::string& hash_name = "",
145 const MTLFCList& func_consts = {},
146 const std::vector<MTL::Function*>& linked_functions = {});
147
148 MTL::ComputePipelineState* get_kernel(
149 const std::string& base_name,
150 const std::string& lib_name = "mlx",
151 const std::string& hash_name = "",
152 const MTLFCList& func_consts = {},
153 const std::vector<MTL::Function*>& linked_functions = {});
154
155 MTL::ArgumentEncoder* argument_encoder(
156 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
157
158 private:
159 MTL::Library* get_library_cache_(const std::string& name);
160
161 MTL::Library* get_library_(const std::string& source_string);
162 MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
163
164 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
165
166 MTL::Function* get_function_(
167 const std::string& name,
168 const std::string& specialized_name,
169 const MTLFCList& func_consts,
170 MTL::Library* mtl_lib);
171
172 MTL::LinkedFunctions* get_linked_functions_(
173 const std::vector<MTL::Function*>& funcs);
174
175 MTL::ComputePipelineState* get_kernel_(
176 const std::string& name,
177 const MTL::Function* mtl_function);
178
179 MTL::ComputePipelineState* get_kernel_(
180 const std::string& name,
181 const MTL::Function* mtl_function,
182 const MTL::LinkedFunctions* linked_functions);
183
184 MTL::Device* device_;
185 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
186 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
187 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
188 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
189 std::unordered_map<std::string, MTL::Library*> library_map_;
190 std::mutex mtx_;
191};
192
194
195} // namespace mlx::core::metal
Definition array.h:20
Definition device.h:86
int get_command_buffer_ops(int index)
MTL::Device * mtl_device()
Definition device.h:93
void register_library(const std::string &lib_name, const std::string &lib_path)
MTL::CommandBuffer * get_command_buffer(int index)
void end_encoding(int index)
MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
void increment_command_buffer_ops(int index)
void new_queue(int index)
MTL::Library * get_library(const std::string &name)
MTL::Library * get_library(const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)
void commit_command_buffer(int index)
MTL::Library * get_library(const std::string &name, const std::string &source_string, bool cache=true)
void register_library(const std::string &lib_name)
Definition device.h:111
MTL::Function * get_function(const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})
Device(const Device &)=delete
MTL::Function * get_function(const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})
Device & operator=(const Device &)=delete
MTL::ComputePipelineState * get_kernel(const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
CommandEncoder & get_command_encoder(int index)
Definition allocator.h:12
std::string get_colocated_mtllib_path(const std::string &lib_name)
Definition device.h:23
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:37
Device & device(mlx::core::Device)
Definition device.h:7
ConcurrentContext(CommandEncoder &enc)
Definition device.h:46
Definition device.h:40
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)
CommandEncoder(MTL::CommandBuffer *cbuf)
CommandEncoder & operator=(const CommandEncoder &)=delete
ConcurrentContext start_concurrent()
Definition device.h:69
void set_output_array(array &a, int idx, int64_t offset=0)
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)
MTL::ComputeCommandEncoder * operator->()
Definition device.h:60
void set_input_array(const array &a, int idx, int64_t offset=0)
CommandEncoder(const CommandEncoder &)=delete