mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	 44c1ce5e6a
			
		
	
	44c1ce5e6a
	
	
	
		
			
			* spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
		
			
				
	
	
		
			100 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			100 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright © 2023 Apple Inc.
 | |
| 
 | |
| #include <cassert>
 | |
| #include <iostream>
 | |
| 
 | |
| #include "mlx/mlx.h"
 | |
| 
 | |
| using namespace mlx::core;
 | |
| 
 | |
| void array_basics() {
 | |
|   // Make a scalar array:
 | |
|   array x(1.0);
 | |
| 
 | |
|   // Get the value out of it:
 | |
|   auto s = x.item<float>();
 | |
|   assert(s == 1.0);
 | |
| 
 | |
|   // Scalars have a size of 1:
 | |
|   size_t size = x.size();
 | |
|   assert(size == 1);
 | |
| 
 | |
|   // Scalars have 0 dimensions:
 | |
|   int ndim = x.ndim();
 | |
|   assert(ndim == 0);
 | |
| 
 | |
|   // The shape should be an empty vector:
 | |
|   auto shape = x.shape();
 | |
|   assert(shape.empty());
 | |
| 
 | |
|   // The datatype should be float32:
 | |
|   auto dtype = x.dtype();
 | |
|   assert(dtype == float32);
 | |
| 
 | |
|   // Specify the dtype when constructing the array:
 | |
|   x = array(1, int32);
 | |
|   assert(x.dtype() == int32);
 | |
|   x.item<int>(); // OK
 | |
|   // x.item<float>();  // Undefined!
 | |
| 
 | |
|   // Make a multidimensional array:
 | |
|   x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
 | |
|   // mlx is row-major by default so the first row of this array
 | |
|   // is [1.0, 2.0] and the second row is [3.0, 4.0]
 | |
| 
 | |
|   // Make an array of shape {2, 2} filled with ones:
 | |
|   auto y = ones({2, 2});
 | |
| 
 | |
|   // Pointwise add x and y:
 | |
|   auto z = add(x, y);
 | |
| 
 | |
|   // Same thing:
 | |
|   z = x + y;
 | |
| 
 | |
|   // mlx is lazy by default. At this point `z` only
 | |
|   // has a shape and a type but no actual data:
 | |
|   assert(z.dtype() == float32);
 | |
|   assert(z.shape(0) == 2);
 | |
|   assert(z.shape(1) == 2);
 | |
| 
 | |
|   // To actually run the computation you must evaluate `z`.
 | |
|   // Under the hood, mlx records operations in a graph.
 | |
|   // The variable `z` is a node in the graph which points to its operation
 | |
|   // and inputs. When `eval` is called on an array (or arrays), the array and
 | |
|   // all of its dependencies are recursively evaluated to produce the result.
 | |
|   // Once an array is evaluated, it has data and is detached from its inputs.
 | |
|   eval(z);
 | |
| 
 | |
|   // Of course the array can still be an input to other operations. You can even
 | |
|   // call eval on the array again, this will just be a no-op:
 | |
|   eval(z); // no-op
 | |
| 
 | |
|   // Some functions or methods on arrays implicitly evaluate them. For example
 | |
|   // accessing a value in an array or printing the array implicitly evaluate it:
 | |
|   z = ones({1});
 | |
|   z.item<float>(); // implicit evaluation
 | |
| 
 | |
|   z = ones({2, 2});
 | |
|   std::cout << z << std::endl; // implicit evaluation
 | |
| }
 | |
| 
 | |
| void automatic_differentiation() {
 | |
|   auto fn = [](array x) { return square(x); };
 | |
| 
 | |
|   // Computing the derivative function of a function
 | |
|   auto grad_fn = grad(fn);
 | |
|   // Call grad_fn on the input to get the derivative
 | |
|   auto x = array(1.5);
 | |
|   auto dfdx = grad_fn(x);
 | |
|   // dfdx is 2 * x
 | |
| 
 | |
|   // Get the second derivative by composing grad with grad
 | |
|   auto df2dx2 = grad(grad(fn))(x);
 | |
|   // df2dx2 is 2
 | |
| }
 | |
| 
 | |
| int main() {
 | |
|   array_basics();
 | |
|   automatic_differentiation();
 | |
| }
 |