* start of C++ docs

* fix stream doc

* only include ops for now
This commit is contained in:
Awni Hannun
2024-04-26 12:56:05 -07:00
committed by GitHub
parent 82463e9938
commit 5bfe89bdb1
9 changed files with 87 additions and 20 deletions

View File

@@ -165,6 +165,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr};
}
// More helpful message if maximum buffer length is exceeded
if (size > device_->maxBufferLength()) {
std::ostringstream msg;
msg << "Attempting to allocate " << size << " bytes which is greater than"
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
<< " bytes.";
throw std::runtime_error(msg.str());
}
// Align up memory
if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);

View File

@@ -204,7 +204,7 @@ class ScaledDotProductAttention : public Custom {
void eval_gpu(const std::vector<array>& inputs, array& out);
bool is_equivalent(const Primitive& other) const override;
DEFINE_PRINT(ScaledDotProductAttention)
DEFINE_PRINT(ScaledDotProductAttention);
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;

View File

@@ -11,7 +11,10 @@
namespace mlx::core {
/** Creation operations */
/**
* \defgroup ops Core array operations
* @{
*/
/**
* A 1D array of numbers starting at `start` (optional),
@@ -115,8 +118,6 @@ inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
array tril(array x, int k = 0, StreamOrDevice s = {});
array triu(array x, int k = 0, StreamOrDevice s = {});
/** array manipulation */
/** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
@@ -289,8 +290,6 @@ std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
StreamOrDevice s = {});
/** Comparison operations */
/** Returns the bool array with (a == b) element-wise. */
array equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator==(const array& a, const array& b) {
@@ -401,8 +400,6 @@ array where(
const array& y,
StreamOrDevice s = {});
/** Reduction operations */
/** True if all elements in the array are true (or non-zero). **/
array all(const array& a, bool keepdims, StreamOrDevice s = {});
inline array all(const array& a, StreamOrDevice s = {}) {
@@ -710,8 +707,6 @@ array logsumexp(
bool keepdims = false,
StreamOrDevice s = {});
/** Simple arithmetic operations */
/** Absolute value of elements in an array. */
array abs(const array& a, StreamOrDevice s = {});
@@ -1076,8 +1071,6 @@ array cummin(
bool inclusive = true,
StreamOrDevice s = {});
/** Convolution operations */
/** General convolution with a filter */
array conv_general(
array input,
@@ -1246,4 +1239,6 @@ array number_of_elements(
Dtype dtype = int32,
StreamOrDevice s = {});
/** @} */
} // namespace mlx::core