mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Cpp docs (#1036)
* start of C++ docs * fix stream doc * only include ops for now
This commit is contained in:
parent
82463e9938
commit
5bfe89bdb1
50
docs/Doxyfile
Normal file
50
docs/Doxyfile
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
################################################################################
|
||||||
|
# Primary project setup. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
PROJECT_NAME = "MLX"
|
||||||
|
OUTPUT_DIRECTORY = build
|
||||||
|
XML_OUTPUT = xml
|
||||||
|
HTML_OUTPUT = html
|
||||||
|
STRIP_FROM_PATH = ../
|
||||||
|
INPUT = ../mlx
|
||||||
|
FILE_PATTERNS = *.h
|
||||||
|
EXCLUDE_PATTERNS = */private/*
|
||||||
|
CREATE_SUBDIRS = NO
|
||||||
|
FULL_PATH_NAMES = YES
|
||||||
|
RECURSIVE = YES
|
||||||
|
GENERATE_HTML = YES
|
||||||
|
GENERATE_LATEX = NO
|
||||||
|
GENERATE_XML = YES
|
||||||
|
XML_PROGRAMLISTING = YES
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Doxygen preprocessor / parser control. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
ENABLE_PREPROCESSING = YES
|
||||||
|
MACRO_EXPANSION = YES
|
||||||
|
EXPAND_ONLY_PREDEF = NO
|
||||||
|
SKIP_FUNCTION_MACROS = NO
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Compound extraction control. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
EXTRACT_ALL = YES
|
||||||
|
EXTRACT_PACKAGE = YES
|
||||||
|
EXTRACT_STATIC = YES
|
||||||
|
CASE_SENSE_NAMES = NO
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Docstring control / customization. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
JAVADOC_AUTOBRIEF = YES
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Warning suppression. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
QUIET = YES
|
||||||
|
WARN_IF_UNDOCUMENTED = NO
|
@ -2,12 +2,16 @@
|
|||||||
|
|
||||||
### Setup (do once)
|
### Setup (do once)
|
||||||
|
|
||||||
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
|
Install Doxygen:
|
||||||
for example with `conda`:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
conda install sphinx
|
brew install doxygen
|
||||||
pip install sphinx-book-theme
|
```
|
||||||
|
|
||||||
|
Install Python packages:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Build
|
### Build
|
||||||
@ -15,7 +19,7 @@ pip install sphinx-book-theme
|
|||||||
Build the docs from `mlx/docs/`
|
Build the docs from `mlx/docs/`
|
||||||
|
|
||||||
```
|
```
|
||||||
make html
|
doxygen && make html
|
||||||
```
|
```
|
||||||
|
|
||||||
View the docs by running a server in `mlx/docs/build/html/`:
|
View the docs by running a server in `mlx/docs/build/html/`:
|
||||||
|
3
docs/requirements.txt
Normal file
3
docs/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
sphinx
|
||||||
|
breathe
|
||||||
|
sphinx-book-theme
|
@ -22,6 +22,7 @@ extensions = [
|
|||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
"sphinx.ext.napoleon",
|
"sphinx.ext.napoleon",
|
||||||
|
"breathe",
|
||||||
]
|
]
|
||||||
|
|
||||||
python_use_unqualified_type_names = True
|
python_use_unqualified_type_names = True
|
||||||
@ -33,6 +34,9 @@ intersphinx_mapping = {
|
|||||||
"numpy": ("https://numpy.org/doc/stable/", None),
|
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
breathe_projects = {"mlx": "../build/xml"}
|
||||||
|
breathe_default_project = "mlx"
|
||||||
|
|
||||||
templates_path = ["_templates"]
|
templates_path = ["_templates"]
|
||||||
html_static_path = ["_static"]
|
html_static_path = ["_static"]
|
||||||
source_suffix = ".rst"
|
source_suffix = ".rst"
|
||||||
|
@ -3,4 +3,5 @@
|
|||||||
Operations
|
Operations
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
.. doxygengroup:: ops
|
||||||
|
:content-only:
|
||||||
|
@ -165,6 +165,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
return Buffer{nullptr};
|
return Buffer{nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// More helpful message if maximum buffer length is exceeded
|
||||||
|
if (size > device_->maxBufferLength()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Attempting to allocate " << size << " bytes which is greater than"
|
||||||
|
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
|
||||||
|
<< " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
// Align up memory
|
// Align up memory
|
||||||
if (size > vm_page_size) {
|
if (size > vm_page_size) {
|
||||||
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
||||||
|
@ -204,7 +204,7 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
void eval_gpu(const std::vector<array>& inputs, array& out);
|
void eval_gpu(const std::vector<array>& inputs, array& out);
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
DEFINE_PRINT(ScaledDotProductAttention)
|
DEFINE_PRINT(ScaledDotProductAttention);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
|
17
mlx/ops.h
17
mlx/ops.h
@ -11,7 +11,10 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
/** Creation operations */
|
/**
|
||||||
|
* \defgroup ops Core array operations
|
||||||
|
* @{
|
||||||
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A 1D array of numbers starting at `start` (optional),
|
* A 1D array of numbers starting at `start` (optional),
|
||||||
@ -115,8 +118,6 @@ inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
|
|||||||
array tril(array x, int k = 0, StreamOrDevice s = {});
|
array tril(array x, int k = 0, StreamOrDevice s = {});
|
||||||
array triu(array x, int k = 0, StreamOrDevice s = {});
|
array triu(array x, int k = 0, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** array manipulation */
|
|
||||||
|
|
||||||
/** Reshape an array to the given shape. */
|
/** Reshape an array to the given shape. */
|
||||||
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
||||||
|
|
||||||
@ -289,8 +290,6 @@ std::vector<array> broadcast_arrays(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Comparison operations */
|
|
||||||
|
|
||||||
/** Returns the bool array with (a == b) element-wise. */
|
/** Returns the bool array with (a == b) element-wise. */
|
||||||
array equal(const array& a, const array& b, StreamOrDevice s = {});
|
array equal(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
inline array operator==(const array& a, const array& b) {
|
inline array operator==(const array& a, const array& b) {
|
||||||
@ -401,8 +400,6 @@ array where(
|
|||||||
const array& y,
|
const array& y,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Reduction operations */
|
|
||||||
|
|
||||||
/** True if all elements in the array are true (or non-zero). **/
|
/** True if all elements in the array are true (or non-zero). **/
|
||||||
array all(const array& a, bool keepdims, StreamOrDevice s = {});
|
array all(const array& a, bool keepdims, StreamOrDevice s = {});
|
||||||
inline array all(const array& a, StreamOrDevice s = {}) {
|
inline array all(const array& a, StreamOrDevice s = {}) {
|
||||||
@ -710,8 +707,6 @@ array logsumexp(
|
|||||||
bool keepdims = false,
|
bool keepdims = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Simple arithmetic operations */
|
|
||||||
|
|
||||||
/** Absolute value of elements in an array. */
|
/** Absolute value of elements in an array. */
|
||||||
array abs(const array& a, StreamOrDevice s = {});
|
array abs(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
@ -1076,8 +1071,6 @@ array cummin(
|
|||||||
bool inclusive = true,
|
bool inclusive = true,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Convolution operations */
|
|
||||||
|
|
||||||
/** General convolution with a filter */
|
/** General convolution with a filter */
|
||||||
array conv_general(
|
array conv_general(
|
||||||
array input,
|
array input,
|
||||||
@ -1246,4 +1239,6 @@ array number_of_elements(
|
|||||||
Dtype dtype = int32,
|
Dtype dtype = int32,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** @} */
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -139,7 +139,8 @@ void init_stream(nb::module_& m) {
|
|||||||
Synchronize with the given stream.
|
Synchronize with the given stream.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
(Stream, optional): The stream to synchronize with. If ``None`` then
|
stream (Stream, optional): The stream to synchronize with. If ``None``
|
||||||
the default stream of the default device is used. Default: ``None``.
|
then the default stream of the default device is used.
|
||||||
|
Default: ``None``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user