mlx/mlx/backend/metal/matmul.h
Rifur13 c4a471c99d
Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00

51 lines
1019 B
C++

// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace mlx::core {
void steel_matmul_conv_groups(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
int groups,
std::vector<array>& copies);
void steel_matmul(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
std::vector<int> batch_shape = {},
std::vector<size_t> A_batch_stride = {},
std::vector<size_t> B_batch_stride = {});
} // namespace mlx::core