diff --git a/docs/build/html/_sources/dev/custom_metal_kernels.rst b/docs/build/html/_sources/dev/custom_metal_kernels.rst index c4c1b0aff..3e92f2814 100644 --- a/docs/build/html/_sources/dev/custom_metal_kernels.rst +++ b/docs/build/html/_sources/dev/custom_metal_kernels.rst @@ -1,3 +1,5 @@ +.. _custom_metal_kernels: + Custom Metal Kernels ==================== @@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; +Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads `_ function. +This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. +For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. + Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. Using Shape/Strides diff --git a/docs/build/html/_sources/usage/function_transforms.rst b/docs/build/html/_sources/usage/function_transforms.rst index 9a15bbf1c..9769fceaa 100644 --- a/docs/build/html/_sources/usage/function_transforms.rst +++ b/docs/build/html/_sources/usage/function_transforms.rst @@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop: ys = mx.random.uniform(shape=(100, 4096)) def naive_add(xs, ys): - return [xs[i] + ys[:, i] for i in range(xs.shape[1])] + return [xs[i] + ys[:, i] for i in range(xs.shape[0])] Instead you can use :func:`vmap` to automatically vectorize the addition: @@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition: # Vectorize over the second dimension of x and the # first dimension of y - vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0)) + vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1)) The ``in_axes`` parameter can be used to specify which dimensions of the corresponding input to vectorize over. Similarly, use ``out_axes`` to specify diff --git a/docs/build/html/_sources/usage/indexing.rst b/docs/build/html/_sources/usage/indexing.rst index 62994a0fb..c74e357fa 100644 --- a/docs/build/html/_sources/usage/indexing.rst +++ b/docs/build/html/_sources/usage/indexing.rst @@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the kernel would be extremely inefficient. Indexing with boolean masks is something that MLX may support in the future. In -general, MLX has limited support for operations for which outputs +general, MLX has limited support for operations for which output *shapes* are dependent on input *data*. Other examples of these types of operations which MLX does not yet support include :func:`numpy.nonzero` and the single input version of :func:`numpy.where`. diff --git a/docs/build/html/_sources/usage/lazy_evaluation.rst b/docs/build/html/_sources/usage/lazy_evaluation.rst index 466edaaed..8fd855efa 100644 --- a/docs/build/html/_sources/usage/lazy_evaluation.rst +++ b/docs/build/html/_sources/usage/lazy_evaluation.rst @@ -109,7 +109,7 @@ Here is a concrete example: An important behavior to be aware of is when the graph will be implicitly evaluated. Anytime you ``print`` an array, convert it to an -:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`, +:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`, the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX saving functions) will also evaluate the array. diff --git a/docs/build/html/backend_2metal_2device_8h_source.html b/docs/build/html/backend_2metal_2device_8h_source.html index 8b936bd40..bb0b3537b 100644 --- a/docs/build/html/backend_2metal_2device_8h_source.html +++ b/docs/build/html/backend_2metal_2device_8h_source.html @@ -149,7 +149,7 @@ $(function(){ initResizable(false); });
51 enc.concurrent_ = false;
-
52 enc.outputs_.insert(
+
52 enc.prev_outputs_.insert(
53 enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
54 enc.concurrent_outputs_.clear();
55 }
@@ -170,212 +170,215 @@ $(function(){ initResizable(false); });
66 void set_output_array(array& a, int idx, int64_t offset = 0);
67 void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
68 void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
-
69
-
- -
71 return ConcurrentContext(*this);
-
72 }
+ +
70
+
+ +
72 return ConcurrentContext(*this);
+
73 }
- -
74
-
75 // Inputs to all kernels in the encoder including temporaries
-
-
76 std::unordered_set<const void*>& inputs() {
-
77 return all_inputs_;
-
78 };
+ +
75
+
76 // Inputs to all kernels in the encoder including temporaries
+
+
77 std::unordered_set<const void*>& inputs() {
+
78 return all_inputs_;
+
79 };
-
79
-
80 // Outputs of all kernels in the encoder including temporaries
-
-
81 std::unordered_set<const void*> outputs() {
-
82 return all_outputs_;
-
83 };
+
80
+
81 // Outputs of all kernels in the encoder including temporaries
+
+
82 std::unordered_set<const void*> outputs() {
+
83 return all_outputs_;
+
84 };
-
84
-
85 private:
-
86 MTL::ComputeCommandEncoder* enc_;
-
87 bool concurrent_{false};
-
88 std::unordered_set<MTL::Resource*> outputs_;
-
89 std::unordered_set<MTL::Resource*> concurrent_outputs_;
-
90 std::unordered_set<const void*> all_inputs_;
-
91 std::unordered_set<const void*> all_outputs_;
-
92};
+
85
+
86 private:
+
87 MTL::ComputeCommandEncoder* enc_;
+
88 bool needs_barrier_{false};
+
89 bool concurrent_{false};
+
90 std::unordered_set<MTL::Resource*> prev_outputs_;
+
91 std::unordered_set<MTL::Resource*> next_outputs_;
+
92 std::unordered_set<MTL::Resource*> concurrent_outputs_;
+
93 std::unordered_set<const void*> all_inputs_;
+
94 std::unordered_set<const void*> all_outputs_;
+
95};
-
93
-
-
94struct Fence {
-
95 Fence(MTL::Fence* fence) : fence(fence) {}
-
- -
97 fence->release();
-
98 }
+
96
+
+
97struct Fence {
+
98 Fence(MTL::Fence* fence) : fence(fence) {}
+
+ +
100 fence->release();
+
101 }
-
99 MTL::Fence* fence;
-
100};
+
102 MTL::Fence* fence;
+
103};
-
101
-
- -
103 DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
-
- -
105 queue->release();
-
106 if (buffer != nullptr) {
-
107 buffer->release();
-
108 }
-
109 };
+
104
+
+ +
106 DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
+
+ +
108 queue->release();
+
109 if (buffer != nullptr) {
+
110 buffer->release();
+
111 }
+
112 };
-
110 MTL::CommandQueue* queue;
-
111 // A map of prior command encoder outputs to their corresponding fence
-
112 std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
-
113 // Used to allow thread-safe access to the outputs map
-
114 std::mutex fence_mtx;
-
115
-
116 // The buffer and buffer op count are updated
-
117 // between command buffers
-
118 MTL::CommandBuffer* buffer{nullptr};
- -
120
-
121 // The command encoder, fence, and temporaries are updated between command
-
122 // encoders
-
123 std::unique_ptr<CommandEncoder> encoder{nullptr};
-
124 std::shared_ptr<Fence> fence;
-
125 std::vector<array> temporaries;
-
126};
+
113 MTL::CommandQueue* queue;
+
114 // A map of prior command encoder outputs to their corresponding fence
+
115 std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
+
116 // Used to allow thread-safe access to the outputs map
+
117 std::mutex fence_mtx;
+
118
+
119 // The buffer and buffer op count are updated
+
120 // between command buffers
+
121 MTL::CommandBuffer* buffer{nullptr};
+ +
123
+
124 // The command encoder, fence, and temporaries are updated between command
+
125 // encoders
+
126 std::unique_ptr<CommandEncoder> encoder{nullptr};
+
127 std::shared_ptr<Fence> fence;
+
128 std::vector<array> temporaries;
+
129};
-
127
-
-
128class Device {
-
129 public:
- -
131 Device(const Device&) = delete;
-
132 Device& operator=(const Device&) = delete;
- -
134
-
-
135 MTL::Device* mtl_device() {
-
136 return device_;
-
137 };
+
130
+
+
131class Device {
+
132 public:
+ +
134 Device(const Device&) = delete;
+
135 Device& operator=(const Device&) = delete;
+ +
137
+
+
138 MTL::Device* mtl_device() {
+
139 return device_;
+
140 };
-
138
-
-
139 const std::string& get_architecture() {
-
140 return arch_;
-
141 }
+
141
+
+
142 const std::string& get_architecture() {
+
143 return arch_;
+
144 }
-
142
-
143 void new_queue(int index);
-
144 MTL::CommandBuffer* get_command_buffer(int index);
- - -
147 void commit_command_buffer(int index);
- -
149 void end_encoding(int index);
-
150
- -
152 const std::string& lib_name,
-
153 const std::string& lib_path);
-
154
-
155 // Note, this should remain in the header so that it is not dynamically
-
156 // linked
-
-
157 void register_library(const std::string& lib_name) {
-
158 if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
-
159 register_library(lib_name, get_colocated_mtllib_path(lib_name));
-
160 }
-
161 }
+
145
+
146 void new_queue(int index);
+
147 MTL::CommandBuffer* get_command_buffer(int index);
+ + +
150 void commit_command_buffer(int index);
+ +
152 void end_encoding(int index);
+
153
+ +
155 const std::string& lib_name,
+
156 const std::string& lib_path);
+
157
+
158 // Note, this should remain in the header so that it is not dynamically
+
159 // linked
+
+
160 void register_library(const std::string& lib_name) {
+
161 if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
+
162 register_library(lib_name, get_colocated_mtllib_path(lib_name));
+
163 }
+
164 }
-
162
-
163 MTL::Library* get_library(
-
164 const std::string& name,
-
165 const std::function<std::string(void)>& builder);
-
166
-
167 MTL::ComputePipelineState* get_kernel(
-
168 const std::string& base_name,
-
169 MTL::Library* mtl_lib,
-
170 const std::string& hash_name = "",
-
171 const MTLFCList& func_consts = {},
-
172 const std::vector<MTL::Function*>& linked_functions = {});
-
173
-
174 MTL::ComputePipelineState* get_kernel(
-
175 const std::string& base_name,
-
176 const std::string& lib_name = "mlx",
-
177 const std::string& hash_name = "",
-
178 const MTLFCList& func_consts = {},
-
179 const std::vector<MTL::Function*>& linked_functions = {});
-
180
-
181 MTL::ArgumentEncoder* argument_encoder(
-
182 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
+
165
+
166 MTL::Library* get_library(
+
167 const std::string& name,
+
168 const std::function<std::string(void)>& builder);
+
169
+
170 MTL::ComputePipelineState* get_kernel(
+
171 const std::string& base_name,
+
172 MTL::Library* mtl_lib,
+
173 const std::string& hash_name = "",
+
174 const MTLFCList& func_consts = {},
+
175 const std::vector<MTL::Function*>& linked_functions = {});
+
176
+
177 MTL::ComputePipelineState* get_kernel(
+
178 const std::string& base_name,
+
179 const std::string& lib_name = "mlx",
+
180 const std::string& hash_name = "",
+
181 const MTLFCList& func_consts = {},
+
182 const std::vector<MTL::Function*>& linked_functions = {});
183
-
184 // Record temporary arrays for the given stream index
-
185 void add_temporary(array arr, int index);
-
186 void add_temporaries(std::vector<array> arrays, int index);
-
187
-
188 void set_residency_set(const MTL::ResidencySet* residency_set);
-
189
-
190 private:
-
191 DeviceStream& get_stream_(int index) {
-
192 return stream_map_.find(index)->second;
-
193 }
-
194 MTL::Library* get_library_cache_(const std::string& name);
-
195
-
196 MTL::Library* get_library_(const std::string& name);
-
197 MTL::Library* build_library_(const std::string& source_string);
+
184 MTL::ArgumentEncoder* argument_encoder(
+
185 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
+
186
+
187 // Record temporary arrays for the given stream index
+
188 void add_temporary(array arr, int index);
+
189 void add_temporaries(std::vector<array> arrays, int index);
+
190
+
191 void set_residency_set(const MTL::ResidencySet* residency_set);
+
192
+
193 private:
+
194 DeviceStream& get_stream_(int index) {
+
195 return stream_map_.find(index)->second;
+
196 }
+
197 MTL::Library* get_library_cache_(const std::string& name);
198
-
199 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
-
200
-
201 MTL::Function* get_function_(
-
202 const std::string& name,
-
203 const std::string& specialized_name,
-
204 const MTLFCList& func_consts,
-
205 MTL::Library* mtl_lib);
-
206
-
207 MTL::LinkedFunctions* get_linked_functions_(
-
208 const std::vector<MTL::Function*>& funcs);
+
199 MTL::Library* get_library_(const std::string& name);
+
200 MTL::Library* build_library_(const std::string& source_string);
+
201
+
202 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
+
203
+
204 MTL::Function* get_function_(
+
205 const std::string& name,
+
206 const std::string& specialized_name,
+
207 const MTLFCList& func_consts,
+
208 MTL::Library* mtl_lib);
209
-
210 MTL::ComputePipelineState* get_kernel_(
-
211 const std::string& name,
-
212 const MTL::Function* mtl_function);
-
213
-
214 MTL::ComputePipelineState* get_kernel_(
-
215 const std::string& name,
-
216 const MTL::Function* mtl_function,
-
217 const MTL::LinkedFunctions* linked_functions);
-
218
-
219 MTL::ComputePipelineState* get_kernel_(
-
220 const std::string& base_name,
-
221 MTL::Library* mtl_lib,
-
222 const std::string& hash_name,
-
223 const MTLFCList& func_consts = {},
-
224 const std::vector<MTL::Function*>& linked_functions = {});
-
225
-
226 MTL::Device* device_;
-
227 std::unordered_map<int32_t, DeviceStream> stream_map_;
+
210 MTL::LinkedFunctions* get_linked_functions_(
+
211 const std::vector<MTL::Function*>& funcs);
+
212
+
213 MTL::ComputePipelineState* get_kernel_(
+
214 const std::string& name,
+
215 const MTL::Function* mtl_function);
+
216
+
217 MTL::ComputePipelineState* get_kernel_(
+
218 const std::string& name,
+
219 const MTL::Function* mtl_function,
+
220 const MTL::LinkedFunctions* linked_functions);
+
221
+
222 MTL::ComputePipelineState* get_kernel_(
+
223 const std::string& base_name,
+
224 MTL::Library* mtl_lib,
+
225 const std::string& hash_name,
+
226 const MTLFCList& func_consts = {},
+
227 const std::vector<MTL::Function*>& linked_functions = {});
228
-
229 std::shared_mutex kernel_mtx_;
-
230 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
+
229 MTL::Device* device_;
+
230 std::unordered_map<int32_t, DeviceStream> stream_map_;
231
-
232 std::shared_mutex library_mtx_;
-
233 std::unordered_map<std::string, MTL::Library*> library_map_;
-
234 const MTL::ResidencySet* residency_set_{nullptr};
-
235 std::string arch_;
-
236};
+
232 std::shared_mutex kernel_mtx_;
+
233 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
+
234
+
235 std::shared_mutex library_mtx_;
+
236 std::unordered_map<std::string, MTL::Library*> library_map_;
+
237 const MTL::ResidencySet* residency_set_{nullptr};
+
238 std::string arch_;
+
239};
-
237
- -
239
-
240} // namespace mlx::core::metal
+
240
+ +
242
+
243} // namespace mlx::core::metal
Definition array.h:20
-
Definition device.h:128
+
Definition device.h:131
void set_residency_set(const MTL::ResidencySet *residency_set)
int get_command_buffer_ops(int index)
-
MTL::Device * mtl_device()
Definition device.h:135
+
MTL::Device * mtl_device()
Definition device.h:138
void register_library(const std::string &lib_name, const std::string &lib_path)
MTL::CommandBuffer * get_command_buffer(int index)
void end_encoding(int index)
-
const std::string & get_architecture()
Definition device.h:139
+
const std::string & get_architecture()
Definition device.h:142
MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
void add_temporaries(std::vector< array > arrays, int index)
@@ -383,7 +386,7 @@ $(function(){ initResizable(false); });
void increment_command_buffer_ops(int index)
void new_queue(int index)
void commit_command_buffer(int index)
-
void register_library(const std::string &lib_name)
Definition device.h:157
+
void register_library(const std::string &lib_name)
Definition device.h:160
Device(const Device &)=delete
void add_temporary(array arr, int index)
Device & operator=(const Device &)=delete
@@ -402,31 +405,32 @@ $(function(){ initResizable(false); });
Definition device.h:41
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)
CommandEncoder(MTL::CommandBuffer *cbuf)
-
std::unordered_set< const void * > & inputs()
Definition device.h:76
+
std::unordered_set< const void * > & inputs()
Definition device.h:77
CommandEncoder & operator=(const CommandEncoder &)=delete
-
ConcurrentContext start_concurrent()
Definition device.h:70
+
ConcurrentContext start_concurrent()
Definition device.h:71
void set_output_array(array &a, int idx, int64_t offset=0)
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)
MTL::ComputeCommandEncoder * operator->()
Definition device.h:61
void set_input_array(const array &a, int idx, int64_t offset=0)
CommandEncoder(const CommandEncoder &)=delete
-
std::unordered_set< const void * > outputs()
Definition device.h:81
-
Definition device.h:102
-
~DeviceStream()
Definition device.h:104
-
std::unordered_map< const void *, std::shared_ptr< Fence > > outputs
Definition device.h:112
-
DeviceStream(MTL::CommandQueue *queue)
Definition device.h:103
-
std::unique_ptr< CommandEncoder > encoder
Definition device.h:123
-
std::mutex fence_mtx
Definition device.h:114
-
MTL::CommandQueue * queue
Definition device.h:110
-
std::shared_ptr< Fence > fence
Definition device.h:124
-
MTL::CommandBuffer * buffer
Definition device.h:118
-
int buffer_ops
Definition device.h:119
-
std::vector< array > temporaries
Definition device.h:125
-
Definition device.h:94
-
Fence(MTL::Fence *fence)
Definition device.h:95
-
~Fence()
Definition device.h:96
-
MTL::Fence * fence
Definition device.h:99
+ +
std::unordered_set< const void * > outputs()
Definition device.h:82
+
Definition device.h:105
+
~DeviceStream()
Definition device.h:107
+
std::unordered_map< const void *, std::shared_ptr< Fence > > outputs
Definition device.h:115
+
DeviceStream(MTL::CommandQueue *queue)
Definition device.h:106
+
std::unique_ptr< CommandEncoder > encoder
Definition device.h:126
+
std::mutex fence_mtx
Definition device.h:117
+
MTL::CommandQueue * queue
Definition device.h:113
+
std::shared_ptr< Fence > fence
Definition device.h:127
+
MTL::CommandBuffer * buffer
Definition device.h:121
+
int buffer_ops
Definition device.h:122
+
std::vector< array > temporaries
Definition device.h:128
+
Definition device.h:97
+
Fence(MTL::Fence *fence)
Definition device.h:98
+
~Fence()
Definition device.h:99
+
MTL::Fence * fence
Definition device.h:102
+

Note: grid and threadgroup are parameters to the Metal dispatchThreads function. +This means we will launch mx.prod(grid) threads, subdivided into threadgroup size threadgroups. +For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.

Passing verbose=True to mx.fast.metal_kernel.__call__ will print the generated code for debugging purposes.

diff --git a/docs/build/html/doxygen_crawl.html b/docs/build/html/doxygen_crawl.html index d4fc5d84c..0236914c1 100644 --- a/docs/build/html/doxygen_crawl.html +++ b/docs/build/html/doxygen_crawl.html @@ -4330,9 +4330,9 @@ + - @@ -4443,9 +4443,9 @@ + - @@ -4850,9 +4850,11 @@ + + @@ -4900,7 +4902,6 @@ - @@ -5272,7 +5273,6 @@ - @@ -5933,11 +5933,11 @@ + - @@ -5952,6 +5952,7 @@ + @@ -6016,8 +6017,10 @@ + - + + @@ -6051,7 +6054,7 @@ - + @@ -6929,6 +6932,7 @@ + diff --git a/docs/build/html/functions_func_m.html b/docs/build/html/functions_func_m.html index d4f90d50b..b812c59a5 100644 --- a/docs/build/html/functions_func_m.html +++ b/docs/build/html/functions_func_m.html @@ -93,6 +93,7 @@ $(function(){ initResizable(false); });
  • Matmul() : mlx::core::Matmul
  • max() : metal::_numeric_limits_impl< bfloat16_t >
  • Maximum() : mlx::core::Maximum
  • +
  • maybeInsertBarrier() : mlx::core::metal::CommandEncoder
  • merge_partition() : BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >, KernelMultiBlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
  • merge_step() : BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
  • min() : metal::_numeric_limits_impl< bfloat16_t >
  • diff --git a/docs/build/html/functions_m.html b/docs/build/html/functions_m.html index 550ae601f..3be734fe6 100644 --- a/docs/build/html/functions_m.html +++ b/docs/build/html/functions_m.html @@ -102,6 +102,7 @@ $(function(){ initResizable(false); });
  • max_exponent : metal::_numeric_limits_impl< bfloat16_t >
  • max_exponent10 : metal::_numeric_limits_impl< bfloat16_t >
  • Maximum() : mlx::core::Maximum
  • +
  • maybeInsertBarrier() : mlx::core::metal::CommandEncoder
  • merge_partition() : BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >, KernelMultiBlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
  • merge_step() : BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp >
  • Min : mlx::core::distributed::AllReduce, mlx::core::Reduce, mlx::core::Scan, mlx::core::Scatter
  • diff --git a/docs/build/html/globals_c.html b/docs/build/html/globals_c.html index f2a87df5d..a9eed8c2d 100644 --- a/docs/build/html/globals_c.html +++ b/docs/build/html/globals_c.html @@ -92,8 +92,10 @@ $(function(){ initResizable(false); });
  • can_convert_to_bfloat : bf16.h
  • can_convert_to_complex64 : complex.h
  • ceildiv() : utils.h
  • +
  • col_reduce_2pass() : reduce_col.h
  • +
  • col_reduce_longcolumn() : reduce_col.h
  • col_reduce_looped() : reduce_col.h
  • -
  • col_reduce_small() : reduce_col.h
  • +
  • col_reduce_small() : reduce_col.h
  • complex_binop : complex.h
  • complex_binop_helper : complex.h
  • complex_mul() : radix.h
  • diff --git a/docs/build/html/globals_func_c.html b/docs/build/html/globals_func_c.html index 11545a40d..227276f3e 100644 --- a/docs/build/html/globals_func_c.html +++ b/docs/build/html/globals_func_c.html @@ -88,8 +88,10 @@ $(function(){ initResizable(false); });

    - c -

    diff --git a/docs/build/html/globals_func_s.html b/docs/build/html/globals_func_s.html index d8ecab0f0..ce63bfa00 100644 --- a/docs/build/html/globals_func_s.html +++ b/docs/build/html/globals_func_s.html @@ -88,7 +88,7 @@ $(function(){ initResizable(false); });

    - s -

    diff --git a/docs/build/html/globals_s.html b/docs/build/html/globals_s.html index a46a742a1..00ccc52f6 100644 --- a/docs/build/html/globals_s.html +++ b/docs/build/html/globals_s.html @@ -89,7 +89,7 @@ $(function(){ initResizable(false); });

    - s -

    • scatter_impl() : scatter.h
    • scatter_kernels : indexing.h
    • -
    • sdpa_vector() : sdpa_vector.h
    • +
    • sdpa_vector() : sdpa_vector.h
    • simd_shuffle() : utils.h
    • simd_shuffle_and_fill_up() : utils.h
    • simd_shuffle_down() : utils.h
    • diff --git a/docs/build/html/kernels_8h.html b/docs/build/html/kernels_8h.html index 3f1bb1307..4a5c35d1c 100644 --- a/docs/build/html/kernels_8h.html +++ b/docs/build/html/kernels_8h.html @@ -129,8 +129,8 @@ Functions   MTL::ComputePipelineState * mlx::core::get_mb_sort_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)   -MTL::ComputePipelineState * mlx::core::get_reduce_init_kernel (metal::Device &d, const std::string &kernel_name, const array &out) -  +MTL::ComputePipelineState * mlx::core::get_reduce_init_kernel (metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &out) +  MTL::ComputePipelineState * mlx::core::get_reduce_kernel (metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &in, const array &out, int ndim=-1, int bm=-1, int bn=-1)   MTL::ComputePipelineState * mlx::core::get_steel_gemm_fused_kernel (metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn) diff --git a/docs/build/html/kernels_8h_source.html b/docs/build/html/kernels_8h_source.html index 93d7a928c..de9608168 100644 --- a/docs/build/html/kernels_8h_source.html +++ b/docs/build/html/kernels_8h_source.html @@ -169,152 +169,154 @@ $(function(){ initResizable(false); });
      76 int bn,
      77 int tn);
      78
      -
      79MTL::ComputePipelineState* get_reduce_init_kernel(
      +
      79MTL::ComputePipelineState* get_reduce_init_kernel(
      81 const std::string& kernel_name,
      -
      82 const array& out);
      -
      83
      -
      84MTL::ComputePipelineState* get_reduce_kernel(
      - -
      86 const std::string& kernel_name,
      -
      87 const std::string& func_name,
      -
      88 const std::string& op_name,
      -
      89 const array& in,
      -
      90 const array& out,
      -
      91 int ndim = -1,
      -
      92 int bm = -1,
      -
      93 int bn = -1);
      -
      94
      -
      95MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
      - -
      97 const std::string& kernel_name,
      -
      98 const std::string& hash_name,
      -
      99 const metal::MTLFCList& func_consts,
      -
      100 const array& out,
      -
      101 bool transpose_a,
      -
      102 bool transpose_b,
      -
      103 int bm,
      -
      104 int bn,
      -
      105 int bk,
      -
      106 int wm,
      -
      107 int wn);
      -
      108
      -
      109MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
      -
      110 metal::Device& d,
      -
      111 const std::string& kernel_name,
      -
      112 const array& in,
      -
      113 const array& out,
      -
      114 bool transpose_a,
      -
      115 bool transpose_b,
      -
      116 int bm,
      -
      117 int bn,
      -
      118 int bk,
      -
      119 int wm,
      -
      120 int wn,
      -
      121 bool mn_aligned,
      -
      122 bool k_aligned);
      -
      123
      -
      124MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
      -
      125 metal::Device& d,
      -
      126 const std::string& kernel_name,
      -
      127 const array& in,
      -
      128 const array& out,
      -
      129 bool axbpy);
      -
      130
      -
      131MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
      -
      132 metal::Device& d,
      -
      133 const std::string& kernel_name,
      -
      134 const array& out,
      -
      135 const std::optional<array>& mask_out,
      -
      136 const std::optional<array>& mask_op,
      -
      137 bool transpose_a,
      -
      138 bool transpose_b,
      -
      139 int bm,
      -
      140 int bn,
      -
      141 int bk,
      -
      142 int wm,
      -
      143 int wn,
      -
      144 bool mn_aligned,
      -
      145 bool k_aligned);
      -
      146
      -
      147MTL::ComputePipelineState* get_steel_conv_kernel(
      -
      148 metal::Device& d,
      -
      149 const std::string& kernel_name,
      -
      150 const array& out,
      -
      151 int bm,
      -
      152 int bn,
      -
      153 int bk,
      -
      154 int wm,
      -
      155 int wn,
      -
      156 int n_channel_specialization,
      -
      157 bool small_filter);
      -
      158
      -
      159MTL::ComputePipelineState* get_gemv_masked_kernel(
      -
      160 metal::Device& d,
      -
      161 const std::string& kernel_name,
      -
      162 const array& out,
      -
      163 const std::optional<array>& mask_out,
      -
      164 const std::optional<array>& mask_op,
      -
      165 bool transpose_mat,
      -
      166 int bm,
      -
      167 int bn,
      -
      168 int sm,
      -
      169 int sn,
      -
      170 int tm,
      -
      171 int tn,
      -
      172 bool contiguous);
      -
      173
      -
      174MTL::ComputePipelineState* get_steel_conv_general_kernel(
      -
      175 metal::Device& d,
      -
      176 const std::string& kernel_name,
      -
      177 const array& out,
      -
      178 int bm,
      -
      179 int bn,
      -
      180 int bk,
      -
      181 int wm,
      -
      182 int wn);
      -
      183
      -
      184MTL::ComputePipelineState* get_fft_kernel(
      -
      185 metal::Device& d,
      -
      186 const std::string& kernel_name,
      -
      187 const std::string& hash_name,
      -
      188 const metal::MTLFCList& func_consts,
      -
      189 const std::string& template_def);
      -
      190
      -
      191MTL::ComputePipelineState* get_quantized_kernel(
      -
      192 metal::Device& d,
      -
      193 const std::string& kernel_name,
      -
      194 const std::string& template_def);
      -
      195
      -
      196// Create a GPU kernel template definition for JIT compilation
      -
      197template <typename... Args>
      -
      198std::string
      -
      -
      199get_template_definition(std::string name, std::string func, Args... args) {
      -
      200 std::ostringstream s;
      -
      201 s << func << "<";
      -
      202 bool first = true;
      -
      203 auto add_arg = [&s, &first](const auto& arg) {
      -
      204 if (!first) {
      -
      205 s << ", ";
      -
      206 }
      -
      207 first = false;
      -
      208 s << arg;
      -
      209 };
      -
      210 (add_arg(args), ...);
      -
      211 s << ">";
      -
      212 return fmt::format(
      -
      213 "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
      -
      214 name,
      -
      215 s.str());
      -
      216}
      +
      82 const std::string& func_name,
      +
      83 const std::string& op_name,
      +
      84 const array& out);
      +
      85
      +
      86MTL::ComputePipelineState* get_reduce_kernel(
      + +
      88 const std::string& kernel_name,
      +
      89 const std::string& func_name,
      +
      90 const std::string& op_name,
      +
      91 const array& in,
      +
      92 const array& out,
      +
      93 int ndim = -1,
      +
      94 int bm = -1,
      +
      95 int bn = -1);
      +
      96
      +
      97MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
      + +
      99 const std::string& kernel_name,
      +
      100 const std::string& hash_name,
      +
      101 const metal::MTLFCList& func_consts,
      +
      102 const array& out,
      +
      103 bool transpose_a,
      +
      104 bool transpose_b,
      +
      105 int bm,
      +
      106 int bn,
      +
      107 int bk,
      +
      108 int wm,
      +
      109 int wn);
      +
      110
      +
      111MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
      +
      112 metal::Device& d,
      +
      113 const std::string& kernel_name,
      +
      114 const array& in,
      +
      115 const array& out,
      +
      116 bool transpose_a,
      +
      117 bool transpose_b,
      +
      118 int bm,
      +
      119 int bn,
      +
      120 int bk,
      +
      121 int wm,
      +
      122 int wn,
      +
      123 bool mn_aligned,
      +
      124 bool k_aligned);
      +
      125
      +
      126MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
      +
      127 metal::Device& d,
      +
      128 const std::string& kernel_name,
      +
      129 const array& in,
      +
      130 const array& out,
      +
      131 bool axbpy);
      +
      132
      +
      133MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
      +
      134 metal::Device& d,
      +
      135 const std::string& kernel_name,
      +
      136 const array& out,
      +
      137 const std::optional<array>& mask_out,
      +
      138 const std::optional<array>& mask_op,
      +
      139 bool transpose_a,
      +
      140 bool transpose_b,
      +
      141 int bm,
      +
      142 int bn,
      +
      143 int bk,
      +
      144 int wm,
      +
      145 int wn,
      +
      146 bool mn_aligned,
      +
      147 bool k_aligned);
      +
      148
      +
      149MTL::ComputePipelineState* get_steel_conv_kernel(
      +
      150 metal::Device& d,
      +
      151 const std::string& kernel_name,
      +
      152 const array& out,
      +
      153 int bm,
      +
      154 int bn,
      +
      155 int bk,
      +
      156 int wm,
      +
      157 int wn,
      +
      158 int n_channel_specialization,
      +
      159 bool small_filter);
      +
      160
      +
      161MTL::ComputePipelineState* get_gemv_masked_kernel(
      +
      162 metal::Device& d,
      +
      163 const std::string& kernel_name,
      +
      164 const array& out,
      +
      165 const std::optional<array>& mask_out,
      +
      166 const std::optional<array>& mask_op,
      +
      167 bool transpose_mat,
      +
      168 int bm,
      +
      169 int bn,
      +
      170 int sm,
      +
      171 int sn,
      +
      172 int tm,
      +
      173 int tn,
      +
      174 bool contiguous);
      +
      175
      +
      176MTL::ComputePipelineState* get_steel_conv_general_kernel(
      +
      177 metal::Device& d,
      +
      178 const std::string& kernel_name,
      +
      179 const array& out,
      +
      180 int bm,
      +
      181 int bn,
      +
      182 int bk,
      +
      183 int wm,
      +
      184 int wn);
      +
      185
      +
      186MTL::ComputePipelineState* get_fft_kernel(
      +
      187 metal::Device& d,
      +
      188 const std::string& kernel_name,
      +
      189 const std::string& hash_name,
      +
      190 const metal::MTLFCList& func_consts,
      +
      191 const std::string& template_def);
      +
      192
      +
      193MTL::ComputePipelineState* get_quantized_kernel(
      +
      194 metal::Device& d,
      +
      195 const std::string& kernel_name,
      +
      196 const std::string& template_def);
      +
      197
      +
      198// Create a GPU kernel template definition for JIT compilation
      +
      199template <typename... Args>
      +
      200std::string
      +
      +
      201get_template_definition(std::string name, std::string func, Args... args) {
      +
      202 std::ostringstream s;
      +
      203 s << func << "<";
      +
      204 bool first = true;
      +
      205 auto add_arg = [&s, &first](const auto& arg) {
      +
      206 if (!first) {
      +
      207 s << ", ";
      +
      208 }
      +
      209 first = false;
      +
      210 s << arg;
      +
      211 };
      +
      212 (add_arg(args), ...);
      +
      213 s << ">";
      +
      214 return fmt::format(
      +
      215 "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
      +
      216 name,
      +
      217 s.str());
      +
      218}
      -
      217
      -
      218} // namespace mlx::core
      +
      219
      +
      220} // namespace mlx::core
      Definition array.h:20
      -
      Definition device.h:128
      +
      Definition device.h:131
      Op op
      Definition binary.h:129
      std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
      Definition device.h:38
      Definition allocator.h:7
      @@ -322,9 +324,9 @@ $(function(){ initResizable(false); });
      MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
      MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)
      MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
      +
      MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &out)
      MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
      MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
      -
      MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
      MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)
      MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
      MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &in, const array &out, int ndim=-1, int bm=-1, int bn=-1)
      @@ -332,7 +334,7 @@ $(function(){ initResizable(false); });
      MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
      MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)
      MTL::ComputePipelineState * get_quantized_kernel(metal::Device &d, const std::string &kernel_name, const std::string &template_def)
      -
      std::string get_template_definition(std::string name, std::string func, Args... args)
      Definition kernels.h:199
      +
      std::string get_template_definition(std::string name, std::string func, Args... args)
      Definition kernels.h:201
      MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
      MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
      MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
      diff --git a/docs/build/html/matmul_8h_source.html b/docs/build/html/matmul_8h_source.html index 4258466e4..ce4ccfe53 100644 --- a/docs/build/html/matmul_8h_source.html +++ b/docs/build/html/matmul_8h_source.html @@ -143,7 +143,7 @@ $(function(){ initResizable(false); });
      50} // namespace mlx::core
      Definition array.h:20
      -
      Definition device.h:128
      +
      Definition device.h:131
      Definition allocator.h:7
      void steel_matmul_regular(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, std::vector< int > batch_shape, std::vector< size_t > batch_strides, size_t A_batch_stride, size_t B_batch_stride, size_t matrix_stride_out, std::vector< array > &copies)
      void steel_matmul(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< array > &copies, std::vector< int > batch_shape={}, std::vector< size_t > A_batch_stride={}, std::vector< size_t > B_batch_stride={})
      diff --git a/docs/build/html/metal_2reduce_8h.html b/docs/build/html/metal_2reduce_8h.html index 43294f485..d08ed5e85 100644 --- a/docs/build/html/metal_2reduce_8h.html +++ b/docs/build/html/metal_2reduce_8h.html @@ -109,8 +109,8 @@ Namespaces - - + + diff --git a/docs/build/html/metal_2reduce_8h_source.html b/docs/build/html/metal_2reduce_8h_source.html index 9dfc6469d..79e30d6a2 100644 --- a/docs/build/html/metal_2reduce_8h_source.html +++ b/docs/build/html/metal_2reduce_8h_source.html @@ -103,44 +103,43 @@ $(function(){ initResizable(false); });
      10
      11using metal::CommandEncoder;
      12
      - +
      14 const array& in,
      15 array& out,
      16 const std::string& op_name,
      17 CommandEncoder& compute_encoder,
      -
      19 const Stream& s,
      -
      20 std::vector<array>& copies);
      -
      21
      - -
      23 const array& in,
      -
      24 array& out,
      -
      25 const std::string& op_name,
      -
      26 const ReductionPlan& plan,
      -
      27 const std::vector<int>& axes,
      -
      28 CommandEncoder& compute_encoder,
      - -
      30 const Stream& s);
      -
      31
      - -
      33 const array& in,
      -
      34 array& out,
      -
      35 const std::string& op_name,
      -
      36 const ReductionPlan& plan,
      -
      37 const std::vector<int>& axes,
      -
      38 CommandEncoder& compute_encoder,
      - -
      40 const Stream& s);
      -
      41
      -
      42} // namespace mlx::core
      +
      19 const Stream& s);
      +
      20
      + +
      22 const array& in,
      +
      23 array& out,
      +
      24 const std::string& op_name,
      +
      25 const ReductionPlan& plan,
      +
      26 const std::vector<int>& axes,
      +
      27 CommandEncoder& compute_encoder,
      + +
      29 const Stream& s);
      +
      30
      + +
      32 const array& in,
      +
      33 array& out,
      +
      34 const std::string& op_name,
      +
      35 const ReductionPlan& plan,
      +
      36 const std::vector<int>& axes,
      +
      37 CommandEncoder& compute_encoder,
      + +
      39 const Stream& s);
      +
      40
      +
      41} // namespace mlx::core
      Definition array.h:20
      -
      Definition device.h:128
      +
      Definition device.h:131
      Definition allocator.h:7
      +
      void all_reduce_dispatch(const array &in, array &out, const std::string &op_name, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
      void strided_reduce_general_dispatch(const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
      void row_reduce_general_dispatch(const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
      -
      void all_reduce_dispatch(const array &in, array &out, const std::string &op_name, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s, std::vector< array > &copies)
      Definition reduce.h:39
      Definition stream.h:9
      diff --git a/docs/build/html/namespacemembers.html b/docs/build/html/namespacemembers.html index ffff1cc90..5625555af 100644 --- a/docs/build/html/namespacemembers.html +++ b/docs/build/html/namespacemembers.html @@ -99,7 +99,7 @@ $(function(){ initResizable(false); });
    • aligned_dealloc() : pocketfft::detail
    • all() : mlx::core
    • all_gather() : mlx::core::distributed, mlx::core::distributed::detail
    • -
    • all_reduce_dispatch() : mlx::core
    • +
    • all_reduce_dispatch() : mlx::core
    • all_sum() : mlx::core::distributed, mlx::core::distributed::detail
    • allclose() : mlx::core
    • alloc_tmp() : pocketfft::detail
    • diff --git a/docs/build/html/namespacemembers_func.html b/docs/build/html/namespacemembers_func.html index 9e16a9a69..e838b45c1 100644 --- a/docs/build/html/namespacemembers_func.html +++ b/docs/build/html/namespacemembers_func.html @@ -98,7 +98,7 @@ $(function(){ initResizable(false); });
    • aligned_dealloc() : pocketfft::detail
    • all() : mlx::core
    • all_gather() : mlx::core::distributed, mlx::core::distributed::detail
    • -
    • all_reduce_dispatch() : mlx::core
    • +
    • all_reduce_dispatch() : mlx::core
    • all_sum() : mlx::core::distributed, mlx::core::distributed::detail
    • allclose() : mlx::core
    • alloc_tmp() : pocketfft::detail
    • diff --git a/docs/build/html/namespacemembers_func_g.html b/docs/build/html/namespacemembers_func_g.html index 5e2c379ff..1c777baeb 100644 --- a/docs/build/html/namespacemembers_func_g.html +++ b/docs/build/html/namespacemembers_func_g.html @@ -112,7 +112,7 @@ $(function(){ initResizable(false); });
    • get_pool() : pocketfft::detail::threading
    • get_primitive_string() : mlx::core
    • get_quantized_kernel() : mlx::core
    • -
    • get_reduce_init_kernel() : mlx::core
    • +
    • get_reduce_init_kernel() : mlx::core
    • get_reduce_kernel() : mlx::core
    • get_reduction_plan() : mlx::core
    • get_scan_kernel() : mlx::core
    • diff --git a/docs/build/html/namespacemembers_g.html b/docs/build/html/namespacemembers_g.html index 73be65412..0bbdf9db4 100644 --- a/docs/build/html/namespacemembers_g.html +++ b/docs/build/html/namespacemembers_g.html @@ -116,7 +116,7 @@ $(function(){ initResizable(false); });
    • get_pool() : pocketfft::detail::threading
    • get_primitive_string() : mlx::core
    • get_quantized_kernel() : mlx::core
    • -
    • get_reduce_init_kernel() : mlx::core
    • +
    • get_reduce_init_kernel() : mlx::core
    • get_reduce_kernel() : mlx::core
    • get_reduction_plan() : mlx::core
    • get_scan_kernel() : mlx::core
    • diff --git a/docs/build/html/namespacemlx_1_1core.html b/docs/build/html/namespacemlx_1_1core.html index cef5d8c69..09ca616f9 100644 --- a/docs/build/html/namespacemlx_1_1core.html +++ b/docs/build/html/namespacemlx_1_1core.html @@ -534,8 +534,8 @@ Functions - - + + @@ -563,8 +563,8 @@ Functions - - + + @@ -2634,8 +2634,8 @@ template<typename... T>

      Function Documentation

      - -

      ◆ all_reduce_dispatch()

      + +

      ◆ all_reduce_dispatch()

      @@ -2668,12 +2668,7 @@ template<typename... T>
      - - - - - - +

      Functions

      void mlx::core::all_reduce_dispatch (const array &in, array &out, const std::string &op_name, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s, std::vector< array > &copies)
       
      void mlx::core::all_reduce_dispatch (const array &in, array &out, const std::string &op_name, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
       
      void mlx::core::row_reduce_general_dispatch (const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
       
      void mlx::core::strided_reduce_general_dispatch (const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
       
      MTL::ComputePipelineState * get_mb_sort_kernel (metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)
       
      MTL::ComputePipelineState * get_reduce_init_kernel (metal::Device &d, const std::string &kernel_name, const array &out)
       
      MTL::ComputePipelineState * get_reduce_init_kernel (metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &out)
       
      MTL::ComputePipelineState * get_reduce_kernel (metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &in, const array &out, int ndim=-1, int bm=-1, int bn=-1)
       
      MTL::ComputePipelineState * get_steel_gemm_fused_kernel (metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
       
      void steel_matmul (const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< array > &copies, std::vector< int > batch_shape={}, std::vector< size_t > A_batch_stride={}, std::vector< size_t > B_batch_stride={})
       
      void all_reduce_dispatch (const array &in, array &out, const std::string &op_name, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s, std::vector< array > &copies)
       
      void all_reduce_dispatch (const array &in, array &out, const std::string &op_name, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
       
      void row_reduce_general_dispatch (const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
       
      void strided_reduce_general_dispatch (const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)
      const Stream & s,
      std::vector< array > & copies )const Stream & s )
      @@ -4418,8 +4413,8 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
    - -

    ◆ get_reduce_init_kernel()

    + +

    ◆ get_reduce_init_kernel()

    @@ -4434,6 +4429,16 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...> const std::string & kernel_name, + + + + const std::string & func_name, + + + + + const std::string & op_name, + diff --git a/docs/build/html/objects.inv b/docs/build/html/objects.inv index c31b0569f..5bbbf441f 100644 Binary files a/docs/build/html/objects.inv and b/docs/build/html/objects.inv differ diff --git a/docs/build/html/python/_autosummary/mlx.core.fast.metal_kernel.html b/docs/build/html/python/_autosummary/mlx.core.fast.metal_kernel.html index 085ae4f4a..a69f99808 100644 --- a/docs/build/html/python/_autosummary/mlx.core.fast.metal_kernel.html +++ b/docs/build/html/python/_autosummary/mlx.core.fast.metal_kernel.html @@ -867,6 +867,7 @@
    metal_kernel(name: str, input_names: Sequence[str], output_names: Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) object#

    A jit-compiled custom Metal kernel defined from a source string.

    +

    Full documentation: Custom Metal Kernels.

    Parameters:
      diff --git a/docs/build/html/quantized_8h.html b/docs/build/html/quantized_8h.html index f865582b9..c0e18903c 100644 --- a/docs/build/html/quantized_8h.html +++ b/docs/build/html/quantized_8h.html @@ -140,9 +140,9 @@ Functions template<typename T , int group_size, int bits> METAL_FUNC void qmv_impl (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)   -template<typename T , const int group_size, const int bits> -METAL_FUNC void qvm_impl (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid) -  +template<typename T , const int group_size, const int bits> +METAL_FUNC void qvm_impl (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid) +  template<typename T , const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> METAL_FUNC void qmm_t_impl (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)   @@ -167,6 +167,9 @@ Functions template<typename T , const int group_size, const int bits, bool batched> void qvm (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)   +template<typename T , const int group_size, const int bits, int split_k = 32> +void qvm_split_k (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid) +  template<typename T , const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> void qmm_t (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)   @@ -2485,8 +2488,8 @@ template<typename T , const int group_size, const int bits, bool batched>
    - -

    ◆ qvm_impl()

    + +

    ◆ qvm_impl()

    @@ -2518,6 +2521,69 @@ template<typename T , const int group_size, const int bits>
    device T * y, + + + + const int in_vec_size, + + + + + const int out_vec_size, + + + + + uint3 tid, + + + + + uint simd_gid, + + + + + uint simd_lid ) + + +
    + +
    +
    + +

    ◆ qvm_split_k()

    + +
    +
    +
    +template<typename T , const int group_size, const int bits, int split_k = 32>
    + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -2528,6 +2594,51 @@ template<typename T , const int group_size, const int bits> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/build/html/quantized_8h_source.html b/docs/build/html/quantized_8h_source.html index 07b25d33f..0a08db757 100644 --- a/docs/build/html/quantized_8h_source.html +++ b/docs/build/html/quantized_8h_source.html @@ -766,14 +766,14 @@ $(function(){ initResizable(false); });
    645
    646template <typename T, const int group_size, const int bits>
    -
    647METAL_FUNC void qvm_impl(
    +
    647METAL_FUNC void qvm_impl(
    648 const device uint32_t* w,
    649 const device T* scales,
    650 const device T* biases,
    651 const device T* x,
    652 device T* y,
    -
    653 const constant int& in_vec_size,
    -
    654 const constant int& out_vec_size,
    +
    653 const int in_vec_size,
    +
    654 const int out_vec_size,
    655 uint3 tid [[threadgroup_position_in_grid]],
    656 uint simd_gid [[simdgroup_index_in_threadgroup]],
    657 uint simd_lid [[thread_index_in_simdgroup]]) {
    @@ -1423,7 +1423,7 @@ $(function(){ initResizable(false); });
    1285 b_strides,
    1286 tid);
    1287 }
    - +
    1289 w,
    1290 scales,
    1291 biases,
    @@ -1437,610 +1437,667 @@ $(function(){ initResizable(false); });
    1299}
    1300
    -
    1301template <
    -
    1302 typename T,
    -
    1303 const int group_size,
    -
    1304 const int bits,
    -
    1305 const bool aligned_N,
    -
    1306 const bool batched,
    -
    1307 const int BM = 32,
    -
    1308 const int BK = 32,
    -
    1309 const int BN = 32>
    -
    -
    1310[[kernel]] void qmm_t(
    -
    1311 const device uint32_t* w [[buffer(0)]],
    -
    1312 const device T* scales [[buffer(1)]],
    -
    1313 const device T* biases [[buffer(2)]],
    -
    1314 const device T* x [[buffer(3)]],
    -
    1315 device T* y [[buffer(4)]],
    -
    1316 const constant int& K [[buffer(5)]],
    -
    1317 const constant int& N [[buffer(6)]],
    -
    1318 const constant int& M [[buffer(7)]],
    -
    1319 const constant int& x_batch_ndims [[buffer(8)]],
    -
    1320 const constant int* x_shape [[buffer(9)]],
    -
    1321 const constant size_t* x_strides [[buffer(10)]],
    -
    1322 const constant int& w_batch_ndims [[buffer(11)]],
    -
    1323 const constant int* w_shape [[buffer(12)]],
    -
    1324 const constant size_t* w_strides [[buffer(13)]],
    -
    1325 const constant size_t* s_strides [[buffer(14)]],
    -
    1326 const constant size_t* b_strides [[buffer(15)]],
    -
    1327 uint3 tid [[threadgroup_position_in_grid]],
    -
    1328 uint lid [[thread_index_in_threadgroup]],
    -
    1329 uint simd_gid [[simdgroup_index_in_threadgroup]],
    -
    1330 uint simd_lid [[thread_index_in_simdgroup]]) {
    -
    1331 (void)lid;
    -
    1332
    -
    1333 constexpr int BK_padded = (BK + 16 / sizeof(T));
    -
    1334
    -
    1335 threadgroup T Xs[BM * BK_padded];
    -
    1336 threadgroup T Ws[BN * BK_padded];
    -
    1337
    -
    1338 if (batched) {
    - -
    1340 x,
    -
    1341 w,
    -
    1342 scales,
    -
    1343 biases,
    -
    1344 y,
    -
    1345 M * N,
    -
    1346 x_batch_ndims,
    -
    1347 x_shape,
    -
    1348 x_strides,
    -
    1349 w_batch_ndims,
    -
    1350 w_shape,
    -
    1351 w_strides,
    -
    1352 s_strides,
    -
    1353 b_strides,
    -
    1354 tid);
    -
    1355 }
    - -
    1357 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
    -
    1358}
    +
    1301template <typename T, const int group_size, const int bits, int split_k = 32>
    +
    +
    1302[[kernel]] void qvm_split_k(
    +
    1303 const device uint32_t* w [[buffer(0)]],
    +
    1304 const device T* scales [[buffer(1)]],
    +
    1305 const device T* biases [[buffer(2)]],
    +
    1306 const device T* x [[buffer(3)]],
    +
    1307 device T* y [[buffer(4)]],
    +
    1308 const constant int& in_vec_size [[buffer(5)]],
    +
    1309 const constant int& out_vec_size [[buffer(6)]],
    +
    1310 const constant int& x_batch_ndims [[buffer(7)]],
    +
    1311 const constant int* x_shape [[buffer(8)]],
    +
    1312 const constant size_t* x_strides [[buffer(9)]],
    +
    1313 const constant int& w_batch_ndims [[buffer(10)]],
    +
    1314 const constant int* w_shape [[buffer(11)]],
    +
    1315 const constant size_t* w_strides [[buffer(12)]],
    +
    1316 const constant size_t* s_strides [[buffer(13)]],
    +
    1317 const constant size_t* b_strides [[buffer(14)]],
    +
    1318 const constant int& final_block_size [[buffer(15)]],
    +
    1319 uint3 tid [[threadgroup_position_in_grid]],
    +
    1320 uint simd_gid [[simdgroup_index_in_threadgroup]],
    +
    1321 uint simd_lid [[thread_index_in_simdgroup]]) {
    + +
    1323 x,
    +
    1324 w,
    +
    1325 scales,
    +
    1326 biases,
    +
    1327 y,
    +
    1328 out_vec_size,
    +
    1329 x_batch_ndims,
    +
    1330 x_shape,
    +
    1331 x_strides,
    +
    1332 w_batch_ndims,
    +
    1333 w_shape,
    +
    1334 w_strides,
    +
    1335 s_strides,
    +
    1336 b_strides,
    +
    1337 tid);
    +
    1338
    +
    1339 // When (in_vec_size % split_k != 0) the final block needs to be smaller
    +
    1340 int in_vec_size_adj =
    +
    1341 tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
    +
    1342
    + +
    1344 w,
    +
    1345 scales,
    +
    1346 biases,
    +
    1347 x,
    +
    1348 y,
    +
    1349 in_vec_size_adj,
    +
    1350 out_vec_size,
    +
    1351 tid,
    +
    1352 simd_gid,
    +
    1353 simd_lid);
    +
    1354}
    -
    1359
    -
    1360template <
    -
    1361 typename T,
    -
    1362 const int group_size,
    -
    1363 const int bits,
    -
    1364 const bool batched,
    -
    1365 const int BM = 32,
    -
    1366 const int BK = 32,
    -
    1367 const int BN = 32>
    -
    -
    1368[[kernel]] void qmm_n(
    -
    1369 const device uint32_t* w [[buffer(0)]],
    -
    1370 const device T* scales [[buffer(1)]],
    -
    1371 const device T* biases [[buffer(2)]],
    -
    1372 const device T* x [[buffer(3)]],
    -
    1373 device T* y [[buffer(4)]],
    -
    1374 const constant int& K [[buffer(5)]],
    -
    1375 const constant int& N [[buffer(6)]],
    -
    1376 const constant int& M [[buffer(7)]],
    -
    1377 const constant int& x_batch_ndims [[buffer(8)]],
    -
    1378 const constant int* x_shape [[buffer(9)]],
    -
    1379 const constant size_t* x_strides [[buffer(10)]],
    -
    1380 const constant int& w_batch_ndims [[buffer(11)]],
    -
    1381 const constant int* w_shape [[buffer(12)]],
    -
    1382 const constant size_t* w_strides [[buffer(13)]],
    -
    1383 const constant size_t* s_strides [[buffer(14)]],
    -
    1384 const constant size_t* b_strides [[buffer(15)]],
    -
    1385 uint3 tid [[threadgroup_position_in_grid]],
    -
    1386 uint lid [[thread_index_in_threadgroup]],
    -
    1387 uint simd_gid [[simdgroup_index_in_threadgroup]],
    -
    1388 uint simd_lid [[thread_index_in_simdgroup]]) {
    -
    1389 (void)lid;
    -
    1390
    -
    1391 constexpr int BK_padded = (BK + 16 / sizeof(T));
    -
    1392 constexpr int BN_padded = (BN + 16 / sizeof(T));
    -
    1393
    -
    1394 threadgroup T Xs[BM * BK_padded];
    -
    1395 threadgroup T Ws[BK * BN_padded];
    -
    1396
    -
    1397 if (batched) {
    - -
    1399 x,
    -
    1400 w,
    -
    1401 scales,
    -
    1402 biases,
    -
    1403 y,
    -
    1404 M * N,
    -
    1405 x_batch_ndims,
    -
    1406 x_shape,
    -
    1407 x_strides,
    -
    1408 w_batch_ndims,
    -
    1409 w_shape,
    -
    1410 w_strides,
    -
    1411 s_strides,
    -
    1412 b_strides,
    -
    1413 tid);
    -
    1414 }
    -
    1415
    - -
    1417 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
    -
    1418}
    +
    1355
    +
    1356template <
    +
    1357 typename T,
    +
    1358 const int group_size,
    +
    1359 const int bits,
    +
    1360 const bool aligned_N,
    +
    1361 const bool batched,
    +
    1362 const int BM = 32,
    +
    1363 const int BK = 32,
    +
    1364 const int BN = 32>
    +
    +
    1365[[kernel]] void qmm_t(
    +
    1366 const device uint32_t* w [[buffer(0)]],
    +
    1367 const device T* scales [[buffer(1)]],
    +
    1368 const device T* biases [[buffer(2)]],
    +
    1369 const device T* x [[buffer(3)]],
    +
    1370 device T* y [[buffer(4)]],
    +
    1371 const constant int& K [[buffer(5)]],
    +
    1372 const constant int& N [[buffer(6)]],
    +
    1373 const constant int& M [[buffer(7)]],
    +
    1374 const constant int& x_batch_ndims [[buffer(8)]],
    +
    1375 const constant int* x_shape [[buffer(9)]],
    +
    1376 const constant size_t* x_strides [[buffer(10)]],
    +
    1377 const constant int& w_batch_ndims [[buffer(11)]],
    +
    1378 const constant int* w_shape [[buffer(12)]],
    +
    1379 const constant size_t* w_strides [[buffer(13)]],
    +
    1380 const constant size_t* s_strides [[buffer(14)]],
    +
    1381 const constant size_t* b_strides [[buffer(15)]],
    +
    1382 uint3 tid [[threadgroup_position_in_grid]],
    +
    1383 uint lid [[thread_index_in_threadgroup]],
    +
    1384 uint simd_gid [[simdgroup_index_in_threadgroup]],
    +
    1385 uint simd_lid [[thread_index_in_simdgroup]]) {
    +
    1386 (void)lid;
    +
    1387
    +
    1388 constexpr int BK_padded = (BK + 16 / sizeof(T));
    +
    1389
    +
    1390 threadgroup T Xs[BM * BK_padded];
    +
    1391 threadgroup T Ws[BN * BK_padded];
    +
    1392
    +
    1393 if (batched) {
    + +
    1395 x,
    +
    1396 w,
    +
    1397 scales,
    +
    1398 biases,
    +
    1399 y,
    +
    1400 M * N,
    +
    1401 x_batch_ndims,
    +
    1402 x_shape,
    +
    1403 x_strides,
    +
    1404 w_batch_ndims,
    +
    1405 w_shape,
    +
    1406 w_strides,
    +
    1407 s_strides,
    +
    1408 b_strides,
    +
    1409 tid);
    +
    1410 }
    + +
    1412 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
    +
    1413}
    -
    1419
    -
    1420template <typename T, int group_size, int bits>
    -
    -
    1421[[kernel]] void bs_qmv_fast(
    -
    1422 const device uint32_t* w [[buffer(0)]],
    -
    1423 const device T* scales [[buffer(1)]],
    -
    1424 const device T* biases [[buffer(2)]],
    -
    1425 const device T* x [[buffer(3)]],
    -
    1426 device T* y [[buffer(4)]],
    -
    1427 const constant int& in_vec_size [[buffer(5)]],
    -
    1428 const constant int& out_vec_size [[buffer(6)]],
    -
    1429 const constant int& x_batch_ndims [[buffer(7)]],
    -
    1430 const constant int* x_shape [[buffer(8)]],
    -
    1431 const constant size_t* x_strides [[buffer(9)]],
    -
    1432 const constant int& w_batch_ndims [[buffer(10)]],
    -
    1433 const constant int* w_shape [[buffer(11)]],
    -
    1434 const constant size_t* w_strides [[buffer(12)]],
    -
    1435 const constant size_t* s_strides [[buffer(13)]],
    -
    1436 const constant size_t* b_strides [[buffer(14)]],
    -
    1437 const constant int& batch_ndims [[buffer(15)]],
    -
    1438 const constant int* batch_shape [[buffer(16)]],
    -
    1439 const device uint32_t* lhs_indices [[buffer(17)]],
    -
    1440 const device uint32_t* rhs_indices [[buffer(18)]],
    -
    1441 const constant size_t* lhs_strides [[buffer(19)]],
    -
    1442 const constant size_t* rhs_strides [[buffer(20)]],
    -
    1443 uint3 tid [[threadgroup_position_in_grid]],
    -
    1444 uint simd_gid [[simdgroup_index_in_threadgroup]],
    -
    1445 uint simd_lid [[thread_index_in_simdgroup]]) {
    - -
    1447 x,
    -
    1448 w,
    -
    1449 scales,
    -
    1450 biases,
    -
    1451 lhs_indices,
    -
    1452 rhs_indices,
    -
    1453 y,
    -
    1454 out_vec_size,
    -
    1455 batch_ndims,
    -
    1456 batch_shape,
    -
    1457 lhs_strides,
    -
    1458 rhs_strides,
    -
    1459 x_batch_ndims,
    -
    1460 x_shape,
    -
    1461 x_strides,
    -
    1462 w_batch_ndims,
    -
    1463 w_shape,
    -
    1464 w_strides,
    -
    1465 s_strides,
    -
    1466 b_strides,
    -
    1467 tid);
    - -
    1469 w,
    -
    1470 scales,
    -
    1471 biases,
    -
    1472 x,
    -
    1473 y,
    -
    1474 in_vec_size,
    -
    1475 out_vec_size,
    -
    1476 tid,
    -
    1477 simd_gid,
    -
    1478 simd_lid);
    -
    1479}
    +
    1414
    +
    1415template <
    +
    1416 typename T,
    +
    1417 const int group_size,
    +
    1418 const int bits,
    +
    1419 const bool batched,
    +
    1420 const int BM = 32,
    +
    1421 const int BK = 32,
    +
    1422 const int BN = 32>
    +
    +
    1423[[kernel]] void qmm_n(
    +
    1424 const device uint32_t* w [[buffer(0)]],
    +
    1425 const device T* scales [[buffer(1)]],
    +
    1426 const device T* biases [[buffer(2)]],
    +
    1427 const device T* x [[buffer(3)]],
    +
    1428 device T* y [[buffer(4)]],
    +
    1429 const constant int& K [[buffer(5)]],
    +
    1430 const constant int& N [[buffer(6)]],
    +
    1431 const constant int& M [[buffer(7)]],
    +
    1432 const constant int& x_batch_ndims [[buffer(8)]],
    +
    1433 const constant int* x_shape [[buffer(9)]],
    +
    1434 const constant size_t* x_strides [[buffer(10)]],
    +
    1435 const constant int& w_batch_ndims [[buffer(11)]],
    +
    1436 const constant int* w_shape [[buffer(12)]],
    +
    1437 const constant size_t* w_strides [[buffer(13)]],
    +
    1438 const constant size_t* s_strides [[buffer(14)]],
    +
    1439 const constant size_t* b_strides [[buffer(15)]],
    +
    1440 uint3 tid [[threadgroup_position_in_grid]],
    +
    1441 uint lid [[thread_index_in_threadgroup]],
    +
    1442 uint simd_gid [[simdgroup_index_in_threadgroup]],
    +
    1443 uint simd_lid [[thread_index_in_simdgroup]]) {
    +
    1444 (void)lid;
    +
    1445
    +
    1446 constexpr int BK_padded = (BK + 16 / sizeof(T));
    +
    1447 constexpr int BN_padded = (BN + 16 / sizeof(T));
    +
    1448
    +
    1449 threadgroup T Xs[BM * BK_padded];
    +
    1450 threadgroup T Ws[BK * BN_padded];
    +
    1451
    +
    1452 if (batched) {
    + +
    1454 x,
    +
    1455 w,
    +
    1456 scales,
    +
    1457 biases,
    +
    1458 y,
    +
    1459 M * N,
    +
    1460 x_batch_ndims,
    +
    1461 x_shape,
    +
    1462 x_strides,
    +
    1463 w_batch_ndims,
    +
    1464 w_shape,
    +
    1465 w_strides,
    +
    1466 s_strides,
    +
    1467 b_strides,
    +
    1468 tid);
    +
    1469 }
    +
    1470
    + +
    1472 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
    +
    1473}
    -
    1480
    -
    1481template <typename T, int group_size, int bits>
    -
    -
    1482[[kernel]] void bs_qmv(
    -
    1483 const device uint32_t* w [[buffer(0)]],
    -
    1484 const device T* scales [[buffer(1)]],
    -
    1485 const device T* biases [[buffer(2)]],
    -
    1486 const device T* x [[buffer(3)]],
    -
    1487 device T* y [[buffer(4)]],
    -
    1488 const constant int& in_vec_size [[buffer(5)]],
    -
    1489 const constant int& out_vec_size [[buffer(6)]],
    -
    1490 const constant int& x_batch_ndims [[buffer(7)]],
    -
    1491 const constant int* x_shape [[buffer(8)]],
    -
    1492 const constant size_t* x_strides [[buffer(9)]],
    -
    1493 const constant int& w_batch_ndims [[buffer(10)]],
    -
    1494 const constant int* w_shape [[buffer(11)]],
    -
    1495 const constant size_t* w_strides [[buffer(12)]],
    -
    1496 const constant size_t* s_strides [[buffer(13)]],
    -
    1497 const constant size_t* b_strides [[buffer(14)]],
    -
    1498 const constant int& batch_ndims [[buffer(15)]],
    -
    1499 const constant int* batch_shape [[buffer(16)]],
    -
    1500 const device uint32_t* lhs_indices [[buffer(17)]],
    -
    1501 const device uint32_t* rhs_indices [[buffer(18)]],
    -
    1502 const constant size_t* lhs_strides [[buffer(19)]],
    -
    1503 const constant size_t* rhs_strides [[buffer(20)]],
    -
    1504 uint3 tid [[threadgroup_position_in_grid]],
    -
    1505 uint simd_gid [[simdgroup_index_in_threadgroup]],
    -
    1506 uint simd_lid [[thread_index_in_simdgroup]]) {
    - -
    1508 x,
    -
    1509 w,
    -
    1510 scales,
    -
    1511 biases,
    -
    1512 lhs_indices,
    -
    1513 rhs_indices,
    -
    1514 y,
    -
    1515 out_vec_size,
    -
    1516 batch_ndims,
    -
    1517 batch_shape,
    -
    1518 lhs_strides,
    -
    1519 rhs_strides,
    -
    1520 x_batch_ndims,
    -
    1521 x_shape,
    -
    1522 x_strides,
    -
    1523 w_batch_ndims,
    -
    1524 w_shape,
    -
    1525 w_strides,
    -
    1526 s_strides,
    -
    1527 b_strides,
    -
    1528 tid);
    - -
    1530 w,
    -
    1531 scales,
    -
    1532 biases,
    -
    1533 x,
    -
    1534 y,
    -
    1535 in_vec_size,
    -
    1536 out_vec_size,
    -
    1537 tid,
    -
    1538 simd_gid,
    -
    1539 simd_lid);
    -
    1540}
    +
    1474
    +
    1475template <typename T, int group_size, int bits>
    +
    +
    1476[[kernel]] void bs_qmv_fast(
    +
    1477 const device uint32_t* w [[buffer(0)]],
    +
    1478 const device T* scales [[buffer(1)]],
    +
    1479 const device T* biases [[buffer(2)]],
    +
    1480 const device T* x [[buffer(3)]],
    +
    1481 device T* y [[buffer(4)]],
    +
    1482 const constant int& in_vec_size [[buffer(5)]],
    +
    1483 const constant int& out_vec_size [[buffer(6)]],
    +
    1484 const constant int& x_batch_ndims [[buffer(7)]],
    +
    1485 const constant int* x_shape [[buffer(8)]],
    +
    1486 const constant size_t* x_strides [[buffer(9)]],
    +
    1487 const constant int& w_batch_ndims [[buffer(10)]],
    +
    1488 const constant int* w_shape [[buffer(11)]],
    +
    1489 const constant size_t* w_strides [[buffer(12)]],
    +
    1490 const constant size_t* s_strides [[buffer(13)]],
    +
    1491 const constant size_t* b_strides [[buffer(14)]],
    +
    1492 const constant int& batch_ndims [[buffer(15)]],
    +
    1493 const constant int* batch_shape [[buffer(16)]],
    +
    1494 const device uint32_t* lhs_indices [[buffer(17)]],
    +
    1495 const device uint32_t* rhs_indices [[buffer(18)]],
    +
    1496 const constant size_t* lhs_strides [[buffer(19)]],
    +
    1497 const constant size_t* rhs_strides [[buffer(20)]],
    +
    1498 uint3 tid [[threadgroup_position_in_grid]],
    +
    1499 uint simd_gid [[simdgroup_index_in_threadgroup]],
    +
    1500 uint simd_lid [[thread_index_in_simdgroup]]) {
    + +
    1502 x,
    +
    1503 w,
    +
    1504 scales,
    +
    1505 biases,
    +
    1506 lhs_indices,
    +
    1507 rhs_indices,
    +
    1508 y,
    +
    1509 out_vec_size,
    +
    1510 batch_ndims,
    +
    1511 batch_shape,
    +
    1512 lhs_strides,
    +
    1513 rhs_strides,
    +
    1514 x_batch_ndims,
    +
    1515 x_shape,
    +
    1516 x_strides,
    +
    1517 w_batch_ndims,
    +
    1518 w_shape,
    +
    1519 w_strides,
    +
    1520 s_strides,
    +
    1521 b_strides,
    +
    1522 tid);
    + +
    1524 w,
    +
    1525 scales,
    +
    1526 biases,
    +
    1527 x,
    +
    1528 y,
    +
    1529 in_vec_size,
    +
    1530 out_vec_size,
    +
    1531 tid,
    +
    1532 simd_gid,
    +
    1533 simd_lid);
    +
    1534}
    -
    1541
    -
    1542template <typename T, int group_size, int bits>
    -
    -
    1543[[kernel]] void bs_qvm(
    -
    1544 const device uint32_t* w [[buffer(0)]],
    -
    1545 const device T* scales [[buffer(1)]],
    -
    1546 const device T* biases [[buffer(2)]],
    -
    1547 const device T* x [[buffer(3)]],
    -
    1548 device T* y [[buffer(4)]],
    -
    1549 const constant int& in_vec_size [[buffer(5)]],
    -
    1550 const constant int& out_vec_size [[buffer(6)]],
    -
    1551 const constant int& x_batch_ndims [[buffer(7)]],
    -
    1552 const constant int* x_shape [[buffer(8)]],
    -
    1553 const constant size_t* x_strides [[buffer(9)]],
    -
    1554 const constant int& w_batch_ndims [[buffer(10)]],
    -
    1555 const constant int* w_shape [[buffer(11)]],
    -
    1556 const constant size_t* w_strides [[buffer(12)]],
    -
    1557 const constant size_t* s_strides [[buffer(13)]],
    -
    1558 const constant size_t* b_strides [[buffer(14)]],
    -
    1559 const constant int& batch_ndims [[buffer(15)]],
    -
    1560 const constant int* batch_shape [[buffer(16)]],
    -
    1561 const device uint32_t* lhs_indices [[buffer(17)]],
    -
    1562 const device uint32_t* rhs_indices [[buffer(18)]],
    -
    1563 const constant size_t* lhs_strides [[buffer(19)]],
    -
    1564 const constant size_t* rhs_strides [[buffer(20)]],
    -
    1565 uint3 tid [[threadgroup_position_in_grid]],
    -
    1566 uint simd_gid [[simdgroup_index_in_threadgroup]],
    -
    1567 uint simd_lid [[thread_index_in_simdgroup]]) {
    - -
    1569 x,
    -
    1570 w,
    -
    1571 scales,
    -
    1572 biases,
    -
    1573 lhs_indices,
    -
    1574 rhs_indices,
    -
    1575 y,
    -
    1576 out_vec_size,
    -
    1577 batch_ndims,
    -
    1578 batch_shape,
    -
    1579 lhs_strides,
    -
    1580 rhs_strides,
    -
    1581 x_batch_ndims,
    -
    1582 x_shape,
    -
    1583 x_strides,
    -
    1584 w_batch_ndims,
    -
    1585 w_shape,
    -
    1586 w_strides,
    -
    1587 s_strides,
    -
    1588 b_strides,
    -
    1589 tid);
    - -
    1591 w,
    -
    1592 scales,
    -
    1593 biases,
    -
    1594 x,
    -
    1595 y,
    -
    1596 in_vec_size,
    -
    1597 out_vec_size,
    -
    1598 tid,
    -
    1599 simd_gid,
    -
    1600 simd_lid);
    -
    1601}
    +
    1535
    +
    1536template <typename T, int group_size, int bits>
    +
    +
    1537[[kernel]] void bs_qmv(
    +
    1538 const device uint32_t* w [[buffer(0)]],
    +
    1539 const device T* scales [[buffer(1)]],
    +
    1540 const device T* biases [[buffer(2)]],
    +
    1541 const device T* x [[buffer(3)]],
    +
    1542 device T* y [[buffer(4)]],
    +
    1543 const constant int& in_vec_size [[buffer(5)]],
    +
    1544 const constant int& out_vec_size [[buffer(6)]],
    +
    1545 const constant int& x_batch_ndims [[buffer(7)]],
    +
    1546 const constant int* x_shape [[buffer(8)]],
    +
    1547 const constant size_t* x_strides [[buffer(9)]],
    +
    1548 const constant int& w_batch_ndims [[buffer(10)]],
    +
    1549 const constant int* w_shape [[buffer(11)]],
    +
    1550 const constant size_t* w_strides [[buffer(12)]],
    +
    1551 const constant size_t* s_strides [[buffer(13)]],
    +
    1552 const constant size_t* b_strides [[buffer(14)]],
    +
    1553 const constant int& batch_ndims [[buffer(15)]],
    +
    1554 const constant int* batch_shape [[buffer(16)]],
    +
    1555 const device uint32_t* lhs_indices [[buffer(17)]],
    +
    1556 const device uint32_t* rhs_indices [[buffer(18)]],
    +
    1557 const constant size_t* lhs_strides [[buffer(19)]],
    +
    1558 const constant size_t* rhs_strides [[buffer(20)]],
    +
    1559 uint3 tid [[threadgroup_position_in_grid]],
    +
    1560 uint simd_gid [[simdgroup_index_in_threadgroup]],
    +
    1561 uint simd_lid [[thread_index_in_simdgroup]]) {
    + +
    1563 x,
    +
    1564 w,
    +
    1565 scales,
    +
    1566 biases,
    +
    1567 lhs_indices,
    +
    1568 rhs_indices,
    +
    1569 y,
    +
    1570 out_vec_size,
    +
    1571 batch_ndims,
    +
    1572 batch_shape,
    +
    1573 lhs_strides,
    +
    1574 rhs_strides,
    +
    1575 x_batch_ndims,
    +
    1576 x_shape,
    +
    1577 x_strides,
    +
    1578 w_batch_ndims,
    +
    1579 w_shape,
    +
    1580 w_strides,
    +
    1581 s_strides,
    +
    1582 b_strides,
    +
    1583 tid);
    + +
    1585 w,
    +
    1586 scales,
    +
    1587 biases,
    +
    1588 x,
    +
    1589 y,
    +
    1590 in_vec_size,
    +
    1591 out_vec_size,
    +
    1592 tid,
    +
    1593 simd_gid,
    +
    1594 simd_lid);
    +
    1595}
    -
    1602
    -
    1603template <
    -
    1604 typename T,
    -
    1605 const int group_size,
    -
    1606 const int bits,
    -
    1607 const bool aligned_N,
    -
    1608 const int BM = 32,
    -
    1609 const int BK = 32,
    -
    1610 const int BN = 32>
    -
    -
    1611[[kernel]] void bs_qmm_t(
    -
    1612 const device uint32_t* w [[buffer(0)]],
    -
    1613 const device T* scales [[buffer(1)]],
    -
    1614 const device T* biases [[buffer(2)]],
    -
    1615 const device T* x [[buffer(3)]],
    -
    1616 device T* y [[buffer(4)]],
    -
    1617 const constant int& K [[buffer(5)]],
    -
    1618 const constant int& N [[buffer(6)]],
    -
    1619 const constant int& M [[buffer(7)]],
    -
    1620 const constant int& x_batch_ndims [[buffer(8)]],
    -
    1621 const constant int* x_shape [[buffer(9)]],
    -
    1622 const constant size_t* x_strides [[buffer(10)]],
    -
    1623 const constant int& w_batch_ndims [[buffer(11)]],
    -
    1624 const constant int* w_shape [[buffer(12)]],
    -
    1625 const constant size_t* w_strides [[buffer(13)]],
    -
    1626 const constant size_t* s_strides [[buffer(14)]],
    -
    1627 const constant size_t* b_strides [[buffer(15)]],
    -
    1628 const constant int& batch_ndims [[buffer(16)]],
    -
    1629 const constant int* batch_shape [[buffer(17)]],
    -
    1630 const device uint32_t* lhs_indices [[buffer(18)]],
    -
    1631 const device uint32_t* rhs_indices [[buffer(19)]],
    -
    1632 const constant size_t* lhs_strides [[buffer(20)]],
    -
    1633 const constant size_t* rhs_strides [[buffer(21)]],
    -
    1634 uint3 tid [[threadgroup_position_in_grid]],
    -
    1635 uint lid [[thread_index_in_threadgroup]],
    -
    1636 uint simd_gid [[simdgroup_index_in_threadgroup]],
    -
    1637 uint simd_lid [[thread_index_in_simdgroup]]) {
    -
    1638 (void)lid;
    -
    1639
    -
    1640 constexpr int BK_padded = (BK + 16 / sizeof(T));
    -
    1641
    -
    1642 threadgroup T Xs[BM * BK_padded];
    -
    1643 threadgroup T Ws[BN * BK_padded];
    -
    1644
    - -
    1646 x,
    -
    1647 w,
    -
    1648 scales,
    -
    1649 biases,
    -
    1650 lhs_indices,
    -
    1651 rhs_indices,
    -
    1652 y,
    -
    1653 M * N,
    -
    1654 batch_ndims,
    -
    1655 batch_shape,
    -
    1656 lhs_strides,
    -
    1657 rhs_strides,
    -
    1658 x_batch_ndims,
    -
    1659 x_shape,
    -
    1660 x_strides,
    -
    1661 w_batch_ndims,
    -
    1662 w_shape,
    -
    1663 w_strides,
    -
    1664 s_strides,
    -
    1665 b_strides,
    -
    1666 tid);
    - -
    1668 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
    -
    1669}
    +
    1596
    +
    1597template <typename T, int group_size, int bits>
    +
    +
    1598[[kernel]] void bs_qvm(
    +
    1599 const device uint32_t* w [[buffer(0)]],
    +
    1600 const device T* scales [[buffer(1)]],
    +
    1601 const device T* biases [[buffer(2)]],
    +
    1602 const device T* x [[buffer(3)]],
    +
    1603 device T* y [[buffer(4)]],
    +
    1604 const constant int& in_vec_size [[buffer(5)]],
    +
    1605 const constant int& out_vec_size [[buffer(6)]],
    +
    1606 const constant int& x_batch_ndims [[buffer(7)]],
    +
    1607 const constant int* x_shape [[buffer(8)]],
    +
    1608 const constant size_t* x_strides [[buffer(9)]],
    +
    1609 const constant int& w_batch_ndims [[buffer(10)]],
    +
    1610 const constant int* w_shape [[buffer(11)]],
    +
    1611 const constant size_t* w_strides [[buffer(12)]],
    +
    1612 const constant size_t* s_strides [[buffer(13)]],
    +
    1613 const constant size_t* b_strides [[buffer(14)]],
    +
    1614 const constant int& batch_ndims [[buffer(15)]],
    +
    1615 const constant int* batch_shape [[buffer(16)]],
    +
    1616 const device uint32_t* lhs_indices [[buffer(17)]],
    +
    1617 const device uint32_t* rhs_indices [[buffer(18)]],
    +
    1618 const constant size_t* lhs_strides [[buffer(19)]],
    +
    1619 const constant size_t* rhs_strides [[buffer(20)]],
    +
    1620 uint3 tid [[threadgroup_position_in_grid]],
    +
    1621 uint simd_gid [[simdgroup_index_in_threadgroup]],
    +
    1622 uint simd_lid [[thread_index_in_simdgroup]]) {
    + +
    1624 x,
    +
    1625 w,
    +
    1626 scales,
    +
    1627 biases,
    +
    1628 lhs_indices,
    +
    1629 rhs_indices,
    +
    1630 y,
    +
    1631 out_vec_size,
    +
    1632 batch_ndims,
    +
    1633 batch_shape,
    +
    1634 lhs_strides,
    +
    1635 rhs_strides,
    +
    1636 x_batch_ndims,
    +
    1637 x_shape,
    +
    1638 x_strides,
    +
    1639 w_batch_ndims,
    +
    1640 w_shape,
    +
    1641 w_strides,
    +
    1642 s_strides,
    +
    1643 b_strides,
    +
    1644 tid);
    + +
    1646 w,
    +
    1647 scales,
    +
    1648 biases,
    +
    1649 x,
    +
    1650 y,
    +
    1651 in_vec_size,
    +
    1652 out_vec_size,
    +
    1653 tid,
    +
    1654 simd_gid,
    +
    1655 simd_lid);
    +
    1656}
    -
    1670
    -
    1671template <
    -
    1672 typename T,
    -
    1673 const int group_size,
    -
    1674 const int bits,
    -
    1675 const int BM = 32,
    -
    1676 const int BK = 32,
    -
    1677 const int BN = 32>
    -
    -
    1678[[kernel]] void bs_qmm_n(
    -
    1679 const device uint32_t* w [[buffer(0)]],
    -
    1680 const device T* scales [[buffer(1)]],
    -
    1681 const device T* biases [[buffer(2)]],
    -
    1682 const device T* x [[buffer(3)]],
    -
    1683 device T* y [[buffer(4)]],
    -
    1684 const constant int& K [[buffer(5)]],
    -
    1685 const constant int& N [[buffer(6)]],
    -
    1686 const constant int& M [[buffer(7)]],
    -
    1687 const constant int& x_batch_ndims [[buffer(8)]],
    -
    1688 const constant int* x_shape [[buffer(9)]],
    -
    1689 const constant size_t* x_strides [[buffer(10)]],
    -
    1690 const constant int& w_batch_ndims [[buffer(11)]],
    -
    1691 const constant int* w_shape [[buffer(12)]],
    -
    1692 const constant size_t* w_strides [[buffer(13)]],
    -
    1693 const constant size_t* s_strides [[buffer(14)]],
    -
    1694 const constant size_t* b_strides [[buffer(15)]],
    -
    1695 const constant int& batch_ndims [[buffer(16)]],
    -
    1696 const constant int* batch_shape [[buffer(17)]],
    -
    1697 const device uint32_t* lhs_indices [[buffer(18)]],
    -
    1698 const device uint32_t* rhs_indices [[buffer(19)]],
    -
    1699 const constant size_t* lhs_strides [[buffer(20)]],
    -
    1700 const constant size_t* rhs_strides [[buffer(21)]],
    -
    1701 uint3 tid [[threadgroup_position_in_grid]],
    -
    1702 uint lid [[thread_index_in_threadgroup]],
    -
    1703 uint simd_gid [[simdgroup_index_in_threadgroup]],
    -
    1704 uint simd_lid [[thread_index_in_simdgroup]]) {
    -
    1705 (void)lid;
    -
    1706
    -
    1707 constexpr int BK_padded = (BK + 16 / sizeof(T));
    -
    1708 constexpr int BN_padded = (BN + 16 / sizeof(T));
    -
    1709
    -
    1710 threadgroup T Xs[BM * BK_padded];
    -
    1711 threadgroup T Ws[BK * BN_padded];
    -
    1712
    - -
    1714 x,
    -
    1715 w,
    -
    1716 scales,
    -
    1717 biases,
    -
    1718 lhs_indices,
    -
    1719 rhs_indices,
    -
    1720 y,
    -
    1721 M * N,
    -
    1722 batch_ndims,
    -
    1723 batch_shape,
    -
    1724 lhs_strides,
    -
    1725 rhs_strides,
    -
    1726 x_batch_ndims,
    -
    1727 x_shape,
    -
    1728 x_strides,
    -
    1729 w_batch_ndims,
    -
    1730 w_shape,
    -
    1731 w_strides,
    -
    1732 s_strides,
    -
    1733 b_strides,
    -
    1734 tid);
    - -
    1736 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
    -
    1737}
    +
    1657
    +
    1658template <
    +
    1659 typename T,
    +
    1660 const int group_size,
    +
    1661 const int bits,
    +
    1662 const bool aligned_N,
    +
    1663 const int BM = 32,
    +
    1664 const int BK = 32,
    +
    1665 const int BN = 32>
    +
    +
    1666[[kernel]] void bs_qmm_t(
    +
    1667 const device uint32_t* w [[buffer(0)]],
    +
    1668 const device T* scales [[buffer(1)]],
    +
    1669 const device T* biases [[buffer(2)]],
    +
    1670 const device T* x [[buffer(3)]],
    +
    1671 device T* y [[buffer(4)]],
    +
    1672 const constant int& K [[buffer(5)]],
    +
    1673 const constant int& N [[buffer(6)]],
    +
    1674 const constant int& M [[buffer(7)]],
    +
    1675 const constant int& x_batch_ndims [[buffer(8)]],
    +
    1676 const constant int* x_shape [[buffer(9)]],
    +
    1677 const constant size_t* x_strides [[buffer(10)]],
    +
    1678 const constant int& w_batch_ndims [[buffer(11)]],
    +
    1679 const constant int* w_shape [[buffer(12)]],
    +
    1680 const constant size_t* w_strides [[buffer(13)]],
    +
    1681 const constant size_t* s_strides [[buffer(14)]],
    +
    1682 const constant size_t* b_strides [[buffer(15)]],
    +
    1683 const constant int& batch_ndims [[buffer(16)]],
    +
    1684 const constant int* batch_shape [[buffer(17)]],
    +
    1685 const device uint32_t* lhs_indices [[buffer(18)]],
    +
    1686 const device uint32_t* rhs_indices [[buffer(19)]],
    +
    1687 const constant size_t* lhs_strides [[buffer(20)]],
    +
    1688 const constant size_t* rhs_strides [[buffer(21)]],
    +
    1689 uint3 tid [[threadgroup_position_in_grid]],
    +
    1690 uint lid [[thread_index_in_threadgroup]],
    +
    1691 uint simd_gid [[simdgroup_index_in_threadgroup]],
    +
    1692 uint simd_lid [[thread_index_in_simdgroup]]) {
    +
    1693 (void)lid;
    +
    1694
    +
    1695 constexpr int BK_padded = (BK + 16 / sizeof(T));
    +
    1696
    +
    1697 threadgroup T Xs[BM * BK_padded];
    +
    1698 threadgroup T Ws[BN * BK_padded];
    +
    1699
    + +
    1701 x,
    +
    1702 w,
    +
    1703 scales,
    +
    1704 biases,
    +
    1705 lhs_indices,
    +
    1706 rhs_indices,
    +
    1707 y,
    +
    1708 M * N,
    +
    1709 batch_ndims,
    +
    1710 batch_shape,
    +
    1711 lhs_strides,
    +
    1712 rhs_strides,
    +
    1713 x_batch_ndims,
    +
    1714 x_shape,
    +
    1715 x_strides,
    +
    1716 w_batch_ndims,
    +
    1717 w_shape,
    +
    1718 w_strides,
    +
    1719 s_strides,
    +
    1720 b_strides,
    +
    1721 tid);
    + +
    1723 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
    +
    1724}
    -
    1738
    -
    1739template <typename T, const int group_size, const int bits>
    -
    -
    1740[[kernel]] void affine_quantize(
    -
    1741 const device T* w [[buffer(0)]],
    -
    1742 device uint8_t* out [[buffer(1)]],
    -
    1743 device T* scales [[buffer(2)]],
    -
    1744 device T* biases [[buffer(3)]],
    -
    1745 uint2 index [[thread_position_in_grid]],
    -
    1746 uint2 grid_dim [[threads_per_grid]]) {
    -
    1747 constexpr T eps = T(1e-7);
    -
    1748 constexpr int simd_size = 32;
    -
    1749 constexpr int uint8_bits = 8;
    -
    1750 constexpr T n_bins = (1 << bits) - 1;
    -
    1751 constexpr int packs_per_int = uint8_bits / bits;
    -
    1752 constexpr int values_per_reduce = group_size / simd_size;
    -
    1753 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
    -
    1754 constexpr int writes_per_pack =
    -
    1755 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
    -
    1756
    -
    1757 static_assert(
    -
    1758 group_size % simd_size == 0,
    -
    1759 "Group size must be divisible by simd size.");
    -
    1760
    -
    1761 size_t offset = index.x + grid_dim.x * size_t(index.y);
    -
    1762 size_t in_index = offset * values_per_reduce;
    -
    1763 size_t out_index = offset * writes_per_pack;
    +
    1725
    +
    1726template <
    +
    1727 typename T,
    +
    1728 const int group_size,
    +
    1729 const int bits,
    +
    1730 const int BM = 32,
    +
    1731 const int BK = 32,
    +
    1732 const int BN = 32>
    +
    +
    1733[[kernel]] void bs_qmm_n(
    +
    1734 const device uint32_t* w [[buffer(0)]],
    +
    1735 const device T* scales [[buffer(1)]],
    +
    1736 const device T* biases [[buffer(2)]],
    +
    1737 const device T* x [[buffer(3)]],
    +
    1738 device T* y [[buffer(4)]],
    +
    1739 const constant int& K [[buffer(5)]],
    +
    1740 const constant int& N [[buffer(6)]],
    +
    1741 const constant int& M [[buffer(7)]],
    +
    1742 const constant int& x_batch_ndims [[buffer(8)]],
    +
    1743 const constant int* x_shape [[buffer(9)]],
    +
    1744 const constant size_t* x_strides [[buffer(10)]],
    +
    1745 const constant int& w_batch_ndims [[buffer(11)]],
    +
    1746 const constant int* w_shape [[buffer(12)]],
    +
    1747 const constant size_t* w_strides [[buffer(13)]],
    +
    1748 const constant size_t* s_strides [[buffer(14)]],
    +
    1749 const constant size_t* b_strides [[buffer(15)]],
    +
    1750 const constant int& batch_ndims [[buffer(16)]],
    +
    1751 const constant int* batch_shape [[buffer(17)]],
    +
    1752 const device uint32_t* lhs_indices [[buffer(18)]],
    +
    1753 const device uint32_t* rhs_indices [[buffer(19)]],
    +
    1754 const constant size_t* lhs_strides [[buffer(20)]],
    +
    1755 const constant size_t* rhs_strides [[buffer(21)]],
    +
    1756 uint3 tid [[threadgroup_position_in_grid]],
    +
    1757 uint lid [[thread_index_in_threadgroup]],
    +
    1758 uint simd_gid [[simdgroup_index_in_threadgroup]],
    +
    1759 uint simd_lid [[thread_index_in_simdgroup]]) {
    +
    1760 (void)lid;
    +
    1761
    +
    1762 constexpr int BK_padded = (BK + 16 / sizeof(T));
    +
    1763 constexpr int BN_padded = (BN + 16 / sizeof(T));
    1764
    -
    1765 T w_thread[values_per_reduce];
    -
    1766 T w_min = Limits<T>::max;
    -
    1767 T w_max = 0;
    -
    1768
    -
    1769#pragma clang loop unroll(full)
    -
    1770 for (int i = 0; i < values_per_reduce; i++) {
    -
    1771 T val = w[in_index + i];
    -
    1772 w_thread[i] = val;
    -
    1773 w_min = min(w_min, val);
    -
    1774 w_max = max(w_max, val);
    -
    1775 }
    -
    1776
    -
    1777 w_min = simd_min(w_min);
    -
    1778 w_max = simd_max(w_max);
    -
    1779
    -
    1780 T scale = max((w_max - w_min) / n_bins, eps);
    -
    1781 bool side = abs(w_min) > abs(w_max);
    -
    1782 scale = side ? scale : -scale;
    -
    1783 T edge = side ? w_min : w_max;
    -
    1784 T q0 = round(edge / scale);
    -
    1785 bool at_zero = q0 == 0.0f;
    -
    1786 scale = at_zero ? scale : edge / q0;
    -
    1787 T bias = at_zero ? T(0) : edge;
    -
    1788
    -
    1789 // Write out the scales and biases
    -
    1790 size_t gindex = in_index / group_size;
    -
    1791 if (in_index % group_size == 0) {
    -
    1792 scales[gindex] = scale;
    -
    1793 biases[gindex] = bias;
    -
    1794 }
    -
    1795
    -
    1796 uint8_t output = 0;
    -
    1797#pragma clang loop unroll(full)
    -
    1798 for (int i = 0; i < values_per_reduce; i++) {
    -
    1799 uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
    -
    1800 if (bits == 8) {
    -
    1801 output = val;
    -
    1802 } else {
    -
    1803 output += val << (bits * (i % packs_per_int));
    -
    1804 }
    -
    1805
    -
    1806 if (packs_per_int < values_per_reduce &&
    -
    1807 i % packs_per_int == packs_per_int - 1) {
    -
    1808 out[out_index + i / packs_per_int] = output;
    -
    1809 output = 0;
    -
    1810 } else {
    -
    1811#pragma clang loop unroll(full)
    -
    1812 for (int j = 0; j < writes_per_reduce - 1; j++) {
    -
    1813 uint8_t sval = simd_shuffle_down(val, j + 1);
    -
    1814 output += sval << (bits * (values_per_reduce + j + i));
    -
    1815 }
    -
    1816 }
    -
    1817 }
    -
    1818 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
    -
    1819 out[out_index / writes_per_reduce] = output;
    -
    1820 }
    -
    1821}
    +
    1765 threadgroup T Xs[BM * BK_padded];
    +
    1766 threadgroup T Ws[BK * BN_padded];
    +
    1767
    + +
    1769 x,
    +
    1770 w,
    +
    1771 scales,
    +
    1772 biases,
    +
    1773 lhs_indices,
    +
    1774 rhs_indices,
    +
    1775 y,
    +
    1776 M * N,
    +
    1777 batch_ndims,
    +
    1778 batch_shape,
    +
    1779 lhs_strides,
    +
    1780 rhs_strides,
    +
    1781 x_batch_ndims,
    +
    1782 x_shape,
    +
    1783 x_strides,
    +
    1784 w_batch_ndims,
    +
    1785 w_shape,
    +
    1786 w_strides,
    +
    1787 s_strides,
    +
    1788 b_strides,
    +
    1789 tid);
    + +
    1791 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
    +
    1792}
    -
    1822
    -
    1823template <typename T, const int group_size, const int bits>
    -
    - -
    1825 const device T* w [[buffer(0)]],
    -
    1826 const device T* scales [[buffer(1)]],
    -
    1827 const device T* biases [[buffer(2)]],
    -
    1828 device uint8_t* out [[buffer(3)]],
    -
    1829 uint2 index [[thread_position_in_grid]],
    -
    1830 uint2 grid_dim [[threads_per_grid]]) {
    -
    1831 constexpr int uint8_bits = 8;
    -
    1832 constexpr int packs_per_int = uint8_bits / bits;
    -
    1833 constexpr T n_bins = (1 << bits) - 1;
    +
    1793
    +
    1794template <typename T, const int group_size, const int bits>
    +
    +
    1795[[kernel]] void affine_quantize(
    +
    1796 const device T* w [[buffer(0)]],
    +
    1797 device uint8_t* out [[buffer(1)]],
    +
    1798 device T* scales [[buffer(2)]],
    +
    1799 device T* biases [[buffer(3)]],
    +
    1800 uint2 index [[thread_position_in_grid]],
    +
    1801 uint2 grid_dim [[threads_per_grid]]) {
    +
    1802 constexpr T eps = T(1e-7);
    +
    1803 constexpr int simd_size = 32;
    +
    1804 constexpr int uint8_bits = 8;
    +
    1805 constexpr T n_bins = (1 << bits) - 1;
    +
    1806 constexpr int packs_per_int = uint8_bits / bits;
    +
    1807 constexpr int values_per_reduce = group_size / simd_size;
    +
    1808 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
    +
    1809 constexpr int writes_per_pack =
    +
    1810 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
    +
    1811
    +
    1812 static_assert(
    +
    1813 group_size % simd_size == 0,
    +
    1814 "Group size must be divisible by simd size.");
    +
    1815
    +
    1816 size_t offset = index.x + grid_dim.x * size_t(index.y);
    +
    1817 size_t in_index = offset * values_per_reduce;
    +
    1818 size_t out_index = offset * writes_per_pack;
    +
    1819
    +
    1820 T w_thread[values_per_reduce];
    +
    1821 T w_min = Limits<T>::max;
    +
    1822 T w_max = 0;
    +
    1823
    +
    1824#pragma clang loop unroll(full)
    +
    1825 for (int i = 0; i < values_per_reduce; i++) {
    +
    1826 T val = w[in_index + i];
    +
    1827 w_thread[i] = val;
    +
    1828 w_min = min(w_min, val);
    +
    1829 w_max = max(w_max, val);
    +
    1830 }
    +
    1831
    +
    1832 w_min = simd_min(w_min);
    +
    1833 w_max = simd_max(w_max);
    1834
    -
    1835 size_t offset = index.x + grid_dim.x * size_t(index.y);
    -
    1836 size_t in_index = offset * packs_per_int;
    -
    1837 size_t gindex = in_index / group_size;
    -
    1838
    -
    1839 T scale = scales[gindex];
    -
    1840 T bias = biases[gindex];
    -
    1841
    -
    1842 uint8_t output = 0;
    -
    1843#pragma clang loop unroll(full)
    -
    1844 for (int i = 0; i < packs_per_int; i++) {
    -
    1845 uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
    -
    1846 if (bits == 8) {
    -
    1847 output = val;
    -
    1848 } else {
    -
    1849 output += val << (bits * i);
    -
    1850 }
    -
    1851 }
    -
    1852 out[offset] = output;
    -
    1853}
    +
    1835 T scale = max((w_max - w_min) / n_bins, eps);
    +
    1836 bool side = abs(w_min) > abs(w_max);
    +
    1837 scale = side ? scale : -scale;
    +
    1838 T edge = side ? w_min : w_max;
    +
    1839 T q0 = round(edge / scale);
    +
    1840 bool at_zero = q0 == 0.0f;
    +
    1841 scale = at_zero ? scale : edge / q0;
    +
    1842 T bias = at_zero ? T(0) : edge;
    +
    1843
    +
    1844 // Write out the scales and biases
    +
    1845 size_t gindex = in_index / group_size;
    +
    1846 if (in_index % group_size == 0) {
    +
    1847 scales[gindex] = scale;
    +
    1848 biases[gindex] = bias;
    +
    1849 }
    +
    1850
    +
    1851 uint8_t output = 0;
    +
    1852#pragma clang loop unroll(full)
    +
    1853 for (int i = 0; i < values_per_reduce; i++) {
    +
    1854 uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
    +
    1855 if (bits == 8) {
    +
    1856 output = val;
    +
    1857 } else {
    +
    1858 output += val << (bits * (i % packs_per_int));
    +
    1859 }
    +
    1860
    +
    1861 if (packs_per_int < values_per_reduce &&
    +
    1862 i % packs_per_int == packs_per_int - 1) {
    +
    1863 out[out_index + i / packs_per_int] = output;
    +
    1864 output = 0;
    +
    1865 } else {
    +
    1866#pragma clang loop unroll(full)
    +
    1867 for (int j = 0; j < writes_per_reduce - 1; j++) {
    +
    1868 uint8_t sval = simd_shuffle_down(val, j + 1);
    +
    1869 output += sval << (bits * (values_per_reduce + j + i));
    +
    1870 }
    +
    1871 }
    +
    1872 }
    +
    1873 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
    +
    1874 out[out_index / writes_per_reduce] = output;
    +
    1875 }
    +
    1876}
    -
    1854
    -
    1855template <typename T, const int group_size, const int bits>
    -
    -
    1856[[kernel]] void affine_dequantize(
    -
    1857 const device uint8_t* w [[buffer(0)]],
    -
    1858 const device T* scales [[buffer(1)]],
    -
    1859 const device T* biases [[buffer(2)]],
    -
    1860 device T* out [[buffer(3)]],
    -
    1861 uint2 index [[thread_position_in_grid]],
    -
    1862 uint2 grid_dim [[threads_per_grid]]) {
    -
    1863 constexpr int uint8_bits = 8;
    -
    1864 constexpr int packs_per_int = uint8_bits / bits;
    -
    1865
    -
    1866 size_t offset = index.x + grid_dim.x * size_t(index.y);
    -
    1867 size_t oindex = offset * packs_per_int;
    -
    1868 size_t gindex = oindex / group_size;
    -
    1869 T scale = scales[gindex];
    -
    1870 T bias = biases[gindex];
    -
    1871 uint val = w[offset];
    -
    1872
    -
    1873#pragma clang loop unroll(full)
    -
    1874 for (int i = 0; i < packs_per_int; i++) {
    -
    1875 uint8_t d;
    -
    1876 if (bits == 2) {
    -
    1877 d = (val >> (bits * i)) & 0x03;
    -
    1878 } else if (bits == 4) {
    -
    1879 d = (val >> (bits * i)) & 0x0f;
    -
    1880 } else if (bits == 8) {
    -
    1881 d = val;
    -
    1882 }
    -
    1883 out[oindex + i] = scale * d + bias;
    -
    1884 }
    -
    1885}
    +
    1877
    +
    1878template <typename T, const int group_size, const int bits>
    +
    + +
    1880 const device T* w [[buffer(0)]],
    +
    1881 const device T* scales [[buffer(1)]],
    +
    1882 const device T* biases [[buffer(2)]],
    +
    1883 device uint8_t* out [[buffer(3)]],
    +
    1884 uint2 index [[thread_position_in_grid]],
    +
    1885 uint2 grid_dim [[threads_per_grid]]) {
    +
    1886 constexpr int uint8_bits = 8;
    +
    1887 constexpr int packs_per_int = uint8_bits / bits;
    +
    1888 constexpr T n_bins = (1 << bits) - 1;
    +
    1889
    +
    1890 size_t offset = index.x + grid_dim.x * size_t(index.y);
    +
    1891 size_t in_index = offset * packs_per_int;
    +
    1892 size_t gindex = in_index / group_size;
    +
    1893
    +
    1894 T scale = scales[gindex];
    +
    1895 T bias = biases[gindex];
    +
    1896
    +
    1897 uint8_t output = 0;
    +
    1898#pragma clang loop unroll(full)
    +
    1899 for (int i = 0; i < packs_per_int; i++) {
    +
    1900 uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
    +
    1901 if (bits == 8) {
    +
    1902 output = val;
    +
    1903 } else {
    +
    1904 output += val << (bits * i);
    +
    1905 }
    +
    1906 }
    +
    1907 out[offset] = output;
    +
    1908}
    +
    +
    1909
    +
    1910template <typename T, const int group_size, const int bits>
    +
    +
    1911[[kernel]] void affine_dequantize(
    +
    1912 const device uint8_t* w [[buffer(0)]],
    +
    1913 const device T* scales [[buffer(1)]],
    +
    1914 const device T* biases [[buffer(2)]],
    +
    1915 device T* out [[buffer(3)]],
    +
    1916 uint2 index [[thread_position_in_grid]],
    +
    1917 uint2 grid_dim [[threads_per_grid]]) {
    +
    1918 constexpr int uint8_bits = 8;
    +
    1919 constexpr int packs_per_int = uint8_bits / bits;
    +
    1920
    +
    1921 size_t offset = index.x + grid_dim.x * size_t(index.y);
    +
    1922 size_t oindex = offset * packs_per_int;
    +
    1923 size_t gindex = oindex / group_size;
    +
    1924 T scale = scales[gindex];
    +
    1925 T bias = biases[gindex];
    +
    1926 uint val = w[offset];
    +
    1927
    +
    1928#pragma clang loop unroll(full)
    +
    1929 for (int i = 0; i < packs_per_int; i++) {
    +
    1930 uint8_t d;
    +
    1931 if (bits == 2) {
    +
    1932 d = (val >> (bits * i)) & 0x03;
    +
    1933 } else if (bits == 4) {
    +
    1934 d = (val >> (bits * i)) & 0x0f;
    +
    1935 } else if (bits == 8) {
    +
    1936 d = val;
    +
    1937 }
    +
    1938 out[oindex + i] = scale * d + bias;
    +
    1939 }
    +
    1940}
    static constant constexpr const uint8_t simd_size
    Definition ops.h:22
    METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
    Definition utils.h:7
    @@ -2058,28 +2115,29 @@ $(function(){ initResizable(false); });
    #define MLX_MTL_CONST
    Definition quantized.h:8
    U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
    Definition quantized.h:142
    METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
    Definition quantized.h:879
    -
    void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1678
    -
    void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1368
    -
    void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
    Definition quantized.h:1740
    -
    METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:647
    -
    void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1421
    -
    void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
    Definition quantized.h:1856
    +
    METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:647
    +
    void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1733
    +
    void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1423
    +
    void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
    Definition quantized.h:1795
    +
    void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1476
    +
    void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
    Definition quantized.h:1911
    static constant constexpr const int SIMD_SIZE
    Definition quantized.h:10
    void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1200
    -
    void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1543
    -
    void affine_quantize_scales_biases(const device T *w, const device T *scales, const device T *biases, device uint8_t *out, uint2 index, uint2 grid_dim)
    Definition quantized.h:1824
    +
    void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1598
    +
    void affine_quantize_scales_biases(const device T *w, const device T *scales, const device T *biases, device uint8_t *out, uint2 index, uint2 grid_dim)
    Definition quantized.h:1879
    void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1149
    void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
    Definition quantized.h:1098
    static constant constexpr const int QUAD_SIZE
    Definition quantized.h:11
    U load_vector(const device T *x, thread U *x_thread)
    Definition quantized.h:14
    METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:498
    U load_vector_safe(const device T *x, thread U *x_thread, int N)
    Definition quantized.h:52
    -
    void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1611
    +
    void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1666
    U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
    Definition quantized.h:99
    +
    void qvm_split_k(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1302
    METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:434
    -
    void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1310
    +
    void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1365
    METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
    Definition quantized.h:1005
    -
    void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1482
    +
    void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1537
    METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
    Definition quantized.h:376
    void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
    Definition quantized.h:1251
    void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
    Definition quantized.h:187
    diff --git a/docs/build/html/reduce__col_8h.html b/docs/build/html/reduce__col_8h.html index cd6b1d3c6..b7dda2eda 100644 --- a/docs/build/html/reduce__col_8h.html +++ b/docs/build/html/reduce__col_8h.html @@ -98,15 +98,207 @@ $(function(){ initResizable(false); });
    void qvm_split_k (const device uint32_t * w,
    const device T * scales,
    const device T * biases,
    const device T * x,
    device T * y,
    const constant int & out_vec_size,
    const constant int & x_batch_ndims,
    const constant int * x_shape,
    const constant size_t * x_strides,
    const constant int & w_batch_ndims,
    const constant int * w_shape,
    const constant size_t * w_strides,
    const constant size_t * s_strides,
    const constant size_t * b_strides,
    const constant int & final_block_size,
    - - - + + + + + + + + +

    Functions

    template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS>
    void col_reduce_small (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)
     
    template<typename T , typename U , typename Op , int NDIMS>
    void col_reduce_small (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
     
    template<typename T , typename U , typename Op , int NDIMS>
    void col_reduce_longcolumn (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
     
    template<typename T , typename U , typename Op , int NDIMS, int BM, int BN>
    void col_reduce_looped (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
     Our approach is the following simple looped approach:
     
    template<typename T , typename U , typename Op , int NDIMS, int BM, int BN>
    void col_reduce_2pass (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
     

    Function Documentation

    + +

    ◆ col_reduce_2pass()

    + +
    +
    +
    +template<typename T , typename U , typename Op , int NDIMS, int BM, int BN>
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    void col_reduce_2pass (const device T * in,
    device U * out,
    const constant size_t & reduction_size,
    const constant size_t & reduction_stride,
    const constant int * shape,
    const constant size_t * strides,
    const constant int & ndim,
    const constant int * reduce_shape,
    const constant size_t * reduce_strides,
    const constant int & reduce_ndim,
    const constant size_t & non_col_reductions,
    const constant size_t & out_size,
    uint3 gid,
    uint3 gsize,
    uint simd_lane_id,
    uint simd_group_id )
    +
    + +
    +
    + +

    ◆ col_reduce_longcolumn()

    + +
    +
    +
    +template<typename T , typename U , typename Op , int NDIMS>
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    void col_reduce_longcolumn (const device T * in,
    device U * out,
    const constant size_t & reduction_size,
    const constant size_t & reduction_stride,
    const constant int * shape,
    const constant size_t * strides,
    const constant int & ndim,
    const constant int * reduce_shape,
    const constant size_t * reduce_strides,
    const constant int & reduce_ndim,
    const constant size_t & non_col_reductions,
    const constant size_t & out_size,
    uint3 gid,
    uint3 gsize,
    uint3 lid,
    uint3 lsize )
    +
    + +
    +

    ◆ col_reduce_looped()

    @@ -204,13 +396,13 @@ template<typename T , typename U , typename Op , int NDIMS, int BM, int BN>
    - -

    ◆ col_reduce_small()

    + +

    ◆ col_reduce_small()

    -template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS>
    +template<typename T , typename U , typename Op , int NDIMS>
    @@ -280,22 +472,12 @@ template<typename T , typename U , typename Op , int NDIMS, int N_READS = RED - + - - - - - - - - - - - +
    void col_reduce_small
    uint simd_lane_id, uint3 lid,
    uint simd_group_id,
    uint3 tid,
    uint3 tsize )uint3 lsize )
    diff --git a/docs/build/html/reduce__col_8h_source.html b/docs/build/html/reduce__col_8h_source.html index 835b0cd1c..caf4a5f82 100644 --- a/docs/build/html/reduce__col_8h_source.html +++ b/docs/build/html/reduce__col_8h_source.html @@ -93,334 +93,392 @@ $(function(){ initResizable(false); });
    Go to the documentation of this file.
    1// Copyright © 2023-2024 Apple Inc.
    2
    -
    3template <
    -
    4 typename T,
    -
    5 typename U,
    -
    6 typename Op,
    -
    7 int NDIMS,
    -
    8 int N_READS = REDUCE_N_READS>
    -
    -
    9[[kernel]] void col_reduce_small(
    -
    10 const device T* in [[buffer(0)]],
    -
    11 device U* out [[buffer(1)]],
    -
    12 const constant size_t& reduction_size [[buffer(2)]],
    -
    13 const constant size_t& reduction_stride [[buffer(3)]],
    -
    14 const constant int* shape [[buffer(4)]],
    -
    15 const constant size_t* strides [[buffer(5)]],
    -
    16 const constant int& ndim [[buffer(6)]],
    -
    17 const constant int* reduce_shape [[buffer(7)]],
    -
    18 const constant size_t* reduce_strides [[buffer(8)]],
    -
    19 const constant int& reduce_ndim [[buffer(9)]],
    -
    20 const constant size_t& non_col_reductions [[buffer(10)]],
    -
    21 uint3 gid [[threadgroup_position_in_grid]],
    -
    22 uint3 gsize [[threadgroups_per_grid]],
    -
    23 uint simd_lane_id [[thread_index_in_simdgroup]],
    -
    24 uint simd_group_id [[simdgroup_index_in_threadgroup]],
    -
    25 uint3 tid [[thread_position_in_grid]],
    -
    26 uint3 tsize [[threads_per_grid]]) {
    -
    27 Op op;
    - -
    29 const device T* row;
    -
    30
    -
    31 // Case 1: Small row small column
    -
    32 if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
    -
    33 U totals[31];
    -
    34 for (int i = 0; i < 31; i++) {
    -
    35 totals[i] = Op::init;
    -
    36 }
    -
    37
    -
    38 short stride = reduction_stride;
    -
    39 short size = reduction_size;
    -
    40 short blocks = stride / N_READS;
    -
    41 short extra = stride - blocks * N_READS;
    -
    42
    -
    43 size_t out_idx = tid.x + tsize.y * size_t(tid.y);
    -
    44 in += elem_to_loc(out_idx, shape, strides, ndim);
    -
    45
    -
    46 for (uint r = 0; r < non_col_reductions; r++) {
    -
    47 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
    -
    48
    -
    49 for (short i = 0; i < size; i++) {
    -
    50 for (short j = 0; j < blocks; j++) {
    -
    51 for (short k = 0; k < N_READS; k++) {
    -
    52 totals[j * N_READS + k] =
    -
    53 op(totals[j * N_READS + k],
    -
    54 static_cast<U>(row[i * stride + j * N_READS + k]));
    -
    55 }
    -
    56 }
    -
    57 for (short k = 0; k < extra; k++) {
    -
    58 totals[blocks * N_READS + k] =
    -
    59 op(totals[blocks * N_READS + k],
    -
    60 static_cast<U>(row[i * stride + blocks * N_READS + k]));
    -
    61 }
    -
    62 }
    -
    63
    -
    64 loop.next(reduce_shape, reduce_strides);
    -
    65 }
    -
    66 out += out_idx * reduction_stride;
    -
    67 for (short j = 0; j < stride; j++) {
    -
    68 out[j] = totals[j];
    -
    69 }
    -
    70 }
    -
    71
    -
    72 // Case 2: Long row small column
    -
    73 else if (reduction_size * non_col_reductions < 32) {
    -
    74 U totals[N_READS];
    -
    75 for (int i = 0; i < N_READS; i++) {
    -
    76 totals[i] = Op::init;
    -
    77 }
    -
    78
    -
    79 short size = reduction_size;
    -
    80 size_t offset = size_t(tid.x) * N_READS;
    -
    81 bool safe = offset + N_READS <= reduction_stride;
    -
    82 short extra = reduction_stride - offset;
    -
    83
    -
    84 size_t out_idx = tid.y + tsize.z * size_t(tid.z);
    -
    85 in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
    -
    86
    -
    87 for (uint r = 0; r < non_col_reductions; r++) {
    -
    88 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
    -
    89
    -
    90 if (safe) {
    -
    91 for (short i = 0; i < size; i++) {
    -
    92 for (short j = 0; j < N_READS; j++) {
    -
    93 totals[j] =
    -
    94 op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
    -
    95 }
    -
    96 }
    -
    97 } else {
    -
    98 for (short i = 0; i < size; i++) {
    -
    99 for (short j = 0; j < extra; j++) {
    -
    100 totals[j] =
    -
    101 op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
    -
    102 }
    -
    103 }
    -
    104 }
    -
    105
    -
    106 loop.next(reduce_shape, reduce_strides);
    -
    107 }
    -
    108 out += out_idx * reduction_stride + offset;
    -
    109 if (safe) {
    -
    110 for (short i = 0; i < N_READS; i++) {
    -
    111 out[i] = totals[i];
    -
    112 }
    -
    113 } else {
    -
    114 for (short i = 0; i < extra; i++) {
    -
    115 out[i] = totals[i];
    -
    116 }
    -
    117 }
    -
    118 }
    -
    119
    -
    120 // Case 3: Long row medium column
    -
    121 else {
    -
    122 threadgroup U shared_vals[1024];
    -
    123 U totals[N_READS];
    -
    124 for (int i = 0; i < N_READS; i++) {
    -
    125 totals[i] = Op::init;
    -
    126 }
    -
    127
    -
    128 short stride = reduction_stride;
    -
    129 short lid = simd_group_id * simd_size + simd_lane_id;
    -
    130 short2 tile((stride + N_READS - 1) / N_READS, 32);
    -
    131 short2 offset((lid % tile.x) * N_READS, lid / tile.x);
    -
    132 short sm_stride = tile.x * N_READS;
    -
    133 bool safe = offset.x + N_READS <= stride;
    -
    134
    -
    135 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
    -
    136 in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
    -
    137
    -
    138 // Read cooperatively and contiguously and aggregate the partial results.
    -
    139 size_t total = non_col_reductions * reduction_size;
    -
    140 loop.next(offset.y, reduce_shape, reduce_strides);
    -
    141 for (size_t r = offset.y; r < total; r += simd_size) {
    -
    142 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
    -
    143
    -
    144 if (safe) {
    -
    145 for (int i = 0; i < N_READS; i++) {
    -
    146 totals[i] = op(static_cast<U>(row[i]), totals[i]);
    -
    147 }
    -
    148 } else {
    -
    149 U vals[N_READS];
    -
    150 for (int i = 0; i < N_READS; i++) {
    -
    151 vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
    -
    152 }
    -
    153 for (int i = 0; i < N_READS; i++) {
    -
    154 totals[i] = op(vals[i], totals[i]);
    -
    155 }
    -
    156 }
    -
    157
    -
    158 loop.next(simd_size, reduce_shape, reduce_strides);
    -
    159 }
    -
    160
    -
    161 // Each thread holds N_READS partial results but the simdgroups are not
    -
    162 // aligned to do the reduction across the simdgroup so we write our results
    -
    163 // in the shared memory and read them back according to the simdgroup.
    -
    164 for (int i = 0; i < N_READS; i++) {
    -
    165 shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
    -
    166 }
    -
    167 threadgroup_barrier(mem_flags::mem_threadgroup);
    -
    168 for (int i = 0; i < N_READS; i++) {
    -
    169 totals[i] = op.simd_reduce(
    -
    170 shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
    -
    171 }
    -
    172
    -
    173 // Write the output.
    -
    174 if (simd_lane_id == 0) {
    -
    175 short column = simd_group_id * N_READS;
    -
    176 out += out_idx * reduction_stride + column;
    -
    177 if (column + N_READS <= stride) {
    -
    178 for (int i = 0; i < N_READS; i++) {
    -
    179 out[i] = totals[i];
    -
    180 }
    -
    181 } else {
    -
    182 for (int i = 0; column + i < stride; i++) {
    -
    183 out[i] = totals[i];
    -
    184 }
    -
    185 }
    -
    186 }
    -
    187 }
    -
    188}
    +
    3template <typename T, typename U, typename Op, int NDIMS>
    +
    +
    4[[kernel]] void col_reduce_small(
    +
    5 const device T* in [[buffer(0)]],
    +
    6 device U* out [[buffer(1)]],
    +
    7 const constant size_t& reduction_size [[buffer(2)]],
    +
    8 const constant size_t& reduction_stride [[buffer(3)]],
    +
    9 const constant int* shape [[buffer(4)]],
    +
    10 const constant size_t* strides [[buffer(5)]],
    +
    11 const constant int& ndim [[buffer(6)]],
    +
    12 const constant int* reduce_shape [[buffer(7)]],
    +
    13 const constant size_t* reduce_strides [[buffer(8)]],
    +
    14 const constant int& reduce_ndim [[buffer(9)]],
    +
    15 const constant size_t& non_col_reductions [[buffer(10)]],
    +
    16 uint3 gid [[threadgroup_position_in_grid]],
    +
    17 uint3 gsize [[threadgroups_per_grid]],
    +
    18 uint3 lid [[thread_position_in_threadgroup]],
    +
    19 uint3 lsize [[threads_per_threadgroup]]) {
    +
    20 constexpr int n_reads = 4;
    +
    21 Op op;
    + +
    23 const device T* row;
    +
    24
    +
    25 U totals[n_reads];
    +
    26 for (int i = 0; i < n_reads; i++) {
    +
    27 totals[i] = Op::init;
    +
    28 }
    +
    29
    +
    30 size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
    +
    31 if (column >= reduction_stride) {
    +
    32 return;
    +
    33 }
    +
    34 bool safe = column + n_reads <= reduction_stride;
    +
    35
    +
    36 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
    +
    37 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
    +
    38 in += in_idx + column;
    +
    39
    +
    40 size_t total_rows = non_col_reductions * reduction_size;
    +
    41 loop.next(lid.y, reduce_shape, reduce_strides);
    +
    42 for (size_t r = lid.y; r < total_rows; r += lsize.y) {
    +
    43 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
    +
    44 if (safe) {
    +
    45 for (int i = 0; i < n_reads; i++) {
    +
    46 totals[i] = op(static_cast<U>(row[i]), totals[i]);
    +
    47 }
    +
    48 } else {
    +
    49 U vals[n_reads];
    +
    50 for (int i = 0; i < n_reads; i++) {
    +
    51 vals[i] =
    +
    52 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
    +
    53 }
    +
    54 for (int i = 0; i < n_reads; i++) {
    +
    55 totals[i] = op(vals[i], totals[i]);
    +
    56 }
    +
    57 }
    +
    58 loop.next(lsize.y, reduce_shape, reduce_strides);
    +
    59 }
    +
    60
    +
    61 if (lsize.y > 1) {
    +
    62 // lsize.y should be <= 8
    +
    63 threadgroup U shared_vals[32 * 8 * n_reads];
    +
    64 for (int i = 0; i < n_reads; i++) {
    +
    65 shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
    +
    66 }
    +
    67 threadgroup_barrier(mem_flags::mem_threadgroup);
    +
    68 if (lid.y == 0) {
    +
    69 for (int i = 0; i < n_reads; i++) {
    +
    70 totals[i] = shared_vals[lid.x * n_reads + i];
    +
    71 }
    +
    72 for (uint j = 1; j < lsize.y; j++) {
    +
    73 for (int i = 0; i < n_reads; i++) {
    +
    74 totals[i] =
    +
    75 op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
    +
    76 totals[i]);
    +
    77 }
    +
    78 }
    +
    79 }
    +
    80 }
    +
    81
    +
    82 if (lid.y == 0) {
    +
    83 out += out_idx * reduction_stride + column;
    +
    84 if (safe) {
    +
    85 for (int i = 0; i < n_reads; i++) {
    +
    86 out[i] = totals[i];
    +
    87 }
    +
    88 } else {
    +
    89 for (int i = 0; column + i < reduction_stride; i++) {
    +
    90 out[i] = totals[i];
    +
    91 }
    +
    92 }
    +
    93 }
    +
    94}
    -
    189
    -
    201template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
    -
    -
    202[[kernel]] void col_reduce_looped(
    -
    203 const device T* in [[buffer(0)]],
    -
    204 device U* out [[buffer(1)]],
    -
    205 const constant size_t& reduction_size [[buffer(2)]],
    -
    206 const constant size_t& reduction_stride [[buffer(3)]],
    -
    207 const constant int* shape [[buffer(4)]],
    -
    208 const constant size_t* strides [[buffer(5)]],
    -
    209 const constant int& ndim [[buffer(6)]],
    -
    210 const constant int* reduce_shape [[buffer(7)]],
    -
    211 const constant size_t* reduce_strides [[buffer(8)]],
    -
    212 const constant int& reduce_ndim [[buffer(9)]],
    -
    213 const constant size_t& non_col_reductions [[buffer(10)]],
    -
    214 uint3 gid [[threadgroup_position_in_grid]],
    -
    215 uint3 gsize [[threadgroups_per_grid]],
    -
    216 uint simd_lane_id [[thread_index_in_simdgroup]],
    -
    217 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
    -
    218 Op op;
    -
    219 constexpr int n_simdgroups = 4;
    -
    220 constexpr short tgp_size = n_simdgroups * simd_size;
    -
    221 constexpr short n_reads = (BM * BN) / tgp_size;
    -
    222 constexpr short n_read_blocks = BN / n_reads;
    -
    223
    -
    224 threadgroup U shared_vals[BN * BM];
    -
    225 U totals[n_reads];
    - -
    227 const device T* row;
    -
    228
    -
    229 for (int i = 0; i < n_reads; i++) {
    -
    230 totals[i] = Op::init;
    -
    231 }
    -
    232
    -
    233 short lid = simd_group_id * simd_size + simd_lane_id;
    -
    234 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
    -
    235 size_t column = BN * gid.x + offset.x;
    -
    236 bool safe = column + n_reads <= reduction_stride;
    -
    237
    -
    238 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
    -
    239 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
    -
    240 in += in_idx + column;
    -
    241
    -
    242 size_t total = non_col_reductions * reduction_size;
    -
    243 loop.next(offset.y, reduce_shape, reduce_strides);
    -
    244 for (size_t r = offset.y; r < total; r += BM) {
    -
    245 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
    -
    246
    -
    247 if (safe) {
    -
    248 for (int i = 0; i < n_reads; i++) {
    -
    249 totals[i] = op(static_cast<U>(row[i]), totals[i]);
    -
    250 }
    -
    251 } else {
    -
    252 U vals[n_reads];
    -
    253 for (int i = 0; i < n_reads; i++) {
    -
    254 vals[i] =
    -
    255 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
    -
    256 }
    -
    257 for (int i = 0; i < n_reads; i++) {
    -
    258 totals[i] = op(vals[i], totals[i]);
    -
    259 }
    -
    260 }
    -
    261
    -
    262 loop.next(BM, reduce_shape, reduce_strides);
    -
    263 }
    -
    264
    -
    265 // We can use a simd reduction to accumulate across BM so each thread writes
    -
    266 // the partial output to SM and then each simdgroup does BN / n_simdgroups
    -
    267 // accumulations.
    -
    268 if (BM == 32) {
    -
    269 constexpr int n_outputs = BN / n_simdgroups;
    -
    270 static_assert(
    -
    271 BM != 32 || n_outputs == n_reads,
    -
    272 "The tile should be selected such that n_outputs == n_reads");
    -
    273 for (int i = 0; i < n_reads; i++) {
    -
    274 shared_vals[offset.y * BN + offset.x + i] = totals[i];
    -
    275 }
    -
    276 threadgroup_barrier(mem_flags::mem_threadgroup);
    -
    277 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
    -
    278 for (int i = 0; i < n_outputs; i++) {
    -
    279 totals[i] =
    -
    280 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
    -
    281 }
    -
    282
    -
    283 // Write the output.
    -
    284 if (simd_lane_id == 0) {
    -
    285 size_t out_column = BN * gid.x + out_offset.x;
    -
    286 out += out_idx * reduction_stride + out_column;
    -
    287 if (out_column + n_outputs <= reduction_stride) {
    -
    288 for (int i = 0; i < n_outputs; i++) {
    -
    289 out[i] = totals[i];
    -
    290 }
    -
    291 } else {
    -
    292 for (int i = 0; out_column + i < reduction_stride; i++) {
    -
    293 out[i] = totals[i];
    -
    294 }
    -
    295 }
    -
    296 }
    -
    297 }
    -
    298
    -
    299 // Each thread holds n_reads partial results. We write them all out to shared
    -
    300 // memory and threads with offset.y == 0 aggregate the columns and write the
    -
    301 // outputs.
    -
    302 else {
    -
    303 short x_block = offset.x / n_reads;
    -
    304 for (int i = 0; i < n_reads; i++) {
    -
    305 shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
    -
    306 }
    -
    307 threadgroup_barrier(mem_flags::mem_threadgroup);
    -
    308 if (offset.y == 0) {
    -
    309 for (int i = 0; i < n_reads; i++) {
    -
    310 for (int j = 1; j < BM; j++) {
    -
    311 totals[i] =
    -
    312 op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
    -
    313 }
    -
    314 }
    -
    315 }
    -
    316
    -
    317 // Write the output.
    -
    318 if (offset.y == 0) {
    -
    319 out += out_idx * reduction_stride + column;
    -
    320 if (safe) {
    -
    321 for (int i = 0; i < n_reads; i++) {
    -
    322 out[i] = totals[i];
    -
    323 }
    -
    324 } else {
    -
    325 for (int i = 0; column + i < reduction_stride; i++) {
    -
    326 out[i] = totals[i];
    -
    327 }
    -
    328 }
    -
    329 }
    -
    330 }
    -
    331}
    +
    95
    +
    96template <typename T, typename U, typename Op, int NDIMS>
    +
    +
    97[[kernel]] void col_reduce_longcolumn(
    +
    98 const device T* in [[buffer(0)]],
    +
    99 device U* out [[buffer(1)]],
    +
    100 const constant size_t& reduction_size [[buffer(2)]],
    +
    101 const constant size_t& reduction_stride [[buffer(3)]],
    +
    102 const constant int* shape [[buffer(4)]],
    +
    103 const constant size_t* strides [[buffer(5)]],
    +
    104 const constant int& ndim [[buffer(6)]],
    +
    105 const constant int* reduce_shape [[buffer(7)]],
    +
    106 const constant size_t* reduce_strides [[buffer(8)]],
    +
    107 const constant int& reduce_ndim [[buffer(9)]],
    +
    108 const constant size_t& non_col_reductions [[buffer(10)]],
    +
    109 const constant size_t& out_size [[buffer(11)]],
    +
    110 uint3 gid [[threadgroup_position_in_grid]],
    +
    111 uint3 gsize [[threadgroups_per_grid]],
    +
    112 uint3 lid [[thread_position_in_threadgroup]],
    +
    113 uint3 lsize [[threads_per_threadgroup]]) {
    +
    114 Op op;
    + +
    116 const device T* row;
    +
    117
    +
    118 size_t out_idx = gid.x + gsize.x * size_t(gid.y);
    +
    119 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
    +
    120 in += in_idx + lid.x;
    +
    121
    +
    122 U total = Op::init;
    +
    123 size_t total_rows = non_col_reductions * reduction_size;
    +
    124 loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
    +
    125 for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
    +
    126 r += lsize.y * gsize.z) {
    +
    127 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
    +
    128 total = op(static_cast<U>(*row), total);
    +
    129 loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
    +
    130 }
    +
    131
    +
    132 threadgroup U shared_vals[32 * 32];
    +
    133 shared_vals[lid.y * lsize.x + lid.x] = total;
    +
    134 threadgroup_barrier(mem_flags::mem_threadgroup);
    +
    135 if (lid.y == 0) {
    +
    136 for (uint i = 1; i < lsize.y; i++) {
    +
    137 total = op(total, shared_vals[i * lsize.x + lid.x]);
    +
    138 }
    +
    139 out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
    +
    140 }
    +
    141}
    +
    +
    142
    +
    154template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
    +
    +
    155[[kernel]] void col_reduce_looped(
    +
    156 const device T* in [[buffer(0)]],
    +
    157 device U* out [[buffer(1)]],
    +
    158 const constant size_t& reduction_size [[buffer(2)]],
    +
    159 const constant size_t& reduction_stride [[buffer(3)]],
    +
    160 const constant int* shape [[buffer(4)]],
    +
    161 const constant size_t* strides [[buffer(5)]],
    +
    162 const constant int& ndim [[buffer(6)]],
    +
    163 const constant int* reduce_shape [[buffer(7)]],
    +
    164 const constant size_t* reduce_strides [[buffer(8)]],
    +
    165 const constant int& reduce_ndim [[buffer(9)]],
    +
    166 const constant size_t& non_col_reductions [[buffer(10)]],
    +
    167 uint3 gid [[threadgroup_position_in_grid]],
    +
    168 uint3 gsize [[threadgroups_per_grid]],
    +
    169 uint simd_lane_id [[thread_index_in_simdgroup]],
    +
    170 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
    +
    171 Op op;
    +
    172 constexpr int n_simdgroups = 8;
    +
    173 constexpr short tgp_size = n_simdgroups * simd_size;
    +
    174 constexpr short n_reads = (BM * BN) / tgp_size;
    +
    175 constexpr short n_read_blocks = BN / n_reads;
    +
    176
    +
    177 threadgroup U shared_vals[BN * BM];
    +
    178 U totals[n_reads];
    + +
    180 const device T* row;
    +
    181
    +
    182 for (int i = 0; i < n_reads; i++) {
    +
    183 totals[i] = Op::init;
    +
    184 }
    +
    185
    +
    186 short lid = simd_group_id * simd_size + simd_lane_id;
    +
    187 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
    +
    188 size_t column = BN * gid.x + offset.x;
    +
    189 bool safe = column + n_reads <= reduction_stride;
    +
    190
    +
    191 size_t out_idx = gid.y + gsize.y * size_t(gid.z);
    +
    192 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
    +
    193 in += in_idx + column;
    +
    194
    +
    195 size_t total = non_col_reductions * reduction_size;
    +
    196 loop.next(offset.y, reduce_shape, reduce_strides);
    +
    197 for (size_t r = offset.y; r < total; r += BM) {
    +
    198 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
    +
    199
    +
    200 if (safe) {
    +
    201 for (int i = 0; i < n_reads; i++) {
    +
    202 totals[i] = op(static_cast<U>(row[i]), totals[i]);
    +
    203 }
    +
    204 } else {
    +
    205 U vals[n_reads];
    +
    206 for (int i = 0; i < n_reads; i++) {
    +
    207 vals[i] =
    +
    208 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
    +
    209 }
    +
    210 for (int i = 0; i < n_reads; i++) {
    +
    211 totals[i] = op(vals[i], totals[i]);
    +
    212 }
    +
    213 }
    +
    214
    +
    215 loop.next(BM, reduce_shape, reduce_strides);
    +
    216 }
    +
    217
    +
    218 // We can use a simd reduction to accumulate across BM so each thread writes
    +
    219 // the partial output to SM and then each simdgroup does BN / n_simdgroups
    +
    220 // accumulations.
    +
    221 if (BM == 32) {
    +
    222 constexpr int n_outputs = BN / n_simdgroups;
    +
    223 static_assert(
    +
    224 BM != 32 || n_outputs == n_reads,
    +
    225 "The tile should be selected such that n_outputs == n_reads");
    +
    226 for (int i = 0; i < n_reads; i++) {
    +
    227 shared_vals[offset.y * BN + offset.x + i] = totals[i];
    +
    228 }
    +
    229 threadgroup_barrier(mem_flags::mem_threadgroup);
    +
    230 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
    +
    231 for (int i = 0; i < n_outputs; i++) {
    +
    232 totals[i] =
    +
    233 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
    +
    234 }
    +
    235
    +
    236 // Write the output.
    +
    237 if (simd_lane_id == 0) {
    +
    238 size_t out_column = BN * gid.x + out_offset.x;
    +
    239 out += out_idx * reduction_stride + out_column;
    +
    240 if (out_column + n_outputs <= reduction_stride) {
    +
    241 for (int i = 0; i < n_outputs; i++) {
    +
    242 out[i] = totals[i];
    +
    243 }
    +
    244 } else {
    +
    245 for (int i = 0; out_column + i < reduction_stride; i++) {
    +
    246 out[i] = totals[i];
    +
    247 }
    +
    248 }
    +
    249 }
    +
    250 }
    +
    251
    +
    252 // Each thread holds n_reads partial results. We write them all out to shared
    +
    253 // memory and threads with offset.y == 0 aggregate the columns and write the
    +
    254 // outputs.
    +
    255 else {
    +
    256 short x_block = offset.x / n_reads;
    +
    257 for (int i = 0; i < n_reads; i++) {
    +
    258 shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
    +
    259 }
    +
    260 threadgroup_barrier(mem_flags::mem_threadgroup);
    +
    261 if (offset.y == 0) {
    +
    262 for (int i = 0; i < n_reads; i++) {
    +
    263 for (int j = 1; j < BM; j++) {
    +
    264 totals[i] =
    +
    265 op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
    +
    266 }
    +
    267 }
    +
    268 }
    +
    269
    +
    270 // Write the output.
    +
    271 if (offset.y == 0) {
    +
    272 out += out_idx * reduction_stride + column;
    +
    273 if (safe) {
    +
    274 for (int i = 0; i < n_reads; i++) {
    +
    275 out[i] = totals[i];
    +
    276 }
    +
    277 } else {
    +
    278 for (int i = 0; column + i < reduction_stride; i++) {
    +
    279 out[i] = totals[i];
    +
    280 }
    +
    281 }
    +
    282 }
    +
    283 }
    +
    284}
    +
    +
    285
    +
    286template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
    +
    +
    287[[kernel]] void col_reduce_2pass(
    +
    288 const device T* in [[buffer(0)]],
    +
    289 device U* out [[buffer(1)]],
    +
    290 const constant size_t& reduction_size [[buffer(2)]],
    +
    291 const constant size_t& reduction_stride [[buffer(3)]],
    +
    292 const constant int* shape [[buffer(4)]],
    +
    293 const constant size_t* strides [[buffer(5)]],
    +
    294 const constant int& ndim [[buffer(6)]],
    +
    295 const constant int* reduce_shape [[buffer(7)]],
    +
    296 const constant size_t* reduce_strides [[buffer(8)]],
    +
    297 const constant int& reduce_ndim [[buffer(9)]],
    +
    298 const constant size_t& non_col_reductions [[buffer(10)]],
    +
    299 const constant size_t& out_size [[buffer(11)]],
    +
    300 uint3 gid [[threadgroup_position_in_grid]],
    +
    301 uint3 gsize [[threadgroups_per_grid]],
    +
    302 uint simd_lane_id [[thread_index_in_simdgroup]],
    +
    303 uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
    +
    304 Op op;
    +
    305 constexpr int n_simdgroups = 8;
    +
    306 constexpr short tgp_size = n_simdgroups * simd_size;
    +
    307 constexpr short n_reads = (BM * BN) / tgp_size;
    +
    308 constexpr short n_read_blocks = BN / n_reads;
    +
    309 constexpr int n_outputs = BN / n_simdgroups;
    +
    310 constexpr short outer_blocks = 32;
    +
    311 static_assert(BM == 32, "BM should be equal to 32");
    +
    312
    +
    313 threadgroup U shared_vals[BN * BM];
    +
    314 U totals[n_reads];
    + +
    316 const device T* row;
    +
    317
    +
    318 for (int i = 0; i < n_reads; i++) {
    +
    319 totals[i] = Op::init;
    +
    320 }
    +
    321
    +
    322 short lid = simd_group_id * simd_size + simd_lane_id;
    +
    323 short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
    +
    324 size_t column = BN * gid.x + offset.x;
    +
    325 bool safe = column + n_reads <= reduction_stride;
    +
    326
    +
    327 size_t full_idx = gid.y + gsize.y * size_t(gid.z);
    +
    328 size_t block_idx = full_idx / out_size;
    +
    329 size_t out_idx = full_idx % out_size;
    +
    330 size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
    +
    331 in += in_idx + column;
    +
    332
    +
    333 size_t total = non_col_reductions * reduction_size;
    +
    334 loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
    +
    335 for (size_t r = offset.y + block_idx * BM; r < total;
    +
    336 r += outer_blocks * BM) {
    +
    337 row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
    +
    338
    +
    339 if (safe) {
    +
    340 for (int i = 0; i < n_reads; i++) {
    +
    341 totals[i] = op(static_cast<U>(row[i]), totals[i]);
    +
    342 }
    +
    343 } else {
    +
    344 U vals[n_reads];
    +
    345 for (int i = 0; i < n_reads; i++) {
    +
    346 vals[i] =
    +
    347 (column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
    +
    348 }
    +
    349 for (int i = 0; i < n_reads; i++) {
    +
    350 totals[i] = op(vals[i], totals[i]);
    +
    351 }
    +
    352 }
    +
    353
    +
    354 loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
    +
    355 }
    +
    356
    +
    357 // We can use a simd reduction to accumulate across BM so each thread writes
    +
    358 // the partial output to SM and then each simdgroup does BN / n_simdgroups
    +
    359 // accumulations.
    +
    360 for (int i = 0; i < n_reads; i++) {
    +
    361 shared_vals[offset.y * BN + offset.x + i] = totals[i];
    +
    362 }
    +
    363 threadgroup_barrier(mem_flags::mem_threadgroup);
    +
    364 short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
    +
    365 for (int i = 0; i < n_outputs; i++) {
    +
    366 totals[i] =
    +
    367 op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
    +
    368 }
    +
    369
    +
    370 // Write the output.
    +
    371 if (simd_lane_id == 0) {
    +
    372 size_t out_column = BN * gid.x + out_offset.x;
    +
    373 out += full_idx * reduction_stride + out_column;
    +
    374 if (out_column + n_outputs <= reduction_stride) {
    +
    375 for (int i = 0; i < n_outputs; i++) {
    +
    376 out[i] = totals[i];
    +
    377 }
    +
    378 } else {
    +
    379 for (int i = 0; out_column + i < reduction_stride; i++) {
    +
    380 out[i] = totals[i];
    +
    381 }
    +
    382 }
    +
    383 }
    +
    384}
    static constant constexpr const uint8_t simd_size
    Definition ops.h:22
    METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
    Definition utils.h:87
    Op op
    Definition binary.h:129
    -
    static constexpr int REDUCE_N_READS
    Definition defines.h:12
    -
    void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
    Our approach is the following simple looped approach:
    Definition reduce_col.h:202
    -
    void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)
    Definition reduce_col.h:9
    +
    void col_reduce_2pass(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
    Definition reduce_col.h:287
    +
    void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
    Our approach is the following simple looped approach:
    Definition reduce_col.h:155
    +
    void col_reduce_longcolumn(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
    Definition reduce_col.h:97
    +
    void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
    Definition reduce_col.h:4
    Definition utils.h:197
    void next(const constant int *shape, const constant size_t *strides)
    Definition utils.h:202
    offset_t location(offset_t, const constant int *, const constant size_t *, int)
    Definition utils.h:229
    diff --git a/docs/build/html/sdpa__vector_8h.html b/docs/build/html/sdpa__vector_8h.html index df9d5fd42..daee8ad0a 100644 --- a/docs/build/html/sdpa__vector_8h.html +++ b/docs/build/html/sdpa__vector_8h.html @@ -99,13 +99,13 @@ $(function(){ initResizable(false); }); - - - + + +

    Functions

    template<typename T , int D>
    void sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
     
    template<typename T , int D>
    void sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
     

    Function Documentation

    - -

    ◆ sdpa_vector()

    + +

    ◆ sdpa_vector()

    @@ -147,6 +147,11 @@ template<typename T , int D>
    const constant size_t & k_stride, + + + + const constant size_t & v_stride, + diff --git a/docs/build/html/sdpa__vector_8h_source.html b/docs/build/html/sdpa__vector_8h_source.html index 2d695b09b..390193135 100644 --- a/docs/build/html/sdpa__vector_8h_source.html +++ b/docs/build/html/sdpa__vector_8h_source.html @@ -99,7 +99,7 @@ $(function(){ initResizable(false); });
    6
    7template <typename T, int D>
    -
    8[[kernel]] void sdpa_vector(
    +
    8[[kernel]] void sdpa_vector(
    9 const device T* queries [[buffer(0)]],
    10 const device T* keys [[buffer(1)]],
    11 const device T* values [[buffer(2)]],
    @@ -107,113 +107,114 @@ $(function(){ initResizable(false); });
    13 const constant int& gqa_factor,
    14 const constant int& N,
    15 const constant size_t& k_stride,
    -
    16 const constant float& scale,
    -
    17 uint3 tid [[threadgroup_position_in_grid]],
    -
    18 uint simd_gid [[simdgroup_index_in_threadgroup]],
    -
    19 uint simd_lid [[thread_index_in_simdgroup]]) {
    -
    20 constexpr int BN = 32;
    -
    21 constexpr int BD = 32;
    -
    22 constexpr int elem_per_thread = D / BD;
    -
    23
    -
    24 const int stride = BN * D;
    -
    25
    -
    26 typedef float U;
    -
    27
    -
    28 thread U q[elem_per_thread];
    -
    29 thread U k[elem_per_thread];
    -
    30 thread U o[elem_per_thread];
    -
    31
    -
    32 threadgroup U outputs[BN * BD];
    -
    33 threadgroup U max_scores[BN];
    -
    34 threadgroup U sum_exp_scores[BN];
    -
    35
    -
    36 // Adjust positions
    -
    37 const int head_idx = tid.y;
    -
    38 const int kv_head_idx = head_idx / gqa_factor;
    -
    39 queries += head_idx * D + simd_lid * elem_per_thread;
    -
    40 keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
    -
    41 values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
    -
    42 out += head_idx * D + simd_gid * elem_per_thread;
    -
    43
    -
    44 // Read the query and 0 the output accumulator
    -
    45 for (int i = 0; i < elem_per_thread; i++) {
    -
    46 q[i] = static_cast<U>(scale) * queries[i];
    -
    47 }
    -
    48 for (int i = 0; i < elem_per_thread; i++) {
    -
    49 o[i] = 0;
    -
    50 }
    -
    51
    -
    52 U max_score = -INFINITY;
    -
    53 U sum_exp_score = 0;
    -
    54
    -
    55 // For each key
    -
    56 for (int i = simd_gid; i < N; i += BN) {
    -
    57 // Read the key
    -
    58 for (int i = 0; i < elem_per_thread; i++) {
    -
    59 k[i] = keys[i];
    -
    60 }
    -
    61
    -
    62 // Compute the i-th score
    -
    63 U score = 0;
    -
    64 for (int i = 0; i < elem_per_thread; i++) {
    -
    65 score += q[i] * k[i];
    -
    66 }
    -
    67 score = simd_sum(score);
    -
    68
    -
    69 // Update the accumulators
    -
    70 U new_max = max(max_score, score);
    -
    71 U factor = fast::exp(max_score - new_max);
    -
    72 U exp_score = fast::exp(score - new_max);
    -
    73
    -
    74 max_score = new_max;
    -
    75 sum_exp_score = sum_exp_score * factor + exp_score;
    -
    76
    -
    77 // Update the output accumulator
    -
    78 for (int i = 0; i < elem_per_thread; i++) {
    -
    79 o[i] = o[i] * factor + exp_score * values[i];
    -
    80 }
    -
    81
    -
    82 // Move the pointers to the next kv
    -
    83 keys += stride;
    -
    84 values += stride;
    -
    85 }
    -
    86 threadgroup_barrier(mem_flags::mem_threadgroup);
    -
    87
    -
    88 // Each thread has a partial part of the output so we need to combine them.
    -
    89
    -
    90 // First let's communicate the max and sum_exp
    -
    91 if (simd_lid == 0) {
    -
    92 max_scores[simd_gid] = max_score;
    -
    93 sum_exp_scores[simd_gid] = sum_exp_score;
    -
    94 }
    -
    95 threadgroup_barrier(mem_flags::mem_threadgroup);
    -
    96 max_score = max_scores[simd_lid];
    -
    97 U new_max = simd_max(max_score);
    -
    98 U factor = fast::exp(max_score - new_max);
    -
    99 sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
    -
    100
    -
    101 // Now we need to aggregate all the outputs
    -
    102 for (int i = 0; i < elem_per_thread; i++) {
    -
    103 outputs[simd_lid * BD + simd_gid] = o[i];
    -
    104 threadgroup_barrier(mem_flags::mem_threadgroup);
    -
    105 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
    -
    106 threadgroup_barrier(mem_flags::mem_threadgroup);
    -
    107 }
    -
    108
    -
    109 // And write the output
    -
    110 if (simd_lid == 0) {
    -
    111 for (int i = 0; i < elem_per_thread; i++) {
    -
    112 out[i] = static_cast<T>(o[i]);
    -
    113 }
    -
    114 }
    -
    115}
    +
    16 const constant size_t& v_stride,
    +
    17 const constant float& scale,
    +
    18 uint3 tid [[threadgroup_position_in_grid]],
    +
    19 uint simd_gid [[simdgroup_index_in_threadgroup]],
    +
    20 uint simd_lid [[thread_index_in_simdgroup]]) {
    +
    21 constexpr int BN = 32;
    +
    22 constexpr int BD = 32;
    +
    23 constexpr int elem_per_thread = D / BD;
    +
    24
    +
    25 const int stride = BN * D;
    +
    26
    +
    27 typedef float U;
    +
    28
    +
    29 thread U q[elem_per_thread];
    +
    30 thread U k[elem_per_thread];
    +
    31 thread U o[elem_per_thread];
    +
    32
    +
    33 threadgroup U outputs[BN * BD];
    +
    34 threadgroup U max_scores[BN];
    +
    35 threadgroup U sum_exp_scores[BN];
    +
    36
    +
    37 // Adjust positions
    +
    38 const int head_idx = tid.y;
    +
    39 const int kv_head_idx = head_idx / gqa_factor;
    +
    40 queries += head_idx * D + simd_lid * elem_per_thread;
    +
    41 keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
    +
    42 values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
    +
    43 out += head_idx * D + simd_gid * elem_per_thread;
    +
    44
    +
    45 // Read the query and 0 the output accumulator
    +
    46 for (int i = 0; i < elem_per_thread; i++) {
    +
    47 q[i] = static_cast<U>(scale) * queries[i];
    +
    48 }
    +
    49 for (int i = 0; i < elem_per_thread; i++) {
    +
    50 o[i] = 0;
    +
    51 }
    +
    52
    +
    53 U max_score = -INFINITY;
    +
    54 U sum_exp_score = 0;
    +
    55
    +
    56 // For each key
    +
    57 for (int i = simd_gid; i < N; i += BN) {
    +
    58 // Read the key
    +
    59 for (int i = 0; i < elem_per_thread; i++) {
    +
    60 k[i] = keys[i];
    +
    61 }
    +
    62
    +
    63 // Compute the i-th score
    +
    64 U score = 0;
    +
    65 for (int i = 0; i < elem_per_thread; i++) {
    +
    66 score += q[i] * k[i];
    +
    67 }
    +
    68 score = simd_sum(score);
    +
    69
    +
    70 // Update the accumulators
    +
    71 U new_max = max(max_score, score);
    +
    72 U factor = fast::exp(max_score - new_max);
    +
    73 U exp_score = fast::exp(score - new_max);
    +
    74
    +
    75 max_score = new_max;
    +
    76 sum_exp_score = sum_exp_score * factor + exp_score;
    +
    77
    +
    78 // Update the output accumulator
    +
    79 for (int i = 0; i < elem_per_thread; i++) {
    +
    80 o[i] = o[i] * factor + exp_score * values[i];
    +
    81 }
    +
    82
    +
    83 // Move the pointers to the next kv
    +
    84 keys += stride;
    +
    85 values += stride;
    +
    86 }
    +
    87 threadgroup_barrier(mem_flags::mem_threadgroup);
    +
    88
    +
    89 // Each thread has a partial part of the output so we need to combine them.
    +
    90
    +
    91 // First let's communicate the max and sum_exp
    +
    92 if (simd_lid == 0) {
    +
    93 max_scores[simd_gid] = max_score;
    +
    94 sum_exp_scores[simd_gid] = sum_exp_score;
    +
    95 }
    +
    96 threadgroup_barrier(mem_flags::mem_threadgroup);
    +
    97 max_score = max_scores[simd_lid];
    +
    98 U new_max = simd_max(max_score);
    +
    99 U factor = fast::exp(max_score - new_max);
    +
    100 sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
    +
    101
    +
    102 // Now we need to aggregate all the outputs
    +
    103 for (int i = 0; i < elem_per_thread; i++) {
    +
    104 outputs[simd_lid * BD + simd_gid] = o[i];
    +
    105 threadgroup_barrier(mem_flags::mem_threadgroup);
    +
    106 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
    +
    107 threadgroup_barrier(mem_flags::mem_threadgroup);
    +
    108 }
    +
    109
    +
    110 // And write the output
    +
    111 if (simd_lid == 0) {
    +
    112 for (int i = 0; i < elem_per_thread; i++) {
    +
    113 out[i] = static_cast<T>(o[i]);
    +
    114 }
    +
    115 }
    +
    116}
    METAL_FUNC bfloat16_t exp(bfloat16_t x)
    Definition bf16_math.h:242
    Definition bf16.h:265
    METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
    Definition bf16_math.h:392
    METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
    Definition bf16_math.h:392
    METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
    Definition bf16_math.h:234
    -
    void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
    Definition sdpa_vector.h:8
    +
    void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
    Definition sdpa_vector.h:8
    +
    +
    + +

    ◆ maybeInsertBarrier()

    + +
    +
    + + + + + + + +
    void mlx::core::metal::CommandEncoder::maybeInsertBarrier ()
    +
    +
    diff --git a/docs/build/html/usage/function_transforms.html b/docs/build/html/usage/function_transforms.html index 1b08af1eb..e903d77a8 100644 --- a/docs/build/html/usage/function_transforms.html +++ b/docs/build/html/usage/function_transforms.html @@ -986,13 +986,13 @@ We will prioritize including it.

    ys = mx.random.uniform(shape=(100, 4096)) def naive_add(xs, ys): - return [xs[i] + ys[:, i] for i in range(xs.shape[1])] + return [xs[i] + ys[:, i] for i in range(xs.shape[0])]

    Instead you can use vmap() to automatically vectorize the addition:

    # Vectorize over the second dimension of x and the
     # first dimension of y
    -vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
    +vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
     

    The in_axes parameter can be used to specify which dimensions of the diff --git a/docs/build/html/usage/indexing.html b/docs/build/html/usage/indexing.html index 37d46eff8..55e3de841 100644 --- a/docs/build/html/usage/indexing.html +++ b/docs/build/html/usage/indexing.html @@ -922,7 +922,7 @@ undefined behavior.

    from the GPU. Performing bounds checking for array indices before launching the kernel would be extremely inefficient.

    Indexing with boolean masks is something that MLX may support in the future. In -general, MLX has limited support for operations for which outputs +general, MLX has limited support for operations for which output shapes are dependent on input data. Other examples of these types of operations which MLX does not yet support include numpy.nonzero() and the single input version of numpy.where().

    diff --git a/docs/build/html/usage/lazy_evaluation.html b/docs/build/html/usage/lazy_evaluation.html index cac88c1c1..59143b41a 100644 --- a/docs/build/html/usage/lazy_evaluation.html +++ b/docs/build/html/usage/lazy_evaluation.html @@ -952,7 +952,7 @@ stochastic gradient descent). A natural and usually efficient place to use

    An important behavior to be aware of is when the graph will be implicitly evaluated. Anytime you print an array, convert it to an -numpy.ndarray, or otherwise access it’s memory via memoryview, +numpy.ndarray, or otherwise access its memory via memoryview, the graph will be evaluated. Saving arrays via save() (or any other MLX saving functions) will also evaluate the array.

    Calling array.item() on a scalar array will also evaluate it. In the