mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Spelling (#342)
* spelling: accumulates Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: across Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: additional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: against Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: among Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: array Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: at least Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: available Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: axes Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: basically Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bfloat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: bounds Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: broadcast Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: buffer Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: class Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: coefficients Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: collision Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: combinations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: committing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: computation Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: consider Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: constructing Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: conversions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: correctly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: corresponding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: declaration Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: default Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dependency Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destination Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: destructor Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: dimensions Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: divided Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: element-wise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: elements Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: endianness Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: equivalent Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: explicitly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: github Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: indices Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: irregularly Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: memory Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: metallib Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: negative Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: notable Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: optional Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: otherwise Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: overridden Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partially Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: partition Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perform Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: perturbations Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: positively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: primitive Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeat Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: repeats Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respect Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: respectively Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: result Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: rounding Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: separate Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: skipping Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: structure Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: the Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: transpose Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unnecessary Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unneeded Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> * spelling: unsupported Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com> --------- Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
		@@ -26,7 +26,7 @@ namespace mlx::core {
 | 
			
		||||
///////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 *  Scale and sum two vectors elementwise
 | 
			
		||||
 *  Scale and sum two vectors element-wise
 | 
			
		||||
 *  z = alpha * x + beta * y
 | 
			
		||||
 *
 | 
			
		||||
 *  Follow numpy style broadcasting between x and y
 | 
			
		||||
@@ -91,21 +91,21 @@ void axpby_impl(
 | 
			
		||||
  T alpha = static_cast<T>(alpha_);
 | 
			
		||||
  T beta = static_cast<T>(beta_);
 | 
			
		||||
 | 
			
		||||
  // Do the elementwise operation for each output
 | 
			
		||||
  // Do the element-wise operation for each output
 | 
			
		||||
  for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
 | 
			
		||||
    // Map linear indices to offsets in x and y
 | 
			
		||||
    auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
 | 
			
		||||
    auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
 | 
			
		||||
 | 
			
		||||
    // We allocate the output to be contiguous and regularly strided
 | 
			
		||||
    // (defaults to row major) and hence it doesn't need additonal mapping
 | 
			
		||||
    // (defaults to row major) and hence it doesn't need additional mapping
 | 
			
		||||
    out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/** Fall back implementation for evaluation on CPU */
 | 
			
		||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  // Check the inputs (registered in the op while contructing the out array)
 | 
			
		||||
  // Check the inputs (registered in the op while constructing the out array)
 | 
			
		||||
  assert(inputs.size() == 2);
 | 
			
		||||
  auto& x = inputs[0];
 | 
			
		||||
  auto& y = inputs[1];
 | 
			
		||||
@@ -192,7 +192,7 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  eval(inputs, out);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#else // Accelerate not avaliable
 | 
			
		||||
#else // Accelerate not available
 | 
			
		||||
 | 
			
		||||
/** Evaluate primitive on CPU falling back to common backend */
 | 
			
		||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
@@ -254,7 +254,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  compute_encoder->setComputePipelineState(kernel);
 | 
			
		||||
 | 
			
		||||
  // Kernel parameters are registered with buffer indices corresponding to
 | 
			
		||||
  // those in the kernel decelaration at axpby.metal
 | 
			
		||||
  // those in the kernel declaration at axpby.metal
 | 
			
		||||
  int ndim = out.ndim();
 | 
			
		||||
  size_t nelem = out.size();
 | 
			
		||||
 | 
			
		||||
@@ -287,7 +287,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  // Fix the 3D size of the launch grid (in terms of threads)
 | 
			
		||||
  MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
 | 
			
		||||
 | 
			
		||||
  // Launch the grid with the given number of threads divded among
 | 
			
		||||
  // Launch the grid with the given number of threads divided among
 | 
			
		||||
  // the given threadgroups
 | 
			
		||||
  compute_encoder->dispatchThreads(grid_dims, group_dims);
 | 
			
		||||
}
 | 
			
		||||
@@ -311,8 +311,8 @@ array Axpby::jvp(
 | 
			
		||||
    const std::vector<array>& tangents,
 | 
			
		||||
    const std::vector<int>& argnums) {
 | 
			
		||||
  // Forward mode diff that pushes along the tangents
 | 
			
		||||
  // The jvp transform on the the primitive can built with ops
 | 
			
		||||
  // that are scheduled on the same stream as the primtive
 | 
			
		||||
  // The jvp transform on the primitive can built with ops
 | 
			
		||||
  // that are scheduled on the same stream as the primitive
 | 
			
		||||
 | 
			
		||||
  // If argnums = {0}, we only push along x in which case the
 | 
			
		||||
  // jvp is just the tangent scaled by alpha
 | 
			
		||||
@@ -345,7 +345,7 @@ std::vector<array> Axpby::vjp(
 | 
			
		||||
  return vjps;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/** Vectorize primitve along given axis */
 | 
			
		||||
/** Vectorize primitive along given axis */
 | 
			
		||||
std::pair<array, int> Axpby::vmap(
 | 
			
		||||
    const std::vector<array>& inputs,
 | 
			
		||||
    const std::vector<int>& axes) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user