mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Merge branch 'main' into stft
This commit is contained in:
@@ -97,8 +97,7 @@ TEST_CASE("test export primitives with state") {
|
||||
TEST_CASE("test export functions with kwargs") {
|
||||
std::string file_path = get_temp_file("model.mlxfn");
|
||||
|
||||
auto fun =
|
||||
[](const std::map<std::string, array>& kwargs) -> std::vector<array> {
|
||||
auto fun = [](const Kwargs& kwargs) -> std::vector<array> {
|
||||
return {kwargs.at("x") + kwargs.at("y")};
|
||||
};
|
||||
|
||||
|
||||
@@ -3874,3 +3874,41 @@ TEST_CASE("test contiguous") {
|
||||
CHECK(x.flags().col_contiguous);
|
||||
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
|
||||
}
|
||||
|
||||
TEST_CASE("test bitwise shift operations") {
|
||||
std::vector<Dtype> dtypes = {
|
||||
int8, int16, int32, int64, uint8, uint16, uint32, uint64};
|
||||
|
||||
for (const auto& dtype : dtypes) {
|
||||
array x = full({4}, 1, dtype);
|
||||
array y = full({4}, 2, dtype);
|
||||
|
||||
auto left_shift_result = left_shift(x, y);
|
||||
CHECK_EQ(left_shift_result.dtype(), dtype);
|
||||
CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype))
|
||||
.item<bool>());
|
||||
|
||||
auto right_shift_result = right_shift(full({4}, 4, dtype), y);
|
||||
CHECK_EQ(right_shift_result.dtype(), dtype);
|
||||
CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item<bool>());
|
||||
}
|
||||
|
||||
array x = array({127, -128}, int8);
|
||||
array y = array({1, 1}, int8);
|
||||
auto left_shift_result = left_shift(x, y);
|
||||
auto right_shift_result = right_shift(x, y);
|
||||
|
||||
CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item<bool>());
|
||||
CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item<bool>());
|
||||
|
||||
array x_bool = full({4}, true, bool_);
|
||||
array y_bool = full({4}, true, bool_);
|
||||
auto left_shift_bool_result = left_shift(x_bool, y_bool);
|
||||
auto right_shift_bool_result = right_shift(x_bool, y_bool);
|
||||
|
||||
CHECK_EQ(left_shift_bool_result.dtype(), uint8);
|
||||
CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item<bool>());
|
||||
|
||||
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
|
||||
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
|
||||
}
|
||||
Reference in New Issue
Block a user