mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
50dfb664db
commit
eab2685c67
@ -17,24 +17,25 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename mask_t>
|
||||||
inline void mask_matrix(
|
inline void mask_matrix(
|
||||||
T* data,
|
T* data,
|
||||||
const bool* mask,
|
const mask_t* mask,
|
||||||
int block_size,
|
int block_size,
|
||||||
const int X,
|
const int X,
|
||||||
const int Y,
|
const int Y,
|
||||||
const size_t X_data_str,
|
const size_t X_data_str,
|
||||||
const size_t Y_data_str,
|
const size_t Y_data_str,
|
||||||
const size_t X_mask_str,
|
const size_t X_mask_str,
|
||||||
const size_t Y_mask_str) {
|
const size_t Y_mask_str,
|
||||||
|
const size_t mask_offset) {
|
||||||
int tX = (X + block_size - 1) / block_size;
|
int tX = (X + block_size - 1) / block_size;
|
||||||
int tY = (Y + block_size - 1) / block_size;
|
int tY = (Y + block_size - 1) / block_size;
|
||||||
|
|
||||||
for (int i = 0; i < tX; i++) {
|
for (int i = 0; i < tX; i++) {
|
||||||
for (int j = 0; j < tY; j++) {
|
for (int j = 0; j < tY; j++) {
|
||||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
|
||||||
if (!do_mask) {
|
if (do_mask != 1) {
|
||||||
int loc_x = i * block_size;
|
int loc_x = i * block_size;
|
||||||
int loc_y = j * block_size;
|
int loc_y = j * block_size;
|
||||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||||
@ -43,7 +44,11 @@ inline void mask_matrix(
|
|||||||
int size_y = std::min(block_size, Y - loc_y);
|
int size_y = std::min(block_size, Y - loc_y);
|
||||||
for (int ii = 0; ii < size_x; ii++) {
|
for (int ii = 0; ii < size_x; ii++) {
|
||||||
for (int jj = 0; jj < size_y; jj++) {
|
for (int jj = 0; jj < size_y; jj++) {
|
||||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
if constexpr (std::is_same_v<mask_t, bool>) {
|
||||||
|
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||||
|
} else {
|
||||||
|
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -62,36 +67,39 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
auto& out_mask = inputs[2];
|
|
||||||
|
|
||||||
auto check_transpose = [](const array& arr, bool do_copy) {
|
auto check_transpose =
|
||||||
auto stx = arr.strides()[arr.ndim() - 2];
|
[](const array& arr, bool do_copy, bool expand_all = false) {
|
||||||
auto sty = arr.strides()[arr.ndim() - 1];
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
if (stx == arr.shape(-1) && sty == 1) {
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
if (do_copy) {
|
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
if (do_copy) {
|
||||||
copy(arr, arr_copy, CopyType::Vector);
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
return std::make_tuple(false, stx, arr_copy);
|
copy(arr, arr_copy, CopyType::Vector);
|
||||||
}
|
return std::make_tuple(false, stx, arr_copy);
|
||||||
return std::make_tuple(false, stx, arr);
|
}
|
||||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
return std::make_tuple(false, stx, arr);
|
||||||
if (do_copy) {
|
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
if (do_copy) {
|
||||||
copy(arr, arr_copy, CopyType::Vector);
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
return std::make_tuple(true, sty, arr_copy);
|
copy(arr, arr_copy, CopyType::Vector);
|
||||||
}
|
return std::make_tuple(true, sty, arr_copy);
|
||||||
return std::make_tuple(true, sty, arr);
|
}
|
||||||
} else {
|
return std::make_tuple(true, sty, arr);
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
} else {
|
||||||
copy(arr, arr_copy, CopyType::General);
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
size_t stx = arr.shape(-1);
|
copy(arr, arr_copy, CopyType::General);
|
||||||
return std::make_tuple(false, stx, arr_copy);
|
size_t stx = arr.shape(-1);
|
||||||
}
|
return std::make_tuple(false, stx, arr_copy);
|
||||||
};
|
}
|
||||||
|
};
|
||||||
|
|
||||||
bool has_op_mask = inputs.size() > 3;
|
bool has_op_mask = inputs.size() > 3;
|
||||||
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
|
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||||
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
|
auto [a_transposed, lda, a] =
|
||||||
|
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||||
|
auto [b_transposed, ldb, b] =
|
||||||
|
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||||
|
|
||||||
size_t M = a.shape(-2);
|
size_t M = a.shape(-2);
|
||||||
size_t N = b.shape(-1);
|
size_t N = b.shape(-1);
|
||||||
@ -114,27 +122,42 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
int Y,
|
int Y,
|
||||||
size_t X_data_str,
|
size_t X_data_str,
|
||||||
size_t Y_data_str) {
|
size_t Y_data_str) {
|
||||||
const bool* mask_ptr = mask.data<bool>() +
|
size_t mask_offset = elem_to_loc(
|
||||||
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
|
mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||||
mask.shape(),
|
mask.shape(),
|
||||||
mask.strides());
|
mask.strides());
|
||||||
|
|
||||||
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||||
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||||
|
|
||||||
return mask_matrix(
|
if (mask.dtype() == bool_) {
|
||||||
data,
|
return mask_matrix(
|
||||||
mask_ptr,
|
data,
|
||||||
block_size,
|
mask.data<bool>(),
|
||||||
X,
|
block_size,
|
||||||
Y,
|
X,
|
||||||
X_data_str,
|
Y,
|
||||||
Y_data_str,
|
X_data_str,
|
||||||
X_mask_str,
|
Y_data_str,
|
||||||
Y_mask_str);
|
X_mask_str,
|
||||||
|
Y_mask_str,
|
||||||
|
mask_offset);
|
||||||
|
} else {
|
||||||
|
return mask_matrix(
|
||||||
|
data,
|
||||||
|
mask.data<float>(),
|
||||||
|
block_size,
|
||||||
|
X,
|
||||||
|
Y,
|
||||||
|
X_data_str,
|
||||||
|
Y_data_str,
|
||||||
|
X_mask_str,
|
||||||
|
Y_mask_str,
|
||||||
|
mask_offset);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
|
||||||
// Adjust pointer
|
// Adjust pointer
|
||||||
float* ai =
|
float* ai =
|
||||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
|
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||||
@ -144,7 +167,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Zero out blocks in a and b if needed
|
// Zero out blocks in a and b if needed
|
||||||
if (has_op_mask) {
|
if (has_op_mask) {
|
||||||
auto& a_mask = inputs[3];
|
auto& a_mask = inputs[inputs.size() - 2];
|
||||||
mask_array(
|
mask_array(
|
||||||
a_mask,
|
a_mask,
|
||||||
ai,
|
ai,
|
||||||
@ -155,7 +178,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
a_transposed ? 1 : lda,
|
a_transposed ? 1 : lda,
|
||||||
a_transposed ? lda : 1);
|
a_transposed ? lda : 1);
|
||||||
|
|
||||||
auto& b_mask = inputs[4];
|
auto& b_mask = inputs[inputs.size() - 1];
|
||||||
mask_array(
|
mask_array(
|
||||||
b_mask,
|
b_mask,
|
||||||
bi,
|
bi,
|
||||||
@ -186,7 +209,9 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Zero out blocks in out
|
// Zero out blocks in out
|
||||||
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
|
if (has_out_mask) {
|
||||||
|
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,8 +11,38 @@ using namespace mlx::steel;
|
|||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct _NoMask {
|
||||||
|
char x;
|
||||||
|
|
||||||
|
constexpr METAL_FUNC operator bool() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC operator bool() const threadgroup {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC operator bool() const device {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC operator bool() const constant {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OutT, typename InT = OutT>
|
||||||
|
struct ScaleOp {
|
||||||
|
OutT scale;
|
||||||
|
|
||||||
|
METAL_FUNC OutT apply(InT x) const {
|
||||||
|
return static_cast<OutT>(x) * scale;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct _NoMask nomask_t;
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
|
typename out_mask_t,
|
||||||
|
typename op_mask_t,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
@ -21,8 +51,7 @@ template <
|
|||||||
bool transpose_a,
|
bool transpose_a,
|
||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
bool MN_aligned,
|
bool MN_aligned,
|
||||||
bool K_aligned,
|
bool K_aligned>
|
||||||
bool has_operand_mask = false>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
block_masked_gemm(
|
block_masked_gemm(
|
||||||
const device T* A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
@ -31,9 +60,9 @@ block_masked_gemm(
|
|||||||
const constant GEMMParams* params [[buffer(4)]],
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
const constant size_t* batch_strides [[buffer(7)]],
|
const constant size_t* batch_strides [[buffer(7)]],
|
||||||
const device bool* out_mask [[buffer(10)]],
|
const device out_mask_t* out_mask [[buffer(10)]],
|
||||||
const device bool* lhs_mask [[buffer(11)]],
|
const device op_mask_t* lhs_mask [[buffer(11)]],
|
||||||
const device bool* rhs_mask [[buffer(12)]],
|
const device op_mask_t* rhs_mask [[buffer(12)]],
|
||||||
const constant int* mask_strides [[buffer(13)]],
|
const constant int* mask_strides [[buffer(13)]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
@ -42,6 +71,21 @@ block_masked_gemm(
|
|||||||
// Appease the compiler
|
// Appease the compiler
|
||||||
(void)lid;
|
(void)lid;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
BM == BN,
|
||||||
|
"block_masked_gemm must have the same block M and block N size");
|
||||||
|
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
|
||||||
|
|
||||||
|
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||||
|
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||||
|
|
||||||
|
constexpr bool has_mul_operand_mask =
|
||||||
|
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||||
|
constexpr bool has_mul_output_mask =
|
||||||
|
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||||
|
|
||||||
|
constexpr short k_mask_factor = short(BM / BK);
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<
|
using gemm_kernel = GEMMKernel<
|
||||||
T,
|
T,
|
||||||
T,
|
T,
|
||||||
@ -63,15 +107,19 @@ block_masked_gemm(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const constant size_t* mask_batch_strides =
|
||||||
|
batch_strides + 2 * params->batch_ndim;
|
||||||
|
|
||||||
if (params->batch_ndim > 1) {
|
if (params->batch_ndim > 1) {
|
||||||
const constant size_t* mask_batch_strides =
|
if (has_output_mask) {
|
||||||
batch_strides + 2 * params->batch_ndim;
|
out_mask += elem_to_loc(
|
||||||
out_mask +=
|
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
|
||||||
|
mask_batch_strides += params->batch_ndim;
|
||||||
|
}
|
||||||
|
|
||||||
if (has_operand_mask) {
|
if (has_operand_mask) {
|
||||||
const constant size_t* mask_strides_lhs =
|
const constant size_t* mask_strides_lhs = mask_batch_strides;
|
||||||
mask_batch_strides + params->batch_ndim;
|
|
||||||
const constant size_t* mask_strides_rhs =
|
const constant size_t* mask_strides_rhs =
|
||||||
mask_strides_lhs + params->batch_ndim;
|
mask_strides_lhs + params->batch_ndim;
|
||||||
|
|
||||||
@ -86,10 +134,14 @@ block_masked_gemm(
|
|||||||
rhs_mask += batch_offsets.y;
|
rhs_mask += batch_offsets.y;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
|
if (has_output_mask) {
|
||||||
|
out_mask += tid.z * mask_batch_strides[0];
|
||||||
|
mask_batch_strides += params->batch_ndim;
|
||||||
|
}
|
||||||
|
|
||||||
if (has_operand_mask) {
|
if (has_operand_mask) {
|
||||||
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
|
lhs_mask += tid.z * mask_batch_strides[0];
|
||||||
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
|
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,44 +173,69 @@ block_masked_gemm(
|
|||||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||||
D += c_row_long * params->ldd + c_col_long;
|
D += c_row_long * params->ldd + c_col_long;
|
||||||
|
|
||||||
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
const constant int* out_mask_strides = mask_strides;
|
||||||
|
const constant int* lhs_mask_strides =
|
||||||
|
mask_strides + (has_output_mask ? 2 : 0);
|
||||||
|
const constant int* rhs_mask_strides =
|
||||||
|
lhs_mask_strides + (has_operand_mask ? 2 : 0);
|
||||||
|
|
||||||
// Write zeros and return
|
const int out_mask_offset = !has_output_mask
|
||||||
if (!mask_out) {
|
? 0
|
||||||
constexpr short tgp_size = WM * WN * 32;
|
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
|
||||||
constexpr short vec_size = 4;
|
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
|
||||||
|
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
|
||||||
|
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
|
||||||
|
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
|
||||||
|
short k_factor_cnt = k_mask_factor;
|
||||||
|
|
||||||
// Tile threads in threadgroup
|
ScaleOp<float> out_mask_op;
|
||||||
constexpr short TN = BN / vec_size;
|
ScaleOp<T> lhs_mask_op;
|
||||||
constexpr short TM = tgp_size / TN;
|
ScaleOp<T> rhs_mask_op;
|
||||||
|
|
||||||
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
if (has_output_mask) {
|
||||||
const short bi = thread_idx / TN;
|
auto mask_out = out_mask[out_mask_offset];
|
||||||
const short bj = vec_size * (thread_idx % TN);
|
|
||||||
|
|
||||||
D += bi * params->ldd + bj;
|
if (has_mul_output_mask) {
|
||||||
|
out_mask_op.scale = float(mask_out);
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
|
||||||
|
|
||||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
||||||
for (short ti = 0; ti < BM; ti += TM) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
D[ti * params->ldd + j] = T(0.);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
short jmax = tgp_bn - bj;
|
|
||||||
jmax = jmax < vec_size ? jmax : vec_size;
|
|
||||||
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
|
||||||
for (short j = 0; j < jmax; j++) {
|
|
||||||
D[ti * params->ldd + j] = T(0.);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
// Write zeros and return
|
||||||
|
if (!mask_out) {
|
||||||
|
constexpr short tgp_size = WM * WN * 32;
|
||||||
|
constexpr short vec_size = 4;
|
||||||
|
|
||||||
|
// Tile threads in threadgroup
|
||||||
|
constexpr short TN = BN / vec_size;
|
||||||
|
constexpr short TM = tgp_size / TN;
|
||||||
|
|
||||||
|
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
||||||
|
const short bi = thread_idx / TN;
|
||||||
|
const short bj = vec_size * (thread_idx % TN);
|
||||||
|
|
||||||
|
D += bi * params->ldd + bj;
|
||||||
|
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
|
||||||
|
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
|
for (short ti = 0; ti < BM; ti += TM) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
D[ti * params->ldd + j] = T(0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
short jmax = tgp_bn - bj;
|
||||||
|
jmax = jmax < vec_size ? jmax : vec_size;
|
||||||
|
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||||
|
for (short j = 0; j < jmax; j++) {
|
||||||
|
D[ti * params->ldd + j] = T(0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
@ -166,8 +243,6 @@ block_masked_gemm(
|
|||||||
// Prepare threadgroup mma operation
|
// Prepare threadgroup mma operation
|
||||||
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
@ -177,21 +252,88 @@ block_masked_gemm(
|
|||||||
thread typename gemm_kernel::loader_b_t loader_b(
|
thread typename gemm_kernel::loader_b_t loader_b(
|
||||||
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup bounds
|
||||||
|
const short tgp_bm =
|
||||||
|
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
|
||||||
|
const short tgp_bn =
|
||||||
|
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Do unaligned K iterations first
|
||||||
|
if (!K_aligned) {
|
||||||
|
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||||
|
const int mask_idx_last = k_last / BM;
|
||||||
|
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
|
||||||
|
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
lhs_mask_op.scale =
|
||||||
|
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
|
||||||
|
rhs_mask_op.scale =
|
||||||
|
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move loader source ahead to end
|
||||||
|
const int k_remain = params->K - k_last;
|
||||||
|
const size_t k_jump_a =
|
||||||
|
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||||
|
const size_t k_jump_b =
|
||||||
|
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||||
|
|
||||||
|
loader_a.src += k_jump_a;
|
||||||
|
loader_b.src += k_jump_b;
|
||||||
|
|
||||||
|
// Load tile
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
loader_a.apply_inplace_op(lhs_mask_op);
|
||||||
|
loader_b.apply_inplace_op(rhs_mask_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Do matmul
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Reset source back to start
|
||||||
|
loader_a.src -= k_jump_a;
|
||||||
|
loader_b.src -= k_jump_b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// MNK aligned loop
|
// MNK aligned loop
|
||||||
if (MN_aligned) {
|
if (MN_aligned) {
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(lhs_mask
|
(bool(lhs_mask[lhs_mask_offset]) &&
|
||||||
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
bool(rhs_mask[rhs_mask_offset]))) {
|
||||||
rhs_mask
|
if (has_mul_operand_mask) {
|
||||||
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
||||||
|
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
||||||
|
}
|
||||||
|
|
||||||
// Load elements into threadgroup
|
// Load elements into threadgroup
|
||||||
loader_a.load_unsafe();
|
loader_a.load_unsafe();
|
||||||
loader_b.load_unsafe();
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
loader_a.apply_inplace_op(lhs_mask_op);
|
||||||
|
loader_b.apply_inplace_op(rhs_mask_op);
|
||||||
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
// Multiply and accumulate threadgroup elements
|
||||||
@ -201,29 +343,15 @@ block_masked_gemm(
|
|||||||
// Prepare for next iteration
|
// Prepare for next iteration
|
||||||
loader_a.next();
|
loader_a.next();
|
||||||
loader_b.next();
|
loader_b.next();
|
||||||
|
|
||||||
|
k_factor_cnt--;
|
||||||
|
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
||||||
|
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
||||||
|
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
if (has_mul_output_mask) {
|
||||||
|
mma_op.apply_epilogue(out_mask_op);
|
||||||
// Loop tail
|
|
||||||
if (!K_aligned) {
|
|
||||||
if (!has_operand_mask ||
|
|
||||||
(lhs_mask
|
|
||||||
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
|
||||||
rhs_mask
|
|
||||||
[(params->K / BM) * mask_strides[5] +
|
|
||||||
tid_x * mask_strides[4]])) {
|
|
||||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
|
||||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store results to device memory
|
// Store results to device memory
|
||||||
@ -233,24 +361,25 @@ block_masked_gemm(
|
|||||||
}
|
}
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// MN unaligned loop
|
// MN unaligned loop
|
||||||
else { // Loop over K - unaligned case
|
else {
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
const bool M_aligned = (tgp_bm == BM);
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
const bool N_aligned = (tgp_bn == BN);
|
||||||
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
|
|
||||||
bool M_aligned = (tgp_bm == BM);
|
const short2 tile_dims_A =
|
||||||
bool N_aligned = (tgp_bn == BN);
|
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||||
|
|
||||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
||||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
|
||||||
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (!has_operand_mask ||
|
if (!has_operand_mask ||
|
||||||
(lhs_mask
|
(bool(lhs_mask[lhs_mask_offset]) &&
|
||||||
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
bool(rhs_mask[rhs_mask_offset]))) {
|
||||||
rhs_mask
|
if (has_mul_operand_mask) {
|
||||||
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
||||||
|
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
||||||
|
}
|
||||||
|
|
||||||
// Load elements into threadgroup
|
// Load elements into threadgroup
|
||||||
if (M_aligned) {
|
if (M_aligned) {
|
||||||
loader_a.load_unsafe();
|
loader_a.load_unsafe();
|
||||||
@ -264,6 +393,11 @@ block_masked_gemm(
|
|||||||
loader_b.load_safe(tile_dims_B);
|
loader_b.load_safe(tile_dims_B);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
loader_a.apply_inplace_op(lhs_mask_op);
|
||||||
|
loader_b.apply_inplace_op(rhs_mask_op);
|
||||||
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
// Multiply and accumulate threadgroup elements
|
||||||
@ -273,29 +407,15 @@ block_masked_gemm(
|
|||||||
// Prepare for next iteration
|
// Prepare for next iteration
|
||||||
loader_a.next();
|
loader_a.next();
|
||||||
loader_b.next();
|
loader_b.next();
|
||||||
|
|
||||||
|
k_factor_cnt--;
|
||||||
|
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
||||||
|
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
||||||
|
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!K_aligned) {
|
if (has_mul_output_mask) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
mma_op.apply_epilogue(out_mask_op);
|
||||||
|
|
||||||
if (!has_operand_mask ||
|
|
||||||
(lhs_mask
|
|
||||||
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
|
||||||
rhs_mask
|
|
||||||
[(params->K / BM) * mask_strides[5] +
|
|
||||||
tid_x * mask_strides[4]])) {
|
|
||||||
short2 tile_dims_A_last =
|
|
||||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
|
||||||
short2 tile_dims_B_last =
|
|
||||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A_last);
|
|
||||||
loader_b.load_safe(tile_dims_B_last);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (M_aligned && N_aligned) {
|
if (M_aligned && N_aligned) {
|
||||||
@ -311,6 +431,10 @@ block_masked_gemm(
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm( \
|
#define instantiate_gemm( \
|
||||||
|
outmaskname, \
|
||||||
|
outmasktype, \
|
||||||
|
opmaskname, \
|
||||||
|
opmasktype, \
|
||||||
tname, \
|
tname, \
|
||||||
trans_a, \
|
trans_a, \
|
||||||
trans_b, \
|
trans_b, \
|
||||||
@ -326,15 +450,15 @@ block_masked_gemm(
|
|||||||
aname, \
|
aname, \
|
||||||
mn_aligned, \
|
mn_aligned, \
|
||||||
kname, \
|
kname, \
|
||||||
k_aligned, \
|
k_aligned) \
|
||||||
omname, \
|
template [[host_name("steel_gemm_block_outmask_" #outmaskname \
|
||||||
op_mask) \
|
"_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \
|
||||||
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \
|
|
||||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||||
"_MN_" #aname "_K_" #kname \
|
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||||
"_op_mask_" #omname)]] [[kernel]] void \
|
|
||||||
block_masked_gemm< \
|
block_masked_gemm< \
|
||||||
itype, \
|
itype, \
|
||||||
|
outmasktype, \
|
||||||
|
opmasktype, \
|
||||||
bm, \
|
bm, \
|
||||||
bn, \
|
bn, \
|
||||||
bk, \
|
bk, \
|
||||||
@ -343,17 +467,16 @@ block_masked_gemm(
|
|||||||
trans_a, \
|
trans_a, \
|
||||||
trans_b, \
|
trans_b, \
|
||||||
mn_aligned, \
|
mn_aligned, \
|
||||||
k_aligned, \
|
k_aligned>( \
|
||||||
op_mask>( \
|
|
||||||
const device itype* A [[buffer(0)]], \
|
const device itype* A [[buffer(0)]], \
|
||||||
const device itype* B [[buffer(1)]], \
|
const device itype* B [[buffer(1)]], \
|
||||||
device itype* D [[buffer(3)]], \
|
device itype* D [[buffer(3)]], \
|
||||||
const constant GEMMParams* params [[buffer(4)]], \
|
const constant GEMMParams* params [[buffer(4)]], \
|
||||||
const constant int* batch_shape [[buffer(6)]], \
|
const constant int* batch_shape [[buffer(6)]], \
|
||||||
const constant size_t* batch_strides [[buffer(7)]], \
|
const constant size_t* batch_strides [[buffer(7)]], \
|
||||||
const device bool* out_mask [[buffer(10)]], \
|
const device outmasktype* out_mask [[buffer(10)]], \
|
||||||
const device bool* lhs_mask [[buffer(11)]], \
|
const device opmasktype* lhs_mask [[buffer(11)]], \
|
||||||
const device bool* rhs_mask [[buffer(12)]], \
|
const device opmasktype* rhs_mask [[buffer(12)]], \
|
||||||
const constant int* mask_strides [[buffer(13)]], \
|
const constant int* mask_strides [[buffer(13)]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
@ -361,9 +484,15 @@ block_masked_gemm(
|
|||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
|
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on
|
instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
instantiate_gemm(bool_, bool, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
instantiate_gemm(iname, itype, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) // clang-format on
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
@ -58,6 +58,18 @@ struct BlockLoader {
|
|||||||
dst(dst_ + bi * dst_ld + bj),
|
dst(dst_ + bi * dst_ld + bj),
|
||||||
src(src_ + bi * src_ld + bj) {}
|
src(src_ + bi * src_ld + bj) {}
|
||||||
|
|
||||||
|
/* Apply operation to threadgroup without bound checking */
|
||||||
|
template <typename UnaryOp>
|
||||||
|
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < BROWS; i += TROWS) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - without bound checking */
|
/* Load from device memory into threadgroup memory - without bound checking */
|
||||||
METAL_FUNC void load_unsafe() const {
|
METAL_FUNC void load_unsafe() const {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
|
@ -198,6 +198,24 @@ struct BlockMMA {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Apply epilogue */
|
||||||
|
template <typename UnaryEpilogue>
|
||||||
|
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
||||||
|
// Loop over all simdgroup tiles
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread auto& accum = results[i * TN + j].thread_elements();
|
||||||
|
|
||||||
|
// Apply epilogue
|
||||||
|
accum[0] = epilogue_op.apply(accum[0]);
|
||||||
|
accum[1] = epilogue_op.apply(accum[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* Apply epilogue */
|
/* Apply epilogue */
|
||||||
template <typename BinaryEpilogue>
|
template <typename BinaryEpilogue>
|
||||||
METAL_FUNC void apply_epilogue(
|
METAL_FUNC void apply_epilogue(
|
||||||
|
@ -1307,7 +1307,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Check and collapse batch dimensions
|
// Check and collapse batch dimensions
|
||||||
|
|
||||||
bool has_op_mask = inputs.size() > 3;
|
bool has_op_mask = inputs.size() > 3;
|
||||||
auto& out_mask = inputs[2];
|
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||||
|
|
||||||
std::vector<int> batch_shape{1};
|
std::vector<int> batch_shape{1};
|
||||||
size_t A_batch_str = 0;
|
size_t A_batch_str = 0;
|
||||||
@ -1350,14 +1350,17 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int wm = 2, wn = 2;
|
int wm = 2, wn = 2;
|
||||||
|
|
||||||
// Prepare kernel name
|
// Prepare kernel name
|
||||||
|
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
||||||
|
std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
|
||||||
|
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "steel_block_masked_gemm_" << (transpose_a ? 't' : 'n')
|
kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_"
|
||||||
|
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||||
<< ((K % bk == 0) ? "t" : "n") << "aligned" << "_op_mask_"
|
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||||
<< (has_op_mask ? "T" : "N");
|
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -1397,17 +1400,23 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||||
|
|
||||||
std::vector<int> mask_strides;
|
std::vector<int> mask_strides;
|
||||||
mask_strides.push_back(*(out_mask.strides().end() - 1));
|
|
||||||
mask_strides.push_back(*(out_mask.strides().end() - 2));
|
if (has_out_mask) {
|
||||||
|
auto& out_mask = inputs[2];
|
||||||
|
mask_strides.push_back(*(out_mask.strides().end() - 1));
|
||||||
|
mask_strides.push_back(*(out_mask.strides().end() - 2));
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(out_mask, 10);
|
||||||
|
}
|
||||||
|
|
||||||
if (has_op_mask) {
|
if (has_op_mask) {
|
||||||
auto& lhs_mask = inputs[3];
|
auto& lhs_mask = inputs[2 + has_out_mask];
|
||||||
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
|
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
|
||||||
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
|
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
|
||||||
|
|
||||||
compute_encoder.set_input_array(lhs_mask, 11);
|
compute_encoder.set_input_array(lhs_mask, 11);
|
||||||
|
|
||||||
auto& rhs_mask = inputs[4];
|
auto& rhs_mask = inputs[3 + has_out_mask];
|
||||||
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
|
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
|
||||||
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
|
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
|
||||||
|
|
||||||
@ -1424,7 +1433,6 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||||
|
|
||||||
compute_encoder.set_input_array(out_mask, 10);
|
|
||||||
set_vector_bytes(compute_encoder, mask_strides, 13);
|
set_vector_bytes(compute_encoder, mask_strides, 13);
|
||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
46
mlx/ops.cpp
46
mlx/ops.cpp
@ -3870,48 +3870,60 @@ array block_masked_mm(
|
|||||||
int tn = (N + block_size - 1) / block_size;
|
int tn = (N + block_size - 1) / block_size;
|
||||||
int tk = (K + block_size - 1) / block_size;
|
int tk = (K + block_size - 1) / block_size;
|
||||||
|
|
||||||
|
std::vector<array> inputs = {a, b};
|
||||||
|
|
||||||
// Broadcast and astype mask
|
// Broadcast and astype mask
|
||||||
auto broadcast_mask = [](array mask,
|
auto broadcast_mask = [](array mask,
|
||||||
std::vector<int>& bs_shape,
|
std::vector<int>& bs_shape,
|
||||||
int y,
|
int y,
|
||||||
int x,
|
int x,
|
||||||
|
Dtype mask_dtype,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
int nd_bsx = bs_shape.size();
|
int nd_bsx = bs_shape.size();
|
||||||
bs_shape[nd_bsx - 2] = y;
|
bs_shape[nd_bsx - 2] = y;
|
||||||
bs_shape[nd_bsx - 1] = x;
|
bs_shape[nd_bsx - 1] = x;
|
||||||
mask = astype(mask, bool_, s);
|
mask = astype(mask, mask_dtype, s);
|
||||||
return broadcast_to(mask, bs_shape, s);
|
return broadcast_to(mask, bs_shape, s);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Out mask
|
// Out mask
|
||||||
array mask_out_p = mask_out.value_or(array({true}));
|
if (mask_out.has_value()) {
|
||||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
array mask_out_p = mask_out.value_or(array({true}));
|
||||||
std::vector<int> ex_dims;
|
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||||
if (in_a_ndim == 1)
|
std::vector<int> ex_dims;
|
||||||
ex_dims.push_back(-2);
|
if (in_a_ndim == 1)
|
||||||
if (in_b_ndim == 1)
|
ex_dims.push_back(-2);
|
||||||
ex_dims.push_back(-1);
|
if (in_b_ndim == 1)
|
||||||
mask_out_p = expand_dims(mask_out_p, ex_dims, s);
|
ex_dims.push_back(-1);
|
||||||
}
|
mask_out_p = expand_dims(mask_out_p, ex_dims, s);
|
||||||
mask_out_p = broadcast_mask(mask_out_p, bsx_shape, tm, tn, s);
|
}
|
||||||
|
auto maskout_dtype = mask_out_p.dtype() == bool_ ? bool_ : out_type;
|
||||||
|
mask_out_p =
|
||||||
|
broadcast_mask(mask_out_p, bsx_shape, tm, tn, maskout_dtype, s);
|
||||||
|
|
||||||
std::vector<array> inputs = {a, b, mask_out_p};
|
inputs.push_back(mask_out_p);
|
||||||
|
}
|
||||||
|
|
||||||
// Operand masks
|
// Operand masks
|
||||||
if (has_operand_mask) {
|
if (has_operand_mask) {
|
||||||
// LHS mask
|
// Pull masks
|
||||||
array mask_lhs_p = mask_lhs.value_or(array({true}));
|
array mask_lhs_p = mask_lhs.value_or(array({true}));
|
||||||
|
array mask_rhs_p = mask_rhs.value_or(array({true}));
|
||||||
|
auto mask_dtype =
|
||||||
|
(mask_lhs_p.dtype() == bool_ && mask_rhs_p.dtype() == bool_) ? bool_
|
||||||
|
: out_type;
|
||||||
|
|
||||||
|
// LHS mask
|
||||||
if (in_a_ndim == 1) {
|
if (in_a_ndim == 1) {
|
||||||
mask_lhs_p = expand_dims(mask_lhs_p, -2, s);
|
mask_lhs_p = expand_dims(mask_lhs_p, -2, s);
|
||||||
}
|
}
|
||||||
mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, s);
|
mask_lhs_p = broadcast_mask(mask_lhs_p, bsx_shape, tm, tk, mask_dtype, s);
|
||||||
|
|
||||||
// RHS mask
|
// RHS mask
|
||||||
array mask_rhs_p = mask_rhs.value_or(array({true}));
|
|
||||||
if (in_b_ndim == 1) {
|
if (in_b_ndim == 1) {
|
||||||
mask_rhs_p = expand_dims(mask_lhs_p, -1, s);
|
mask_rhs_p = expand_dims(mask_rhs_p, -1, s);
|
||||||
}
|
}
|
||||||
mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, s);
|
mask_rhs_p = broadcast_mask(mask_rhs_p, bsx_shape, tk, tn, mask_dtype, s);
|
||||||
|
|
||||||
inputs.push_back(mask_lhs_p);
|
inputs.push_back(mask_lhs_p);
|
||||||
inputs.push_back(mask_rhs_p);
|
inputs.push_back(mask_rhs_p);
|
||||||
|
@ -3487,42 +3487,251 @@ std::vector<array> BlockMaskedMM::vjp(
|
|||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>&) {
|
const std::vector<array>&) {
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// The operation that is done w/o intermediates by the primitive is
|
||||||
|
// - tm = (M + block_size - 1) // block_size; MP = tm * block_size;
|
||||||
|
// - tn = (N + block_size - 1) // block_size; NP = tn * block_size;
|
||||||
|
// - tm = (K + block_size - 1) // block_size; KP = tk * block_size;
|
||||||
|
// - mask_b <- mask broadcasted to block sizes
|
||||||
|
// - A_m = A [..., M, K] * mask_b_lhs [..., MP, KP]
|
||||||
|
// - B_m = B [..., K, N] * mask_b_rhs [..., KP, MP]
|
||||||
|
// - C = A_m [..., M, K] @ B_m [..., K, N]
|
||||||
|
// - C_m = C [..., M, N] * mask_b_out [..., MP, NP]
|
||||||
|
//
|
||||||
|
// The grads are therefore
|
||||||
|
// - dC_m = cotan [..., M, N]
|
||||||
|
// - dmask_b_out = cotan [..., M, N] * C [..., M, N]
|
||||||
|
// - dC = cotan [..., M, N] * mask_b_out [..., MP, NP]
|
||||||
|
// - dA_m = dC [..., M, N] @ B_m.T [..., N, K]
|
||||||
|
// - dB_m = A_m.T [..., K, M] @ dC [..., M, N]
|
||||||
|
// - dA = dA_m * mask_b_lhs [..., MP, KP]
|
||||||
|
// - dB = dB_m * mask_b_rhs [..., KP, MP]
|
||||||
|
// - dmask_b_lhs = dA_m [..., M, K] * A [..., M, K] // need [..., MP, KP]
|
||||||
|
// - dmask_b_rhs = dB_m [..., K, N] * B [..., K, N] // need [..., KP, NP]
|
||||||
|
//
|
||||||
|
// Observations:
|
||||||
|
// * If dmask_b_lhs is not needed, then dA can be calulated in one go as a
|
||||||
|
// as a block_masked_mm with mask_b_lhs as the out_mask without needing to
|
||||||
|
// materialize the intermediate dA_m. Similar for dB.
|
||||||
|
// * If dmask_b_lhs is needed, we need to materialize dA_m directly and then
|
||||||
|
// point-wise multiply with A. But the output needs to be padded
|
||||||
|
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
auto& cotan = cotangents[0];
|
auto& cotan = cotangents[0];
|
||||||
std::vector<int> reorder(cotan.ndim());
|
std::vector<int> reorder(cotan.ndim());
|
||||||
std::iota(reorder.begin(), reorder.end(), 0);
|
std::iota(reorder.begin(), reorder.end(), 0);
|
||||||
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
|
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
|
||||||
|
|
||||||
bool has_op_mask = primals.size() > 3;
|
bool has_op_mask = primals.size() > 3;
|
||||||
|
bool has_out_mask = primals.size() == 3 || primals.size() == 5;
|
||||||
|
|
||||||
|
const int op_mask_idx = has_out_mask ? 3 : 2;
|
||||||
|
bool needs_lhs_mask_vjp = has_op_mask;
|
||||||
|
bool needs_rhs_mask_vjp = has_op_mask;
|
||||||
|
bool needs_lhs_vjp = false;
|
||||||
|
bool needs_rhs_vjp = false;
|
||||||
|
|
||||||
|
for (auto arg : argnums) {
|
||||||
|
needs_lhs_vjp = arg == 0;
|
||||||
|
needs_rhs_vjp = arg == 1;
|
||||||
|
needs_lhs_mask_vjp = arg == op_mask_idx;
|
||||||
|
needs_rhs_mask_vjp = arg == op_mask_idx + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((needs_lhs_mask_vjp && primals[op_mask_idx].dtype() == bool_) ||
|
||||||
|
(needs_rhs_mask_vjp && primals[op_mask_idx + 1].dtype() == bool_)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[BlockMaskedMM] Cannot calculate VJP with respect to boolean masks.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto expand_mask = [&](array mask, int Y, int X) {
|
||||||
|
// Exapnd mask
|
||||||
|
auto mask_reshape = mask.shape();
|
||||||
|
mask = expand_dims(mask, {-3, -1}, stream());
|
||||||
|
auto mask_shape = mask.shape();
|
||||||
|
int mask_ndim = mask_shape.size();
|
||||||
|
|
||||||
|
// Broadcast mask
|
||||||
|
mask_shape[mask_ndim - 1] = block_size_;
|
||||||
|
mask_shape[mask_ndim - 3] = block_size_;
|
||||||
|
mask = broadcast_to(mask, mask_shape, stream());
|
||||||
|
|
||||||
|
// Reshape mask to squeeze in braodcasted dims
|
||||||
|
mask_ndim = mask_reshape.size();
|
||||||
|
mask_reshape[mask_ndim - 2] *= block_size_;
|
||||||
|
mask_reshape[mask_ndim - 1] *= block_size_;
|
||||||
|
mask = reshape(mask, mask_reshape, stream());
|
||||||
|
|
||||||
|
// Slice mask
|
||||||
|
mask_reshape[mask_ndim - 2] = Y;
|
||||||
|
mask_reshape[mask_ndim - 1] = X;
|
||||||
|
mask = slice(mask, std::vector<int>(mask_ndim, 0), mask_reshape, stream());
|
||||||
|
|
||||||
|
return mask;
|
||||||
|
};
|
||||||
|
|
||||||
|
array zero = array(0, cotan.dtype());
|
||||||
|
|
||||||
|
auto multiply_pad_reduce = [&](array p, array q, int align_Y, int align_X) {
|
||||||
|
// Multiply with cotan
|
||||||
|
auto r = multiply(p, q, stream());
|
||||||
|
|
||||||
|
// Pad if needed
|
||||||
|
if ((align_Y != 0) || (align_X != 0)) {
|
||||||
|
r = pad(r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape
|
||||||
|
std::vector<int> r_reshape(r.shape().begin(), r.shape().end() - 2);
|
||||||
|
r_reshape.push_back(r.shape(-2) / block_size_);
|
||||||
|
r_reshape.push_back(block_size_);
|
||||||
|
r_reshape.push_back(r.shape(-1) / block_size_);
|
||||||
|
r_reshape.push_back(block_size_);
|
||||||
|
r = reshape(r, r_reshape, stream());
|
||||||
|
|
||||||
|
// Reduce
|
||||||
|
return sum(r, {-3, -1}, false, stream());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Prepare for padding if needed
|
||||||
|
int M = cotan.shape(-2);
|
||||||
|
int N = cotan.shape(-1);
|
||||||
|
int K = primals[0].shape(-1);
|
||||||
|
int align_M = (M % block_size_);
|
||||||
|
int align_N = (N % block_size_);
|
||||||
|
int align_K = (K % block_size_);
|
||||||
|
|
||||||
|
// Potential intermediates
|
||||||
|
auto unmasked_lhs_grad = primals[0];
|
||||||
|
auto unmasked_rhs_grad = primals[1];
|
||||||
|
|
||||||
|
bool unmasked_lhs_grad_calculated = false;
|
||||||
|
bool unmasked_rhs_grad_calculated = false;
|
||||||
|
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
// M X N * (K X N).T -> M X K
|
// M X N * (K X N).T -> M X K
|
||||||
auto b_t = transpose(primals[1], reorder, stream());
|
auto b_t = transpose(primals[1], reorder, stream());
|
||||||
auto out_mask = primals[2];
|
auto out_mask =
|
||||||
auto lhs_mask =
|
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
|
||||||
has_op_mask ? std::make_optional<array>(primals[3]) : std::nullopt;
|
auto lhs_mask = has_op_mask && !needs_lhs_mask_vjp
|
||||||
|
? std::make_optional<array>(primals[op_mask_idx])
|
||||||
|
: std::nullopt;
|
||||||
auto rhs_mask_t = has_op_mask
|
auto rhs_mask_t = has_op_mask
|
||||||
? std::make_optional<array>(transpose(primals[4], reorder, stream()))
|
? std::make_optional<array>(
|
||||||
|
transpose(primals[op_mask_idx + 1], reorder, stream()))
|
||||||
: std::nullopt;
|
: std::nullopt;
|
||||||
|
|
||||||
auto grad = block_masked_mm(
|
auto grad = block_masked_mm(
|
||||||
cotan, b_t, block_size_, lhs_mask, out_mask, rhs_mask_t, stream());
|
cotan, b_t, block_size_, lhs_mask, out_mask, rhs_mask_t, stream());
|
||||||
|
|
||||||
|
if (needs_lhs_mask_vjp) {
|
||||||
|
unmasked_lhs_grad = grad;
|
||||||
|
unmasked_lhs_grad_calculated = true;
|
||||||
|
auto exp_mask = expand_mask(primals[op_mask_idx], M, K);
|
||||||
|
grad = multiply(grad, exp_mask, stream());
|
||||||
|
}
|
||||||
|
|
||||||
vjps.push_back(grad);
|
vjps.push_back(grad);
|
||||||
|
|
||||||
} else if (arg == 1) {
|
} else if (arg == 1) {
|
||||||
// (M X K).T * M X N -> K X N
|
// (M X K).T * M X N -> K X N
|
||||||
auto a_t = transpose(primals[0], reorder, stream());
|
auto a_t = transpose(primals[0], reorder, stream());
|
||||||
auto out_mask = primals[2];
|
auto out_mask =
|
||||||
|
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
|
||||||
auto lhs_mask_t = has_op_mask
|
auto lhs_mask_t = has_op_mask
|
||||||
? std::make_optional<array>(transpose(primals[3], reorder, stream()))
|
? std::make_optional<array>(
|
||||||
|
transpose(primals[op_mask_idx], reorder, stream()))
|
||||||
|
: std::nullopt;
|
||||||
|
auto rhs_mask = has_op_mask && !needs_rhs_mask_vjp
|
||||||
|
? std::make_optional<array>(primals[op_mask_idx + 1])
|
||||||
: std::nullopt;
|
: std::nullopt;
|
||||||
auto rhs_mask =
|
|
||||||
has_op_mask ? std::make_optional<array>(primals[4]) : std::nullopt;
|
|
||||||
|
|
||||||
auto grad = block_masked_mm(
|
auto grad = block_masked_mm(
|
||||||
a_t, cotan, block_size_, rhs_mask, lhs_mask_t, out_mask, stream());
|
a_t, cotan, block_size_, rhs_mask, lhs_mask_t, out_mask, stream());
|
||||||
|
|
||||||
|
if (needs_rhs_mask_vjp) {
|
||||||
|
unmasked_rhs_grad = grad;
|
||||||
|
unmasked_rhs_grad_calculated = true;
|
||||||
|
auto exp_mask = expand_mask(primals[op_mask_idx + 1], K, N);
|
||||||
|
grad = multiply(grad, exp_mask, stream());
|
||||||
|
}
|
||||||
|
|
||||||
vjps.push_back(grad);
|
vjps.push_back(grad);
|
||||||
|
|
||||||
|
} else if (arg == 2 && has_out_mask) {
|
||||||
|
// Produce the forward result
|
||||||
|
auto lhs_mask = has_op_mask
|
||||||
|
? std::make_optional<array>(primals[op_mask_idx])
|
||||||
|
: std::nullopt;
|
||||||
|
auto rhs_mask = has_op_mask
|
||||||
|
? std::make_optional<array>(primals[op_mask_idx + 1])
|
||||||
|
: std::nullopt;
|
||||||
|
|
||||||
|
auto C = block_masked_mm(
|
||||||
|
primals[0],
|
||||||
|
primals[1],
|
||||||
|
block_size_,
|
||||||
|
primals[2],
|
||||||
|
lhs_mask,
|
||||||
|
rhs_mask,
|
||||||
|
stream());
|
||||||
|
|
||||||
|
// Multiply, Pad and Reduce if needed
|
||||||
|
auto grad = multiply_pad_reduce(cotan, C, align_M, align_N);
|
||||||
|
vjps.push_back(grad);
|
||||||
|
|
||||||
|
} else if (arg == op_mask_idx && has_op_mask) {
|
||||||
|
if (!unmasked_lhs_grad_calculated) {
|
||||||
|
// (M X K).T * M X N -> K X N
|
||||||
|
auto b_t = transpose(primals[1], reorder, stream());
|
||||||
|
auto out_mask =
|
||||||
|
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
|
||||||
|
auto rhs_mask_t =
|
||||||
|
transpose(primals[op_mask_idx + 1], reorder, stream());
|
||||||
|
|
||||||
|
unmasked_lhs_grad = block_masked_mm(
|
||||||
|
cotan,
|
||||||
|
b_t,
|
||||||
|
block_size_,
|
||||||
|
std::nullopt,
|
||||||
|
out_mask,
|
||||||
|
rhs_mask_t,
|
||||||
|
stream());
|
||||||
|
|
||||||
|
unmasked_lhs_grad_calculated = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply, Pad and Reduce if needed
|
||||||
|
auto grad =
|
||||||
|
multiply_pad_reduce(primals[0], unmasked_lhs_grad, align_M, align_K);
|
||||||
|
vjps.push_back(grad);
|
||||||
|
|
||||||
|
} else if (arg == op_mask_idx + 1 && has_op_mask) {
|
||||||
|
if (!unmasked_rhs_grad_calculated) {
|
||||||
|
// (M X K).T * M X N -> K X N
|
||||||
|
auto a_t = transpose(primals[0], reorder, stream());
|
||||||
|
auto out_mask =
|
||||||
|
has_out_mask ? std::make_optional<array>(primals[2]) : std::nullopt;
|
||||||
|
auto lhs_mask_t = transpose(primals[op_mask_idx], reorder, stream());
|
||||||
|
|
||||||
|
unmasked_rhs_grad = block_masked_mm(
|
||||||
|
a_t,
|
||||||
|
cotan,
|
||||||
|
block_size_,
|
||||||
|
std::nullopt,
|
||||||
|
lhs_mask_t,
|
||||||
|
out_mask,
|
||||||
|
stream());
|
||||||
|
|
||||||
|
unmasked_rhs_grad_calculated = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply, Pad and Reduce if needed
|
||||||
|
auto grad =
|
||||||
|
multiply_pad_reduce(primals[1], unmasked_rhs_grad, align_K, align_N);
|
||||||
|
vjps.push_back(grad);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[BlockMaskedMM] Cannot calculate VJP with respect to masks.");
|
"[BlockMaskedMM] Cannot calculate VJP with respect to masks.");
|
||||||
|
@ -682,7 +682,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(c.shape, (0, 0))
|
self.assertEqual(c.shape, (0, 0))
|
||||||
|
|
||||||
def test_block_masked_matmul(self):
|
def test_block_masked_matmul(self):
|
||||||
def np_block_masked_mm(
|
def ref_block_masked_mm(
|
||||||
a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None
|
a, b, block_size, out_mask=None, lhs_mask=None, rhs_mask=None
|
||||||
):
|
):
|
||||||
# Get mask adjusted shapes
|
# Get mask adjusted shapes
|
||||||
@ -690,33 +690,81 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
N = b.shape[-1]
|
N = b.shape[-1]
|
||||||
K = a.shape[-1]
|
K = a.shape[-1]
|
||||||
|
|
||||||
|
bsx_shape = np.broadcast_shapes(a.shape[:-2], b.shape[:-2])
|
||||||
|
|
||||||
# Expand mask dims
|
# Expand mask dims
|
||||||
def expand_mask(mask, block_size, Y, X):
|
def expand_mask(mask, block_size, Y, X):
|
||||||
mask = np.expand_dims(mask, (-3, -1))
|
mask = mx.expand_dims(mask, (-3, -1))
|
||||||
mask_shape = list(mask.shape)
|
mask_shape = list(bsx_shape) + list(mask.shape[-4:])
|
||||||
mask_shape[-1] = block_size
|
mask_shape[-1] = block_size
|
||||||
x = mask_shape[-2] * block_size
|
x = mask_shape[-2] * block_size
|
||||||
mask_shape[-3] = block_size
|
mask_shape[-3] = block_size
|
||||||
y = mask_shape[-4] * block_size
|
y = mask_shape[-4] * block_size
|
||||||
mask = np.broadcast_to(mask, mask_shape)
|
mask = mx.broadcast_to(mask, mask_shape)
|
||||||
mask_shape = mask_shape[:-4] + [y, x]
|
mask_shape = mask_shape[:-4] + [y, x]
|
||||||
return mask.reshape(mask_shape)[..., :Y, :X]
|
return mask.reshape(mask_shape)[..., :Y, :X]
|
||||||
|
|
||||||
|
a_masked = a
|
||||||
|
b_masked = b
|
||||||
|
|
||||||
if lhs_mask is not None:
|
if lhs_mask is not None:
|
||||||
lhs_mask = expand_mask(lhs_mask, block_size, M, K)
|
lhs_mask = expand_mask(lhs_mask, block_size, M, K).astype(mx.float32)
|
||||||
a = lhs_mask * a
|
a_masked = lhs_mask * a_masked
|
||||||
|
|
||||||
if rhs_mask is not None:
|
if rhs_mask is not None:
|
||||||
rhs_mask = expand_mask(rhs_mask, block_size, K, N)
|
rhs_mask = expand_mask(rhs_mask, block_size, K, N).astype(mx.float32)
|
||||||
b = rhs_mask * b
|
b_masked = rhs_mask * b_masked
|
||||||
|
|
||||||
out = a @ b
|
out = a_masked @ b_masked
|
||||||
|
|
||||||
if out_mask is not None:
|
if out_mask is not None:
|
||||||
out_mask = expand_mask(out_mask, block_size, M, N)
|
out_mask = expand_mask(out_mask, block_size, M, N).astype(mx.float32)
|
||||||
out = out * out_mask
|
out = out * out_mask
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def run_test(a, b, block_size, out_mask, a_mask, b_mask, cotan):
|
||||||
|
def f_ref(a_, b_):
|
||||||
|
return ref_block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask)
|
||||||
|
|
||||||
|
def f_test(a_, b_):
|
||||||
|
return mx.block_masked_mm(a_, b_, block_size, out_mask, a_mask, b_mask)
|
||||||
|
|
||||||
|
out_ref, dout_ref = mx.vjp(f_ref, [a, b], [cotan])
|
||||||
|
out_test, dout_test = mx.vjp(f_test, [a, b], [cotan])
|
||||||
|
|
||||||
|
mx.eval((out_ref, dout_ref, out_test, dout_test))
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
|
||||||
|
|
||||||
|
def run_test_mask_vjp(a, b, block_size, out_mask, a_mask, b_mask, cotan):
|
||||||
|
def f_ref(a_, b_, a_mask_, b_mask_):
|
||||||
|
return ref_block_masked_mm(
|
||||||
|
a_, b_, block_size, out_mask, a_mask_, b_mask_
|
||||||
|
)
|
||||||
|
|
||||||
|
def f_test(a_, b_, a_mask_, b_mask_):
|
||||||
|
return mx.block_masked_mm(
|
||||||
|
a_, b_, block_size, out_mask, a_mask_, b_mask_
|
||||||
|
)
|
||||||
|
|
||||||
|
out_ref, dout_ref = mx.vjp(f_ref, [a, b, a_mask, b_mask], [cotan])
|
||||||
|
out_test, dout_test = mx.vjp(f_test, [a, b, a_mask, b_mask], [cotan])
|
||||||
|
|
||||||
|
mx.eval((out_ref, dout_ref, out_test, dout_test))
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
|
||||||
|
|
||||||
|
for r, t in zip(dout_ref, dout_test):
|
||||||
|
self.assertEqual(r.shape, t.shape)
|
||||||
|
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||||
|
|
||||||
|
def make_mask(tm_, tn_, batch, np_dtype):
|
||||||
|
arr_np_mask = np.random.normal(size=batch + (tm_, tn_)).astype(np_dtype)
|
||||||
|
arr_np_bool_mask = arr_np_mask < 0.0
|
||||||
|
arr_np_mask[arr_np_bool_mask] = 0.0
|
||||||
|
|
||||||
|
return mx.array(arr_np_bool_mask), mx.array(arr_np_mask)
|
||||||
|
|
||||||
def test_shape(
|
def test_shape(
|
||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
@ -737,49 +785,49 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
batch_A=batch_A,
|
batch_A=batch_A,
|
||||||
batch_B=batch_B,
|
batch_B=batch_B,
|
||||||
):
|
):
|
||||||
tm = (M + block_size - 1) // block_size
|
batch_out = np.broadcast_shapes(batch_A, batch_B)
|
||||||
tn = (N + block_size - 1) // block_size
|
cotan = mx.ones(batch_out + (M, N))
|
||||||
tk = (K + block_size - 1) // block_size
|
|
||||||
|
|
||||||
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
|
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
|
||||||
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
|
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
|
||||||
|
|
||||||
batch_out = np.broadcast_shapes(batch_A, batch_B)
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
a_np_mask = np.random.normal(size=batch_A + (tm, tk)) < 0.0
|
tm = (M + block_size - 1) // block_size
|
||||||
b_np_mask = np.random.normal(size=batch_B + (tk, tn)) < 0.0
|
tn = (N + block_size - 1) // block_size
|
||||||
out_np_mask = np.random.normal(size=batch_out + (tm, tn)) < 0.0
|
tk = (K + block_size - 1) // block_size
|
||||||
|
|
||||||
a_mx, b_mx, a_mx_mask, b_mx_mask, out_mx_mask = map(
|
a_mx_bool_mask, a_mx_mask = make_mask(tm, tk, batch_A, np_dtype)
|
||||||
mx.array, (a_np, b_np, a_np_mask, b_np_mask, out_np_mask)
|
b_mx_bool_mask, b_mx_mask = make_mask(tk, tn, batch_B, np_dtype)
|
||||||
|
out_mx_bool_mask, out_mx_mask = make_mask(tm, tn, batch_out, np_dtype)
|
||||||
|
|
||||||
|
# Boolean block masks
|
||||||
|
run_test(
|
||||||
|
a_mx,
|
||||||
|
b_mx,
|
||||||
|
block_size,
|
||||||
|
out_mx_bool_mask,
|
||||||
|
a_mx_bool_mask,
|
||||||
|
b_mx_bool_mask,
|
||||||
|
cotan,
|
||||||
|
)
|
||||||
|
run_test(a_mx, b_mx, block_size, out_mx_bool_mask, None, None, cotan)
|
||||||
|
run_test(
|
||||||
|
a_mx, b_mx, block_size, None, a_mx_bool_mask, b_mx_bool_mask, cotan
|
||||||
)
|
)
|
||||||
|
|
||||||
if transpose:
|
# Float block masks
|
||||||
b_np = np.random.normal(size=batch_B + (N, K)).astype(np_dtype)
|
run_test(
|
||||||
b_mx = mx.array(b_np)
|
a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan
|
||||||
|
|
||||||
b_np = np.swapaxes(b_np, -2, -1)
|
|
||||||
b_mx = mx.swapaxes(b_mx, -2, -1)
|
|
||||||
|
|
||||||
out_np = np_block_masked_mm(
|
|
||||||
a_np, b_np, block_size, out_np_mask, a_np_mask, b_np_mask
|
|
||||||
)
|
)
|
||||||
out_mx = mx.block_masked_mm(
|
run_test(a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan)
|
||||||
a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask
|
run_test_mask_vjp(
|
||||||
|
a_mx, b_mx, block_size, out_mx_mask, a_mx_mask, b_mx_mask, cotan
|
||||||
)
|
)
|
||||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
run_test_mask_vjp(
|
||||||
|
a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask, cotan
|
||||||
out_np = np_block_masked_mm(a_np, b_np, block_size, out_np_mask)
|
|
||||||
out_mx = mx.block_masked_mm(a_mx, b_mx, block_size, out_mx_mask)
|
|
||||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
|
||||||
|
|
||||||
out_np = np_block_masked_mm(
|
|
||||||
a_np, b_np, block_size, None, a_np_mask, b_np_mask
|
|
||||||
)
|
)
|
||||||
out_mx = mx.block_masked_mm(
|
|
||||||
a_mx, b_mx, block_size, None, a_mx_mask, b_mx_mask
|
|
||||||
)
|
|
||||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
|
||||||
|
|
||||||
shapes = (
|
shapes = (
|
||||||
(16, 16, 16, 32),
|
(16, 16, 16, 32),
|
||||||
@ -789,11 +837,10 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for M, N, K, block_size in shapes:
|
for M, N, K, block_size in shapes:
|
||||||
test_shape(M, N, K, block_size, transpose=False)
|
test_shape(M, N, K, block_size)
|
||||||
test_shape(M, N, K, block_size, transpose=True)
|
|
||||||
|
|
||||||
# Test broadcasting
|
# Test broadcasting
|
||||||
test_shape(64, 64, 64, 32, transpose=False, batch_A=(1, 2), batch_B=(2, 2))
|
test_shape(64, 64, 64, 32, batch_A=(1, 2), batch_B=(2, 2))
|
||||||
|
|
||||||
# Test gemv
|
# Test gemv
|
||||||
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
|
a_np = np.random.normal(size=(64, 64)).astype(np.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user