mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix extensions (#1126)
* fix extensions * title * enable circle * fix nanobind tag * fix bug in doc * try to fix config * typo
This commit is contained in:
		| @@ -1,5 +1,5 @@ | ||||
| Developer Documentation | ||||
| ======================= | ||||
| Custom Extensions in MLX | ||||
| ======================== | ||||
|  | ||||
| You can extend MLX with custom operations on the CPU or GPU. This guide | ||||
| explains how to do that with a simple example. | ||||
| @@ -494,7 +494,7 @@ below. | ||||
|         auto kernel = d.get_kernel(kname.str(), "mlx_ext"); | ||||
|  | ||||
|         // Prepare to encode kernel | ||||
|         auto compute_encoder = d.get_command_encoder(s.index); | ||||
|         auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|         compute_encoder->setComputePipelineState(kernel); | ||||
|  | ||||
|         // Kernel parameters are registered with buffer indices corresponding to | ||||
| @@ -503,11 +503,11 @@ below. | ||||
|         size_t nelem = out.size(); | ||||
|  | ||||
|         // Encode input arrays to kernel | ||||
|         set_array_buffer(compute_encoder, x, 0); | ||||
|         set_array_buffer(compute_encoder, y, 1); | ||||
|         compute_encoder.set_input_array(x, 0); | ||||
|         compute_encoder.set_input_array(y, 1); | ||||
|  | ||||
|         // Encode output arrays to kernel | ||||
|         set_array_buffer(compute_encoder, out, 2); | ||||
|         compute_encoder.set_output_array(out, 2); | ||||
|  | ||||
|         // Encode alpha and beta | ||||
|         compute_encoder->setBytes(&alpha_, sizeof(float), 3); | ||||
| @@ -531,7 +531,7 @@ below. | ||||
|  | ||||
|         // Launch the grid with the given number of threads divided among | ||||
|         // the given threadgroups | ||||
|         compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|         compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|     } | ||||
|  | ||||
| We can now call the :meth:`axpby` operation on both the CPU and the GPU! | ||||
| @@ -825,7 +825,7 @@ Let's look at a simple script and its results: | ||||
|  | ||||
|     print(f"c shape: {c.shape}") | ||||
|     print(f"c dtype: {c.dtype}") | ||||
|     print(f"c correctness: {mx.all(c == 6.0).item()}") | ||||
|     print(f"c correct: {mx.all(c == 6.0).item()}") | ||||
|  | ||||
| Output: | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun