mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
ExpandDims primitive (#1687)
* add squeeze primitive * simplify squeeze, use in gather * fix * fix * fix * fix * fix no cpu * use squeeze in matmul and friends * expand dims primitive * comment
This commit is contained in:
@@ -43,6 +43,7 @@ DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
@@ -76,6 +77,7 @@ DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
|
||||
@@ -85,6 +85,16 @@ void Depends::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto strides = in.strides();
|
||||
for (auto ax : axes_) {
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
@@ -248,6 +258,20 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
Strides strides;
|
||||
for (int i = 0, j = 0; i < in.ndim(); ++i) {
|
||||
if (j < axes_.size() && i == axes_[j]) {
|
||||
j++;
|
||||
} else {
|
||||
strides.push_back(in.strides(i));
|
||||
}
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
|
||||
@@ -57,6 +57,7 @@ DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
@@ -101,6 +102,7 @@ DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
|
||||
@@ -211,6 +211,10 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto read_task = [out = out,
|
||||
@@ -381,6 +385,10 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
@@ -55,6 +55,7 @@ NO_CPU(Equal)
|
||||
NO_CPU(Erf)
|
||||
NO_CPU(ErfInv)
|
||||
NO_CPU(Exp)
|
||||
NO_CPU(ExpandDims)
|
||||
NO_CPU(Expm1)
|
||||
NO_CPU(FFT)
|
||||
NO_CPU(Floor)
|
||||
@@ -104,6 +105,7 @@ NO_CPU(Softmax)
|
||||
NO_CPU(Sort)
|
||||
NO_CPU_MULTI(Split)
|
||||
NO_CPU(Square)
|
||||
NO_CPU(Squeeze)
|
||||
NO_CPU(Sqrt)
|
||||
NO_CPU(StopGradient)
|
||||
NO_CPU(Subtract)
|
||||
|
||||
@@ -55,6 +55,7 @@ NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(ExpandDims)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
@@ -104,6 +105,7 @@ NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU_MULTI(Split)
|
||||
NO_GPU(Square)
|
||||
NO_GPU(Squeeze)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(StopGradient)
|
||||
NO_GPU(Subtract)
|
||||
|
||||
Reference in New Issue
Block a user