mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
add back conditionaltype (#1655)
This commit is contained in:
parent
bc2a29f033
commit
95c4a2e3af
@ -854,15 +854,17 @@ METAL_FUNC void qvm_impl(
|
|||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||||
constexpr int tn = 32 / pack_factor;
|
constexpr int tn = 32 / pack_factor;
|
||||||
constexpr int block_size = SIMD_SIZE;
|
constexpr int block_size = SIMD_SIZE;
|
||||||
|
|
||||||
const device uint8_t* ws = (const device uint8_t*)w;
|
using W_T =
|
||||||
|
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
||||||
|
const device W_T* ws = (const device W_T*)w;
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint8_t wi[tn * bytes_per_pack];
|
W_T wi[tn * bytes_per_pack];
|
||||||
} vec_w;
|
} vec_w;
|
||||||
|
|
||||||
thread vec_w w_local;
|
thread vec_w w_local;
|
||||||
|
@ -421,3 +421,14 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
|
|||||||
return complex64_t(
|
return complex64_t(
|
||||||
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
|
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// std::conditional is not included with Metal
|
||||||
|
template <bool condition, typename T, typename U>
|
||||||
|
struct ConditionalType {
|
||||||
|
using type = U;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct ConditionalType<true, T, U> {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user