Masked gemv (#1211)

This commit is contained in:
Jagrit Digani
2024-06-14 09:52:26 -07:00
committed by GitHub
parent fe3167d7ea
commit 2d6cd47713
6 changed files with 1729 additions and 421 deletions

View File

@@ -3634,16 +3634,19 @@ std::vector<array> BlockMaskedMM::vjp(
};
// Prepare for padding if needed
int M = cotan.shape(-2);
int N = cotan.shape(-1);
int K = primals[0].shape(-1);
int align_M = (M % block_size_);
int align_N = (N % block_size_);
int align_K = (K % block_size_);
const int M = cotan.shape(-2);
const int N = cotan.shape(-1);
const int K = primals[0].shape(-1);
const int tm = (M + block_size_ - 1) / block_size_;
const int tn = (N + block_size_ - 1) / block_size_;
const int tk = (K + block_size_ - 1) / block_size_;
const int align_M = tm * block_size_ - M;
const int align_N = tn * block_size_ - N;
const int align_K = tk * block_size_ - K;
// Potential intermediates
auto unmasked_lhs_grad = primals[0];
auto unmasked_rhs_grad = primals[1];
array unmasked_lhs_grad = primals[0];
array unmasked_rhs_grad = primals[1];
bool unmasked_lhs_grad_calculated = false;
bool unmasked_rhs_grad_calculated = false;