MLX
Loading...
Searching...
No Matches
matmul.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#include <algorithm>
4#include <cassert>
5#include <sstream>
6
11#include "mlx/utils.h"
12
13namespace mlx::core {
14
16 const Stream& s,
18 const array& a,
19 const array& b,
20 array& out,
21 int M,
22 int N,
23 int K,
24 int lda,
25 int ldb,
26 int ldd,
27 bool transpose_a,
28 bool transpose_b,
29 int groups,
30 std::vector<array>& copies);
31
33 const Stream& s,
35 const array& a,
36 const array& b,
37 array& out,
38 int M,
39 int N,
40 int K,
41 int batch_size_out,
42 int lda,
43 int ldb,
44 bool transpose_a,
45 bool transpose_b,
46 std::vector<array>& copies,
47 std::vector<int> batch_shape = {},
48 std::vector<size_t> A_batch_stride = {},
49 std::vector<size_t> B_batch_stride = {});
50
51} // namespace mlx::core
Definition array.h:20
Definition device.h:117
Definition allocator.h:7
void steel_matmul(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< array > &copies, std::vector< int > batch_shape={}, std::vector< size_t > A_batch_stride={}, std::vector< size_t > B_batch_stride={})
void steel_matmul_conv_groups(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, int groups, std::vector< array > &copies)
Definition stream.h:9