mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Use async cuda malloc managed with cuda 13 * add pool threshold * refactor for regular cuda malloc * load eval gpu for cuda * remove use of cuda pool, use cuda free async * fix * fix * fix * fix * fix + comment
78 lines
2.2 KiB
C++
78 lines
2.2 KiB
C++
// Copyright © 2023-2024 Apple Inc.
|
|
#pragma once
|
|
|
|
#include <functional>
|
|
#include <iomanip>
|
|
|
|
#include "mlx/array.h"
|
|
#include "mlx/primitives.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
inline bool is_static_cast(const Primitive& p) {
|
|
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
|
}
|
|
|
|
std::string get_type_string(Dtype d);
|
|
|
|
template <typename T>
|
|
void print_float_constant(std::ostream& os, const array& x) {
|
|
auto old_precision = os.precision();
|
|
if constexpr (std::is_same_v<T, double>) {
|
|
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
|
} else {
|
|
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
|
}
|
|
os << x.item<T>() << std::setprecision(old_precision);
|
|
}
|
|
|
|
template <typename T>
|
|
void print_int_constant(std::ostream& os, const array& x) {
|
|
os << x.item<T>();
|
|
}
|
|
|
|
template <typename T>
|
|
void print_complex_constant(std::ostream& os, const array& x) {
|
|
auto old_precision = os.precision();
|
|
T constant = x.item<T>();
|
|
|
|
os << get_type_string(x.dtype()) << "("
|
|
<< std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
|
<< constant.real() << ", " << constant.imag() << ")"
|
|
<< std::setprecision(old_precision);
|
|
}
|
|
|
|
void print_constant(std::ostream& os, const array& x);
|
|
|
|
inline bool is_scalar(const array& x) {
|
|
return x.ndim() == 0;
|
|
}
|
|
|
|
// Check if we can use a contiguous operation given inputs and the output shape
|
|
bool compiled_check_contiguity(
|
|
const std::vector<array>& inputs,
|
|
const Shape& shape);
|
|
|
|
// Allocate space for the outputs possibly with input donation
|
|
void compiled_allocate_outputs(
|
|
const std::vector<array>& inputs,
|
|
std::vector<array>& outputs,
|
|
const std::function<bool(size_t)>& is_constant,
|
|
bool contiguous,
|
|
const std::function<allocator::Buffer(size_t)>& mallocfn =
|
|
allocator::malloc);
|
|
|
|
// Collapse contiguous dims ignoring scalars and constants.
|
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
|
const std::vector<array>& inputs,
|
|
const array& out,
|
|
const std::function<bool(size_t)>& is_constant);
|
|
|
|
// Return whether the kernel should use large index.
|
|
bool compiled_use_large_index(
|
|
const std::vector<array>& inputs,
|
|
const std::vector<array>& outputs,
|
|
bool contiguous);
|
|
|
|
} // namespace mlx::core
|