MLX
 
Loading...
Searching...
No Matches
matmul.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
6
7namespace mlx::core {
8
10 const Stream& s,
12 const array& a,
13 const array& b,
14 array& out,
15 int M,
16 int N,
17 int K,
18 int batch_size_out,
19 int lda,
20 int ldb,
21 int ldd,
22 bool transpose_a,
23 bool transpose_b,
24 Shape batch_shape,
25 Strides batch_strides,
26 int64_t A_batch_stride,
27 int64_t B_batch_stride,
28 int64_t matrix_stride_out,
29 std::vector<array>& copies);
30
32 const Stream& s,
34 const array& a,
35 const array& b,
36 array& out,
37 int M,
38 int N,
39 int K,
40 int batch_size_out,
41 int lda,
42 int ldb,
43 bool transpose_a,
44 bool transpose_b,
45 std::vector<array>& copies,
46 Shape batch_shape = {},
47 Strides A_batch_stride = {},
48 Strides B_batch_stride = {});
49
50} // namespace mlx::core
Definition array.h:24
Definition device.h:158
Definition allocator.h:7
void steel_matmul_regular(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, Shape batch_shape, Strides batch_strides, int64_t A_batch_stride, int64_t B_batch_stride, int64_t matrix_stride_out, std::vector< array > &copies)
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void steel_matmul(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< array > &copies, Shape batch_shape={}, Strides A_batch_stride={}, Strides B_batch_stride={})
Definition stream.h:9