mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix extension (#1740)
This commit is contained in:
parent
491fa95b1f
commit
a64a8dfe45
@ -58,9 +58,9 @@ mx::array axpby(
|
|||||||
// Construct the array as the output of the Axpby primitive
|
// Construct the array as the output of the Axpby primitive
|
||||||
// with the broadcasted and upcasted arrays as inputs
|
// with the broadcasted and upcasted arrays as inputs
|
||||||
return mx::array(
|
return mx::array(
|
||||||
/* const std::vector<int>& shape = */ out_shape,
|
/* const mx::Shape& shape = */ out_shape,
|
||||||
/* mx::Dtype dtype = */ out_dtype,
|
/* mx::Dtype dtype = */ out_dtype,
|
||||||
/* std::unique_ptr<mx::Primitive> primitive = */
|
/* std::shared_ptr<mx::Primitive> primitive = */
|
||||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||||
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
||||||
}
|
}
|
||||||
@ -279,7 +279,7 @@ void Axpby::eval_gpu(
|
|||||||
if (!contiguous_kernel) {
|
if (!contiguous_kernel) {
|
||||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||||
compute_encoder.set_bytes(y.strides(), 7);
|
compute_encoder.set_vector_bytes(y.strides(), 7);
|
||||||
compute_encoder.set_bytes(ndim, 8);
|
compute_encoder.set_bytes(ndim, 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user