mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
fix array from list for > 32 bit types (#501)
This commit is contained in:
@@ -229,9 +229,28 @@ array array_from_list(
|
||||
return array(vals.begin(), shape, specified_type.value_or(bool_));
|
||||
}
|
||||
case pyint: {
|
||||
std::vector<int> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, specified_type.value_or(int32));
|
||||
auto dtype = specified_type.value_or(int32);
|
||||
if (dtype == int64) {
|
||||
std::vector<int64_t> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
} else if (dtype == uint64) {
|
||||
std::vector<uint64_t> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
} else if (dtype == uint32) {
|
||||
std::vector<uint32_t> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
} else if (is_floating_point(dtype)) {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
} else {
|
||||
std::vector<int> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype);
|
||||
}
|
||||
}
|
||||
case pyfloat: {
|
||||
std::vector<float> vals;
|
||||
|
Reference in New Issue
Block a user