22 std::vector<array>& copies);
38 std::vector<array>& copies,
39 std::vector<int> batch_shape = {},
40 std::vector<size_t> A_batch_stride = {},
41 std::vector<size_t> B_batch_stride = {});
void steel_matmul(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< array > &copies, std::vector< int > batch_shape={}, std::vector< size_t > A_batch_stride={}, std::vector< size_t > B_batch_stride={})
void steel_matmul_conv_groups(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, int groups, std::vector< array > &copies)