* start of C++ docs

* fix stream doc

* only include ops for now
This commit is contained in:
Awni Hannun 2024-04-26 12:56:05 -07:00 committed by GitHub
parent 82463e9938
commit 5bfe89bdb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 87 additions and 20 deletions

50
docs/Doxyfile Normal file
View File

@ -0,0 +1,50 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@ -2,12 +2,16 @@
### Setup (do once)
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
for example with `conda`:
Install Doxygen:
```
conda install sphinx
pip install sphinx-book-theme
brew install doxygen
```
Install Python packages:
```
pip install -r requirements.txt
```
### Build
@ -15,7 +19,7 @@ pip install sphinx-book-theme
Build the docs from `mlx/docs/`
```
make html
doxygen && make html
```
View the docs by running a server in `mlx/docs/build/html/`:

3
docs/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
sphinx
breathe
sphinx-book-theme

View File

@ -22,6 +22,7 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"breathe",
]
python_use_unqualified_type_names = True
@ -33,6 +34,9 @@ intersphinx_mapping = {
"numpy": ("https://numpy.org/doc/stable/", None),
}
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"]
html_static_path = ["_static"]
source_suffix = ".rst"

View File

@ -3,4 +3,5 @@
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@ -165,6 +165,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr};
}
// More helpful message if maximum buffer length is exceeded
if (size > device_->maxBufferLength()) {
std::ostringstream msg;
msg << "Attempting to allocate " << size << " bytes which is greater than"
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
<< " bytes.";
throw std::runtime_error(msg.str());
}
// Align up memory
if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);

View File

@ -204,7 +204,7 @@ class ScaledDotProductAttention : public Custom {
void eval_gpu(const std::vector<array>& inputs, array& out);
bool is_equivalent(const Primitive& other) const override;
DEFINE_PRINT(ScaledDotProductAttention)
DEFINE_PRINT(ScaledDotProductAttention);
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;

View File

@ -11,7 +11,10 @@
namespace mlx::core {
/** Creation operations */
/**
* \defgroup ops Core array operations
* @{
*/
/**
* A 1D array of numbers starting at `start` (optional),
@ -115,8 +118,6 @@ inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
array tril(array x, int k = 0, StreamOrDevice s = {});
array triu(array x, int k = 0, StreamOrDevice s = {});
/** array manipulation */
/** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
@ -289,8 +290,6 @@ std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
StreamOrDevice s = {});
/** Comparison operations */
/** Returns the bool array with (a == b) element-wise. */
array equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator==(const array& a, const array& b) {
@ -401,8 +400,6 @@ array where(
const array& y,
StreamOrDevice s = {});
/** Reduction operations */
/** True if all elements in the array are true (or non-zero). **/
array all(const array& a, bool keepdims, StreamOrDevice s = {});
inline array all(const array& a, StreamOrDevice s = {}) {
@ -710,8 +707,6 @@ array logsumexp(
bool keepdims = false,
StreamOrDevice s = {});
/** Simple arithmetic operations */
/** Absolute value of elements in an array. */
array abs(const array& a, StreamOrDevice s = {});
@ -1076,8 +1071,6 @@ array cummin(
bool inclusive = true,
StreamOrDevice s = {});
/** Convolution operations */
/** General convolution with a filter */
array conv_general(
array input,
@ -1246,4 +1239,6 @@ array number_of_elements(
Dtype dtype = int32,
StreamOrDevice s = {});
/** @} */
} // namespace mlx::core

View File

@ -139,7 +139,8 @@ void init_stream(nb::module_& m) {
Synchronize with the given stream.
Args:
(Stream, optional): The stream to synchronize with. If ``None`` then
the default stream of the default device is used. Default: ``None``.
stream (Stream, optional): The stream to synchronize with. If ``None``
then the default stream of the default device is used.
Default: ``None``.
)pbdoc");
}