diff --git a/docs/README.md b/docs/README.md
index d0c35a31c..f197ecf43 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -7,7 +7,7 @@ for example with `conda`:
```
conda install sphinx
-pip install sphinx-rtd-theme
+pip install sphinx-book-theme
```
### Build
diff --git a/docs/src/_static/mlx_logo.png b/docs/src/_static/mlx_logo.png
new file mode 100644
index 000000000..49400bd8d
Binary files /dev/null and b/docs/src/_static/mlx_logo.png differ
diff --git a/docs/src/conf.py b/docs/src/conf.py
index 5d3be4e57..132a85fb2 100644
--- a/docs/src/conf.py
+++ b/docs/src/conf.py
@@ -39,7 +39,17 @@ pygments_style = "sphinx"
# -- Options for HTML output -------------------------------------------------
-html_theme = "sphinx_rtd_theme"
+html_theme = "sphinx_book_theme"
+
+html_theme_options = {
+ "show_toc_level": 2,
+ "repository_url": "https://github.com/ml-explore/mlx",
+ "use_repository_button": True,
+ "navigation_with_keys": False,
+}
+
+html_logo = "_static/mlx_logo.png"
+
# -- Options for HTMLHelp output ---------------------------------------------
diff --git a/docs/src/index.rst b/docs/src/index.rst
index 3c57db7de..445970370 100644
--- a/docs/src/index.rst
+++ b/docs/src/index.rst
@@ -1,6 +1,30 @@
MLX
===
+MLX is a NumPy-like array framework designed for efficient and flexible
+machine learning on Apple silicon.
+
+The Python API closely follows NumPy with a few exceptions. MLX also has a
+fully featured C++ API which closely follows the Python API.
+
+The main differences between MLX and NumPy are:
+
+ - **Composable function transformations**: MLX has composable function
+ transformations for automatic differentiation, automatic vectorization,
+ and computation graph optimization.
+ - **Lazy computation**: Computations in MLX are lazy. Arrays are only
+ materialized when needed.
+ - **Multi-device**: Operations can run on any of the supported devices (CPU,
+ GPU, ...)
+
+The design of MLX is strongly inspired by frameworks like `PyTorch
+`_, `Jax `_, and
+`ArrayFire `_. A noteable difference from these
+frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
+memory. Operations on MLX arrays can be performed on any of the supported
+device types without performing data copies. Currently supported device types
+are the CPU and GPU.
+
.. toctree::
:caption: Install
:maxdepth: 1
diff --git a/docs/src/quick_start.rst b/docs/src/quick_start.rst
index 3439580ba..ae7912487 100644
--- a/docs/src/quick_start.rst
+++ b/docs/src/quick_start.rst
@@ -1,28 +1,6 @@
Quick Start Guide
=================
-MLX is a NumPy-like array framework designed for efficient and flexible
-machine learning on Apple silicon. The Python API closely follows NumPy with
-a few exceptions. MLX also has a fully featured C++ API which closely follows
-the Python API.
-
-The main differences between MLX and NumPy are:
-
- - **Composable function transformations**: MLX has composable function
- transformations for automatic differentiation, automatic vectorization,
- and computation graph optimization.
- - **Lazy computation**: Computations in MLX are lazy. Arrays are only
- materialized when needed.
- - **Multi-device**: Operations can run on any of the supported devices (CPU,
- GPU, ...)
-
-The design of MLX is strongly inspired by frameworks like `PyTorch
-`_, `Jax `_, and
-`ArrayFire `_. A noteable difference from these
-frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
-memory. Operations on MLX arrays can be performed on any of the supported
-device types without performing data copies. Currently supported device types
-are the CPU and GPU.
Basics
------