5#include <Metal/Metal.hpp>
7#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
8#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
22 matrixDescriptorWithRows_columns_rowBytes_dataType,
23 "matrixDescriptorWithRows:columns:rowBytes:dataType:");
25 matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
26 "matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
31 "initWithDevice:transposeLeft:transposeRight:"
32 "resultRows:resultColumns:interiorColumns:alpha:beta:");
34 encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
35 "encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
42 vectorDescriptorWithLength_dataType,
43 "vectorDescriptorWithLength:dataType:");
45 vectorDescriptorWithLength_vectors_vectorBytes_dataType,
46 "vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
48 initWithDevice_transpose_rows_columns_alpha_beta,
49 "initWithDevice:transpose:rows:columns:alpha:beta:");
51 encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
52 "encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
70 NS::UInteger rowBytes,
71 NS::UInteger dataType);
75 NS::UInteger matrices,
76 NS::UInteger rowBytes,
77 NS::UInteger matrixBytes,
78 NS::UInteger dataType);
79 NS::UInteger
rows()
const;
82class Matrix :
public NS::Referencing<Matrix> {
89class Kernel :
public NS::Referencing<Kernel> {
91 NS::String*
label()
const;
92 MTL::Device*
device()
const;
96 :
public NS::Referencing<MatrixMultiplication, Kernel> {
104 NS::UInteger resultRows,
105 NS::UInteger resultColumns,
106 NS::UInteger interiorColumns,
111 MTL::CommandBuffer* commandBuffer,
127 NS::UInteger dataType);
130 NS::UInteger vectors,
131 NS::UInteger vectorBytes,
132 NS::UInteger dataType);
135class Vector :
public NS::Referencing<Vector> {
143 :
public NS::Referencing<MatrixVectorMultiplication, Kernel> {
151 NS::UInteger columns,
156 MTL::CommandBuffer* commandBuffer,
164 NS::UInteger columns,
165 NS::UInteger rowBytes,
166 NS::UInteger dataType) {
167 return Object::sendMessage<MatrixDescriptor*>(
178 NS::UInteger columns,
179 NS::UInteger matrices,
180 NS::UInteger rowBytes,
181 NS::UInteger matrixBytes,
182 NS::UInteger dataType) {
183 return Object::sendMessage<MatrixDescriptor*>(
186 matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
206 return Object::sendMessage<Matrix*>(
211 const MTL::Buffer* buffer,
213 return init(
const_cast<MTL::Buffer*
>(buffer), descriptor);
225 return NS::Object::alloc<MatrixMultiplication>(
233 NS::UInteger resultRows,
234 NS::UInteger resultColumns,
235 NS::UInteger interiorColumns,
238 return Object::sendMessage<MatrixMultiplication*>(
252 MTL::CommandBuffer* commandBuffer,
256 return Object::sendMessage<void>(
259 encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
267 Object::sendMessage<void>(
272 MTL::Origin origin) {
273 Object::sendMessage<void>(
278 MTL::Origin origin) {
279 Object::sendMessage<void>(
284 Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(setBatchStart_), batchStart);
288 Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(setBatchSize_), batchSize);
293 NS::UInteger dataType) {
294 return Object::sendMessage<VectorDescriptor*>(
303 NS::UInteger vectors,
304 NS::UInteger vectorBytes,
305 NS::UInteger dataType) {
306 return Object::sendMessage<VectorDescriptor*>(
322 return Object::sendMessage<Vector*>(
327 const MTL::Buffer* buffer,
329 return init(
const_cast<MTL::Buffer*
>(buffer), descriptor);
333 return NS::Object::alloc<MatrixVectorMultiplication>(
341 NS::UInteger columns,
344 return Object::sendMessage<MatrixVectorMultiplication*>(
356 MTL::CommandBuffer* commandBuffer,
360 return Object::sendMessage<void>(
363 encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
MTL::Device * device() const
Definition gemm.h:220
NS::String * label() const
Definition gemm.h:216
NS::UInteger rows() const
Definition gemm.h:195
static class MatrixDescriptor * matrixDescriptor(NS::UInteger rows, NS::UInteger columns, NS::UInteger rowBytes, NS::UInteger dataType)
Definition gemm.h:162
Matrix * init(MTL::Buffer *buffer, MatrixDescriptor *descriptor)
Definition gemm.h:203
static class Matrix * alloc()
Definition gemm.h:199
MatrixMultiplication * init(MTL::Device *device, bool transposeLeft, bool transposeRight, NS::UInteger resultRows, NS::UInteger resultColumns, NS::UInteger interiorColumns, double alpha, double beta)
Definition gemm.h:229
void setBatchStart(NS::UInteger batchStart)
Definition gemm.h:283
void setResultMatrixOrigin(MTL::Origin origin)
Definition gemm.h:277
void setLeftMatrixOrigin(MTL::Origin origin)
Definition gemm.h:266
static class MatrixMultiplication * alloc()
Definition gemm.h:224
void setBatchSize(NS::UInteger batchSize)
Definition gemm.h:287
void encodeToCommandBuffer(MTL::CommandBuffer *commandBuffer, Matrix *leftMatrix, Matrix *rightMatrix, Matrix *resultMatrix)
Definition gemm.h:251
void setRightMatrixOrigin(MTL::Origin origin)
Definition gemm.h:271
MatrixVectorMultiplication * init(MTL::Device *device, bool transpose, NS::UInteger rows, NS::UInteger columns, double alpha, double beta)
Definition gemm.h:337
void encodeToCommandBuffer(MTL::CommandBuffer *commandBuffer, Matrix *inputMatrix, Vector *inputVector, Vector *resultVector)
Definition gemm.h:355
static class MatrixVectorMultiplication * alloc()
Definition gemm.h:332
static class VectorDescriptor * vectorDescriptor(NS::UInteger length, NS::UInteger dataType)
Definition gemm.h:291
Vector * init(MTL::Buffer *buffer, VectorDescriptor *descriptor)
Definition gemm.h:319
static class Vector * alloc()
Definition gemm.h:315
#define _MPS_PRIVATE_SEL(accessor)
Definition gemm.h:8
#define _MPS_PRIVATE_CLS(symbol)
Definition gemm.h:7
DataType
Definition gemm.h:57
@ DataTypeFloat32
Definition gemm.h:61
@ DataTypeFloatBit
Definition gemm.h:58
@ DataTypeBFloat16
Definition gemm.h:62
@ DataTypeFloat16
Definition gemm.h:60
@ DataTypeAlternateEncodingBit
Definition gemm.h:59
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor)
_MTL_PRIVATE_DEF_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType, "matrixDescriptorWithRows:columns:rowBytes:dataType:")