mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Masked gemv (#1211)
This commit is contained in:
@@ -3634,16 +3634,19 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
};
|
||||
|
||||
// Prepare for padding if needed
|
||||
int M = cotan.shape(-2);
|
||||
int N = cotan.shape(-1);
|
||||
int K = primals[0].shape(-1);
|
||||
int align_M = (M % block_size_);
|
||||
int align_N = (N % block_size_);
|
||||
int align_K = (K % block_size_);
|
||||
const int M = cotan.shape(-2);
|
||||
const int N = cotan.shape(-1);
|
||||
const int K = primals[0].shape(-1);
|
||||
const int tm = (M + block_size_ - 1) / block_size_;
|
||||
const int tn = (N + block_size_ - 1) / block_size_;
|
||||
const int tk = (K + block_size_ - 1) / block_size_;
|
||||
const int align_M = tm * block_size_ - M;
|
||||
const int align_N = tn * block_size_ - N;
|
||||
const int align_K = tk * block_size_ - K;
|
||||
|
||||
// Potential intermediates
|
||||
auto unmasked_lhs_grad = primals[0];
|
||||
auto unmasked_rhs_grad = primals[1];
|
||||
array unmasked_lhs_grad = primals[0];
|
||||
array unmasked_rhs_grad = primals[1];
|
||||
|
||||
bool unmasked_lhs_grad_calculated = false;
|
||||
bool unmasked_rhs_grad_calculated = false;
|
||||
|
||||
Reference in New Issue
Block a user