mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes * some cleanup * add utils * fix * fix 2 * more cleaning * fix * remove unused mps matmul support * one more nit * revert
This commit is contained in:
@@ -161,6 +161,8 @@ TEST_CASE("test array types") {
|
||||
// bfloat16
|
||||
{ basic_dtype_test(bfloat16_t, bfloat16); }
|
||||
|
||||
#undef basic_dtype_test
|
||||
|
||||
// uint32
|
||||
{
|
||||
uint32_t val = UINT_MAX;
|
||||
@@ -233,31 +235,6 @@ TEST_CASE("test array types") {
|
||||
CHECK_EQ(x.dtype(), complex64);
|
||||
CHECK_EQ(x.item<complex64_t>(), v);
|
||||
}
|
||||
|
||||
#undef basic_dtype_test
|
||||
|
||||
#define basic_dtype_str_test(s, dtype) \
|
||||
CHECK_EQ(s, dtype_to_array_protocol(dtype)); \
|
||||
CHECK_EQ(dtype_from_array_protocol(s), dtype);
|
||||
|
||||
// To and from str
|
||||
{
|
||||
basic_dtype_str_test("|b1", bool_);
|
||||
basic_dtype_str_test("|u1", uint8);
|
||||
basic_dtype_str_test("<u2", uint16);
|
||||
basic_dtype_str_test("<u4", uint32);
|
||||
basic_dtype_str_test("<u8", uint64);
|
||||
basic_dtype_str_test("|i1", int8);
|
||||
basic_dtype_str_test("<i2", int16);
|
||||
basic_dtype_str_test("<i4", int32);
|
||||
basic_dtype_str_test("<i8", int64);
|
||||
basic_dtype_str_test("<f2", float16);
|
||||
basic_dtype_str_test("<f4", float32);
|
||||
basic_dtype_str_test("<V2", bfloat16);
|
||||
basic_dtype_str_test("<c8", complex64);
|
||||
}
|
||||
|
||||
#undef basic_dtype_str_test
|
||||
}
|
||||
|
||||
TEST_CASE("test array metadata") {
|
||||
|
Reference in New Issue
Block a user