More fixes for arrays with large sizes (#1405)

* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
This commit is contained in:
Awni Hannun
2024-09-17 12:46:31 -07:00
committed by GitHub
parent c6739ba7f3
commit 4f46e9c997
26 changed files with 325 additions and 611 deletions

View File

@@ -8,7 +8,7 @@
namespace mlx::core {
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 3;
void ternary_op_gpu_inplace(
const std::vector<array>& inputs,
@@ -26,11 +26,21 @@ void ternary_op_gpu_inplace(
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_c = strides[2];
auto& strides_out = strides[3];
auto maybe_collapse = [topt, &a, &b, &c, &out]() {
if (topt == TernaryOpType::General) {
// The size cap here should ideally be `UINT32_MAX` but we are
// limitied by the shape being an int.
auto [shape, strides] = collapse_contiguous_dims(
{a, b, c, out},
/* size_cap = */ INT32_MAX);
return std::make_tuple(
shape, strides[0], strides[1], strides[2], strides[3]);
} else {
std::vector<size_t> e;
return std::make_tuple(std::vector<int>{}, e, e, e, e);
}
};
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT_MAX;
std::string kernel_name;
@@ -88,7 +98,7 @@ void ternary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
}
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);