MLX
Loading...
Searching...
No Matches
gemm.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <Metal/Metal.hpp>
6
7#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
8#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
9
11_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
13_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
16_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
17_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
18} // namespace MTL::Private::Class
19
22 matrixDescriptorWithRows_columns_rowBytes_dataType,
23 "matrixDescriptorWithRows:columns:rowBytes:dataType:");
25 matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
26 "matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
28_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
30 initWithDevice_,
31 "initWithDevice:transposeLeft:transposeRight:"
32 "resultRows:resultColumns:interiorColumns:alpha:beta:");
34 encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
35 "encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
36_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
37_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
38_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
39_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
40_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
42 vectorDescriptorWithLength_dataType,
43 "vectorDescriptorWithLength:dataType:");
45 vectorDescriptorWithLength_vectors_vectorBytes_dataType,
46 "vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
48 initWithDevice_transpose_rows_columns_alpha_beta,
49 "initWithDevice:transpose:rows:columns:alpha:beta:");
51 encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
52 "encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
53} // namespace MTL::Private::Selector
54
55namespace MPS {
56
64
65class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
66 public:
68 NS::UInteger rows,
69 NS::UInteger columns,
70 NS::UInteger rowBytes,
71 NS::UInteger dataType);
73 NS::UInteger rows,
74 NS::UInteger columns,
75 NS::UInteger matrices,
76 NS::UInteger rowBytes,
77 NS::UInteger matrixBytes,
78 NS::UInteger dataType);
79 NS::UInteger rows() const;
80};
81
82class Matrix : public NS::Referencing<Matrix> {
83 public:
84 static class Matrix* alloc();
85 Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
86 Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
87};
88
89class Kernel : public NS::Referencing<Kernel> {
90 public:
91 NS::String* label() const;
92 MTL::Device* device() const;
93};
94
96 : public NS::Referencing<MatrixMultiplication, Kernel> {
97 public:
98 static class MatrixMultiplication* alloc();
99
101 MTL::Device* device,
102 bool transposeLeft,
103 bool transposeRight,
104 NS::UInteger resultRows,
105 NS::UInteger resultColumns,
106 NS::UInteger interiorColumns,
107 double alpha,
108 double beta);
109
111 MTL::CommandBuffer* commandBuffer,
112 Matrix* leftMatrix,
113 Matrix* rightMatrix,
114 Matrix* resultMatrix);
115
116 void setLeftMatrixOrigin(MTL::Origin origin);
117 void setRightMatrixOrigin(MTL::Origin origin);
118 void setResultMatrixOrigin(MTL::Origin origin);
119 void setBatchStart(NS::UInteger batchStart);
120 void setBatchSize(NS::UInteger batchSize);
121};
122
123class VectorDescriptor : public NS::Copying<VectorDescriptor> {
124 public:
126 NS::UInteger length,
127 NS::UInteger dataType);
129 NS::UInteger length,
130 NS::UInteger vectors,
131 NS::UInteger vectorBytes,
132 NS::UInteger dataType);
133};
134
135class Vector : public NS::Referencing<Vector> {
136 public:
137 static class Vector* alloc();
138 Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
139 Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
140};
141
143 : public NS::Referencing<MatrixVectorMultiplication, Kernel> {
144 public:
145 static class MatrixVectorMultiplication* alloc();
146
148 MTL::Device* device,
149 bool transpose,
150 NS::UInteger rows,
151 NS::UInteger columns,
152 double alpha,
153 double beta);
154
156 MTL::CommandBuffer* commandBuffer,
157 Matrix* inputMatrix,
158 Vector* inputVector,
159 Vector* resultVector);
160};
161
163 NS::UInteger rows,
164 NS::UInteger columns,
165 NS::UInteger rowBytes,
166 NS::UInteger dataType) {
167 return Object::sendMessage<MatrixDescriptor*>(
168 _MPS_PRIVATE_CLS(MPSMatrixDescriptor),
169 _MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
170 rows,
171 columns,
172 rowBytes,
173 dataType);
174}
175
177 NS::UInteger rows,
178 NS::UInteger columns,
179 NS::UInteger matrices,
180 NS::UInteger rowBytes,
181 NS::UInteger matrixBytes,
182 NS::UInteger dataType) {
183 return Object::sendMessage<MatrixDescriptor*>(
184 _MPS_PRIVATE_CLS(MPSMatrixDescriptor),
186 matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
187 rows,
188 columns,
189 matrices,
190 rowBytes,
191 matrixBytes,
192 dataType);
193}
194
195_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
196 return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
197}
198
199_MTL_INLINE Matrix* Matrix::alloc() {
200 return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
201}
202
203_MTL_INLINE Matrix* Matrix::init(
204 MTL::Buffer* buffer,
205 MatrixDescriptor* descriptor) {
206 return Object::sendMessage<Matrix*>(
207 this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
208}
209
210_MTL_INLINE Matrix* Matrix::init(
211 const MTL::Buffer* buffer,
212 MatrixDescriptor* descriptor) {
213 return init(const_cast<MTL::Buffer*>(buffer), descriptor);
214}
215
216_MTL_INLINE NS::String* Kernel::label() const {
217 return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
218}
219
220_MTL_INLINE MTL::Device* Kernel::device() const {
221 return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
222}
223
225 return NS::Object::alloc<MatrixMultiplication>(
226 _MPS_PRIVATE_CLS(MPSMatrixMultiplication));
227}
228
230 MTL::Device* device,
231 bool transposeLeft,
232 bool transposeRight,
233 NS::UInteger resultRows,
234 NS::UInteger resultColumns,
235 NS::UInteger interiorColumns,
236 double alpha,
237 double beta) {
238 return Object::sendMessage<MatrixMultiplication*>(
239 this,
240 _MPS_PRIVATE_SEL(initWithDevice_),
241 device,
242 transposeLeft,
243 transposeRight,
244 resultRows,
245 resultColumns,
246 interiorColumns,
247 alpha,
248 beta);
249}
250
252 MTL::CommandBuffer* commandBuffer,
253 Matrix* leftMatrix,
254 Matrix* rightMatrix,
255 Matrix* resultMatrix) {
256 return Object::sendMessage<void>(
257 this,
259 encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
260 commandBuffer,
261 leftMatrix,
262 rightMatrix,
263 resultMatrix);
264}
265
266_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
267 Object::sendMessage<void>(
268 this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
269}
270
272 MTL::Origin origin) {
273 Object::sendMessage<void>(
274 this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
275}
276
278 MTL::Origin origin) {
279 Object::sendMessage<void>(
280 this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
281}
282
283_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
284 Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
285}
286
287_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
288 Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
289}
290
292 NS::UInteger length,
293 NS::UInteger dataType) {
294 return Object::sendMessage<VectorDescriptor*>(
295 _MPS_PRIVATE_CLS(MPSVectorDescriptor),
296 _MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
297 length,
298 dataType);
299}
300
302 NS::UInteger length,
303 NS::UInteger vectors,
304 NS::UInteger vectorBytes,
305 NS::UInteger dataType) {
306 return Object::sendMessage<VectorDescriptor*>(
307 _MPS_PRIVATE_CLS(MPSVectorDescriptor),
308 _MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
309 length,
310 vectors,
311 vectorBytes,
312 dataType);
313}
314
315_MTL_INLINE Vector* Vector::alloc() {
316 return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
317}
318
319_MTL_INLINE Vector* Vector::init(
320 MTL::Buffer* buffer,
321 VectorDescriptor* descriptor) {
322 return Object::sendMessage<Vector*>(
323 this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
324}
325
326_MTL_INLINE Vector* Vector::init(
327 const MTL::Buffer* buffer,
328 VectorDescriptor* descriptor) {
329 return init(const_cast<MTL::Buffer*>(buffer), descriptor);
330}
331
333 return NS::Object::alloc<MatrixVectorMultiplication>(
334 _MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
335}
336
338 MTL::Device* device,
339 bool transpose,
340 NS::UInteger rows,
341 NS::UInteger columns,
342 double alpha,
343 double beta) {
344 return Object::sendMessage<MatrixVectorMultiplication*>(
345 this,
346 _MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
347 device,
348 transpose,
349 rows,
350 columns,
351 alpha,
352 beta);
353}
354
356 MTL::CommandBuffer* commandBuffer,
357 Matrix* inputMatrix,
358 Vector* inputVector,
359 Vector* resultVector) {
360 return Object::sendMessage<void>(
361 this,
363 encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
364 commandBuffer,
365 inputMatrix,
366 inputVector,
367 resultVector);
368}
369
370} // namespace MPS
Definition gemm.h:89
MTL::Device * device() const
Definition gemm.h:220
NS::String * label() const
Definition gemm.h:216
Definition gemm.h:65
NS::UInteger rows() const
Definition gemm.h:195
static class MatrixDescriptor * matrixDescriptor(NS::UInteger rows, NS::UInteger columns, NS::UInteger rowBytes, NS::UInteger dataType)
Definition gemm.h:162
Definition gemm.h:82
Matrix * init(MTL::Buffer *buffer, MatrixDescriptor *descriptor)
Definition gemm.h:203
static class Matrix * alloc()
Definition gemm.h:199
Definition gemm.h:96
MatrixMultiplication * init(MTL::Device *device, bool transposeLeft, bool transposeRight, NS::UInteger resultRows, NS::UInteger resultColumns, NS::UInteger interiorColumns, double alpha, double beta)
Definition gemm.h:229
void setBatchStart(NS::UInteger batchStart)
Definition gemm.h:283
void setResultMatrixOrigin(MTL::Origin origin)
Definition gemm.h:277
void setLeftMatrixOrigin(MTL::Origin origin)
Definition gemm.h:266
static class MatrixMultiplication * alloc()
Definition gemm.h:224
void setBatchSize(NS::UInteger batchSize)
Definition gemm.h:287
void encodeToCommandBuffer(MTL::CommandBuffer *commandBuffer, Matrix *leftMatrix, Matrix *rightMatrix, Matrix *resultMatrix)
Definition gemm.h:251
void setRightMatrixOrigin(MTL::Origin origin)
Definition gemm.h:271
Definition gemm.h:143
MatrixVectorMultiplication * init(MTL::Device *device, bool transpose, NS::UInteger rows, NS::UInteger columns, double alpha, double beta)
Definition gemm.h:337
void encodeToCommandBuffer(MTL::CommandBuffer *commandBuffer, Matrix *inputMatrix, Vector *inputVector, Vector *resultVector)
Definition gemm.h:355
static class MatrixVectorMultiplication * alloc()
Definition gemm.h:332
Definition gemm.h:123
static class VectorDescriptor * vectorDescriptor(NS::UInteger length, NS::UInteger dataType)
Definition gemm.h:291
Definition gemm.h:135
Vector * init(MTL::Buffer *buffer, VectorDescriptor *descriptor)
Definition gemm.h:319
static class Vector * alloc()
Definition gemm.h:315
#define _MPS_PRIVATE_SEL(accessor)
Definition gemm.h:8
#define _MPS_PRIVATE_CLS(symbol)
Definition gemm.h:7
Definition gemm.h:55
DataType
Definition gemm.h:57
@ DataTypeFloat32
Definition gemm.h:61
@ DataTypeFloatBit
Definition gemm.h:58
@ DataTypeBFloat16
Definition gemm.h:62
@ DataTypeFloat16
Definition gemm.h:60
@ DataTypeAlternateEncodingBit
Definition gemm.h:59
Definition gemm.h:10
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor)
Definition gemm.h:20
_MTL_PRIVATE_DEF_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType, "matrixDescriptorWithRows:columns:rowBytes:dataType:")