MLX
Loading...
Searching...
No Matches
matmul.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
4
5namespace mlx::core {
6
8 const Stream& s,
10 const array& a,
11 const array& b,
12 array& out,
13 int M,
14 int N,
15 int K,
16 int batch_size_out,
17 int lda,
18 int ldb,
19 int ldd,
20 bool transpose_a,
21 bool transpose_b,
22 std::vector<int> batch_shape,
23 std::vector<size_t> batch_strides,
24 size_t A_batch_stride,
25 size_t B_batch_stride,
26 size_t matrix_stride_out,
27 std::vector<array>& copies);
28
30 const Stream& s,
32 const array& a,
33 const array& b,
34 array& out,
35 int M,
36 int N,
37 int K,
38 int batch_size_out,
39 int lda,
40 int ldb,
41 bool transpose_a,
42 bool transpose_b,
43 std::vector<array>& copies,
44 std::vector<int> batch_shape = {},
45 std::vector<size_t> A_batch_stride = {},
46 std::vector<size_t> B_batch_stride = {});
47
48} // namespace mlx::core
Definition array.h:20
Definition device.h:87
Definition allocator.h:7
void steel_matmul_regular(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, std::vector< int > batch_shape, std::vector< size_t > batch_strides, size_t A_batch_stride, size_t B_batch_stride, size_t matrix_stride_out, std::vector< array > &copies)
void steel_matmul(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< array > &copies, std::vector< int > batch_shape={}, std::vector< size_t > A_batch_stride={}, std::vector< size_t > B_batch_stride={})
Definition stream.h:9