MLX
Loading...
Searching...
No Matches
device.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <Metal/Metal.hpp>
6#include <functional>
7#include <mutex>
8#include <string>
9#include <unordered_map>
10#include <unordered_set>
11
12#include "mlx/array.h"
13#include "mlx/device.h"
14
15namespace mlx::core::metal {
16
17using MTLFCList =
18 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
19
21 CommandEncoder(MTL::CommandBuffer* cbuf);
24
27 enc.concurrent = true;
28 }
30 enc.concurrent = false;
31 enc.outputs.insert(
32 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
33 enc.concurrent_outputs.clear();
34 }
35
36 private:
37 CommandEncoder& enc;
38 };
39
40 MTL::ComputeCommandEncoder* operator->() {
41 return enc;
42 }
43
44 void set_input_array(const array& a, int idx, int64_t offset = 0);
45 void set_output_array(array& a, int idx, int64_t offset = 0);
46 void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
47 void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
48
52
54
55 private:
56 void maybe_split();
57
58 int num_dispatches{0};
59 MTL::CommandBuffer* cbuf;
60 MTL::ComputeCommandEncoder* enc;
61 bool concurrent{false};
62 std::unordered_set<MTL::Resource*> outputs;
63 std::unordered_set<MTL::Resource*> concurrent_outputs;
64};
65
66class Device {
67 public:
69 Device(const Device&) = delete;
70 Device& operator=(const Device&) = delete;
72
73 MTL::Device* mtl_device() {
74 return device_;
75 };
76
77 void new_queue(int index);
78 MTL::CommandBuffer* get_command_buffer(int index);
79 int get_command_buffer_ops(int index);
81 void commit_command_buffer(int index);
83 void end_encoding(int index);
84
86 const std::string& lib_name,
87 const std::string& lib_path);
88
89 void register_library(const std::string& lib_name);
90
91 MTL::Library* get_library(const std::string& name);
92
93 MTL::Library* get_library(
94 const std::string& name,
95 const std::string& source_string,
96 bool cache = true);
97
98 MTL::Library* get_library(
99 const std::string& name,
100 const MTL::StitchedLibraryDescriptor* desc,
101 bool cache = true);
102
103 MTL::Function* get_function(
104 const std::string& base_name,
105 MTL::Library* mtl_lib,
106 const std::string& specialized_name = "",
107 const MTLFCList& func_consts = {});
108
109 MTL::Function* get_function(
110 const std::string& base_name,
111 const std::string& lib_name = "mlx",
112 const std::string& specialized_name = "",
113 const MTLFCList& func_consts = {});
114
115 MTL::ComputePipelineState* get_kernel(
116 const std::string& base_name,
117 MTL::Library* mtl_lib,
118 const std::string& hash_name = "",
119 const MTLFCList& func_consts = {},
120 const std::vector<MTL::Function*>& linked_functions = {});
121
122 MTL::ComputePipelineState* get_kernel(
123 const std::string& base_name,
124 const std::string& lib_name = "mlx",
125 const std::string& hash_name = "",
126 const MTLFCList& func_consts = {},
127 const std::vector<MTL::Function*>& linked_functions = {});
128
129 MTL::ArgumentEncoder* argument_encoder(
130 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
131
132 private:
133 MTL::Library* get_library_cache_(const std::string& name);
134
135 MTL::Library* get_library_(const std::string& source_string);
136 MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
137
138 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
139
140 MTL::Function* get_function_(
141 const std::string& name,
142 const std::string& specialized_name,
143 const MTLFCList& func_consts,
144 MTL::Library* mtl_lib);
145
146 MTL::LinkedFunctions* get_linked_functions_(
147 const std::vector<MTL::Function*>& funcs);
148
149 MTL::ComputePipelineState* get_kernel_(
150 const std::string& name,
151 const MTL::Function* mtl_function);
152
153 MTL::ComputePipelineState* get_kernel_(
154 const std::string& name,
155 const MTL::Function* mtl_function,
156 const MTL::LinkedFunctions* linked_functions);
157
158 MTL::Device* device_;
159 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
160 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
161 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
162 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
163 std::unordered_map<std::string, MTL::Library*> library_map_;
164 std::mutex mtx_;
165};
166
168
169} // namespace mlx::core::metal
Definition array.h:20
Definition device.h:66
int get_command_buffer_ops(int index)
MTL::Device * mtl_device()
Definition device.h:73
void register_library(const std::string &lib_name, const std::string &lib_path)
MTL::CommandBuffer * get_command_buffer(int index)
void end_encoding(int index)
MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
void increment_command_buffer_ops(int index)
void new_queue(int index)
MTL::Library * get_library(const std::string &name)
MTL::Library * get_library(const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)
void commit_command_buffer(int index)
MTL::Library * get_library(const std::string &name, const std::string &source_string, bool cache=true)
void register_library(const std::string &lib_name)
MTL::Function * get_function(const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})
Device(const Device &)=delete
MTL::Function * get_function(const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})
Device & operator=(const Device &)=delete
MTL::ComputePipelineState * get_kernel(const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
CommandEncoder & get_command_encoder(int index)
Definition allocator.h:12
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:17
Device & device(mlx::core::Device)
Definition device.h:7
ConcurrentContext(CommandEncoder &enc)
Definition device.h:26
Definition device.h:20
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)
CommandEncoder(MTL::CommandBuffer *cbuf)
CommandEncoder & operator=(const CommandEncoder &)=delete
ConcurrentContext start_concurrent()
Definition device.h:49
void set_output_array(array &a, int idx, int64_t offset=0)
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)
MTL::ComputeCommandEncoder * operator->()
Definition device.h:40
void set_input_array(const array &a, int idx, int64_t offset=0)
CommandEncoder(const CommandEncoder &)=delete