mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
Resolves build issues with the extension example (#419)
* resolved extension build issues and added test to ci * missing gguflib * rebased * force mlx install from fix branch * linux build issue * point to git install and comment out ci tests
This commit is contained in:
@@ -104,7 +104,10 @@ void axpby_impl(
|
||||
}
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out_arr) {
|
||||
auto out = out_arr[0];
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
@@ -175,7 +178,10 @@ void axpby_impl_accelerate(
|
||||
}
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outarr) {
|
||||
auto out = outarr[0];
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
@@ -189,13 +195,15 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, out);
|
||||
eval(inputs, outarr);
|
||||
}
|
||||
|
||||
#else // Accelerate not available
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
@@ -208,8 +216,11 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
#ifdef _METAL_
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outarr) {
|
||||
// Prepare inputs
|
||||
auto out = outarr[0];
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
@@ -295,7 +306,9 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
#else // Metal is not available
|
||||
|
||||
/** Fail evaluation on GPU */
|
||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||
}
|
||||
|
||||
@@ -306,7 +319,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array Axpby::jvp(
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
@@ -321,32 +334,33 @@ array Axpby::jvp(
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return multiply(scale_arr, tangents[0], stream());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
}
|
||||
}
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotan.dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
|
Reference in New Issue
Block a user