|
|
|
@@ -29,30 +29,16 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES); |
|
|
|
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END); |
|
|
|
TransArg(); |
|
|
|
for (size_t i = 0; i < begin_.size(); i++) { |
|
|
|
while (begin_[i] < 0) { |
|
|
|
begin_[i] = begin_[i] + input_shape_[i]; |
|
|
|
} |
|
|
|
if (begin_[i] > SizeToInt(input_shape_[i])) { |
|
|
|
begin_[i] = input_shape_[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
ClipBegin(); |
|
|
|
} else { |
|
|
|
auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE); |
|
|
|
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < begin_.size(); i++) { |
|
|
|
while (begin_[i] < 0) { |
|
|
|
begin_[i] = begin_[i] + input_shape_[i]; |
|
|
|
} |
|
|
|
if (begin_[i] > SizeToInt(input_shape_[i])) { |
|
|
|
begin_[i] = input_shape_[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
ClipBegin(); |
|
|
|
for (size_t i = 0; i < sizes.size(); ++i) { |
|
|
|
while (sizes[i] < 0) { |
|
|
|
sizes[i] = sizes[i] + input_shape_[i]; |
|
|
|
sizes[i] = sizes[i] + SizeToInt(input_shape_[i]); |
|
|
|
} |
|
|
|
strides_.emplace_back(1); |
|
|
|
end_.emplace_back(begin_[i] + sizes[i]); |
|
|
|
@@ -62,7 +48,17 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); |
|
|
|
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); |
|
|
|
} |
|
|
|
|
|
|
|
void SliceCPUKernel::ClipBegin() { |
|
|
|
for (size_t i = 0; i < begin_.size(); i++) { |
|
|
|
if (begin_[i] < 0) { |
|
|
|
auto k = begin_[i] + SizeToInt(input_shape_[i]); |
|
|
|
begin_[i] = k < 0 ? 0 : k; |
|
|
|
} |
|
|
|
if (begin_[i] > SizeToInt(input_shape_[i])) { |
|
|
|
begin_[i] = SizeToInt(input_shape_[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
void SliceCPUKernel::ExpandAllMemberDims() { |
|
|
|
auto input_len = input_shape_.size(); |
|
|
|
if (input_len < 4) { |
|
|
|
@@ -178,13 +174,13 @@ void SliceCPUKernel::TransArg() { |
|
|
|
MS_LOG(EXCEPTION) << "slice stride cannot be zero"; |
|
|
|
} |
|
|
|
if (end_[i] == 0 && begin_[i] < 0) { |
|
|
|
end_[i] = end_[i] + input_shape_[i]; |
|
|
|
end_[i] = end_[i] + SizeToInt(input_shape_[i]); |
|
|
|
} |
|
|
|
while (end_[i] < 0) { |
|
|
|
end_[i] = end_[i] + input_shape_[i]; |
|
|
|
if (end_[i] < 0) { |
|
|
|
end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]); |
|
|
|
} |
|
|
|
if (end_[i] > SizeToInt(input_shape_[i])) { |
|
|
|
end_[i] = input_shape_[i]; |
|
|
|
end_[i] = SizeToInt(input_shape_[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|