mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Cpp docs (#1036)
* start of C++ docs * fix stream doc * only include ops for now
This commit is contained in:
@@ -165,6 +165,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
// More helpful message if maximum buffer length is exceeded
|
||||
if (size > device_->maxBufferLength()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Attempting to allocate " << size << " bytes which is greater than"
|
||||
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
|
||||
<< " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Align up memory
|
||||
if (size > vm_page_size) {
|
||||
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
||||
|
@@ -204,7 +204,7 @@ class ScaledDotProductAttention : public Custom {
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out);
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
DEFINE_PRINT(ScaledDotProductAttention)
|
||||
DEFINE_PRINT(ScaledDotProductAttention);
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
|
17
mlx/ops.h
17
mlx/ops.h
@@ -11,7 +11,10 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/** Creation operations */
|
||||
/**
|
||||
* \defgroup ops Core array operations
|
||||
* @{
|
||||
*/
|
||||
|
||||
/**
|
||||
* A 1D array of numbers starting at `start` (optional),
|
||||
@@ -115,8 +118,6 @@ inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
|
||||
array tril(array x, int k = 0, StreamOrDevice s = {});
|
||||
array triu(array x, int k = 0, StreamOrDevice s = {});
|
||||
|
||||
/** array manipulation */
|
||||
|
||||
/** Reshape an array to the given shape. */
|
||||
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
||||
|
||||
@@ -289,8 +290,6 @@ std::vector<array> broadcast_arrays(
|
||||
const std::vector<array>& inputs,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Comparison operations */
|
||||
|
||||
/** Returns the bool array with (a == b) element-wise. */
|
||||
array equal(const array& a, const array& b, StreamOrDevice s = {});
|
||||
inline array operator==(const array& a, const array& b) {
|
||||
@@ -401,8 +400,6 @@ array where(
|
||||
const array& y,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Reduction operations */
|
||||
|
||||
/** True if all elements in the array are true (or non-zero). **/
|
||||
array all(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||
inline array all(const array& a, StreamOrDevice s = {}) {
|
||||
@@ -710,8 +707,6 @@ array logsumexp(
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Simple arithmetic operations */
|
||||
|
||||
/** Absolute value of elements in an array. */
|
||||
array abs(const array& a, StreamOrDevice s = {});
|
||||
|
||||
@@ -1076,8 +1071,6 @@ array cummin(
|
||||
bool inclusive = true,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Convolution operations */
|
||||
|
||||
/** General convolution with a filter */
|
||||
array conv_general(
|
||||
array input,
|
||||
@@ -1246,4 +1239,6 @@ array number_of_elements(
|
||||
Dtype dtype = int32,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** @} */
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user