Add gemv masked to JIT plus some fixes (#1310)

* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
This commit is contained in:
Awni Hannun
2024-08-07 13:38:07 -07:00
committed by GitHub
parent 635ccd9e25
commit 30bbea2f08
25 changed files with 1230 additions and 1702 deletions

View File

@@ -161,6 +161,8 @@ TEST_CASE("test array types") {
// bfloat16
{ basic_dtype_test(bfloat16_t, bfloat16); }
#undef basic_dtype_test
// uint32
{
uint32_t val = UINT_MAX;
@@ -233,31 +235,6 @@ TEST_CASE("test array types") {
CHECK_EQ(x.dtype(), complex64);
CHECK_EQ(x.item<complex64_t>(), v);
}
#undef basic_dtype_test
#define basic_dtype_str_test(s, dtype) \
CHECK_EQ(s, dtype_to_array_protocol(dtype)); \
CHECK_EQ(dtype_from_array_protocol(s), dtype);
// To and from str
{
basic_dtype_str_test("|b1", bool_);
basic_dtype_str_test("|u1", uint8);
basic_dtype_str_test("<u2", uint16);
basic_dtype_str_test("<u4", uint32);
basic_dtype_str_test("<u8", uint64);
basic_dtype_str_test("|i1", int8);
basic_dtype_str_test("<i2", int16);
basic_dtype_str_test("<i4", int32);
basic_dtype_str_test("<i8", int64);
basic_dtype_str_test("<f2", float16);
basic_dtype_str_test("<f4", float32);
basic_dtype_str_test("<V2", bfloat16);
basic_dtype_str_test("<c8", complex64);
}
#undef basic_dtype_str_test
}
TEST_CASE("test array metadata") {