mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous * style * nits in docs * a bunch more stuff * update jit * update jit * use constant for shapes and strides and remove elem_to_loc overload * use kernel instantiation * docs nits * update binary and ternary * comments
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
|
||||
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void ternary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
@@ -26,11 +26,21 @@ void ternary_op_gpu_inplace(
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_c = strides[2];
|
||||
auto& strides_out = strides[3];
|
||||
auto maybe_collapse = [topt, &a, &b, &c, &out]() {
|
||||
if (topt == TernaryOpType::General) {
|
||||
// The size cap here should ideally be `UINT32_MAX` but we are
|
||||
// limitied by the shape being an int.
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
{a, b, c, out},
|
||||
/* size_cap = */ INT32_MAX);
|
||||
return std::make_tuple(
|
||||
shape, strides[0], strides[1], strides[2], strides[3]);
|
||||
} else {
|
||||
std::vector<size_t> e;
|
||||
return std::make_tuple(std::vector<int>{}, e, e, e, e);
|
||||
}
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
||||
|
||||
bool use_2d = out.data_size() > UINT_MAX;
|
||||
std::string kernel_name;
|
||||
@@ -88,7 +98,7 @@ void ternary_op_gpu_inplace(
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
|
||||
Reference in New Issue
Block a user