|
|
|
@@ -19,6 +19,7 @@ |
|
|
|
#include <algorithm> |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include <utility> |
|
|
|
|
|
|
|
#include "ir/value.h" |
|
|
|
#include "parallel/auto_parallel/costmodel.h" |
|
|
|
@@ -544,5 +545,160 @@ Status ExpandDimsInfo::InferMirrorOps() { |
|
|
|
MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) { |
|
|
|
std::vector<int32_t> axis; |
|
|
|
auto axis_list = value_tuple->value(); |
|
|
|
if (inputs_shape_.empty()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
Shape input_shape = inputs_shape_.at(0); |
|
|
|
size_t input_size = input_shape.size(); |
|
|
|
// if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1. |
|
|
|
if (axis_list.empty()) { |
|
|
|
for (size_t i = 0; i < input_size; ++i) { |
|
|
|
if (input_shape[i] == 1) { |
|
|
|
axis.push_back(i); |
|
|
|
} |
|
|
|
} |
|
|
|
axis_ = MakeValue(axis)->cast<ValueTuplePtr>(); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
// convert negative axis to positive. |
|
|
|
for (auto& dim : axis_list) { |
|
|
|
if (!dim->isa<Int32Imm>()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The type of axis is not int"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
int32_t dim_value = GetValue<int32_t>(dim); |
|
|
|
int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value; |
|
|
|
axis.push_back(positive_value); |
|
|
|
} |
|
|
|
axis_ = MakeValue(axis)->cast<ValueTuplePtr>(); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status SqueezeInfo::GetAttrs() { |
|
|
|
auto iter = attrs_.find(AXIS); |
|
|
|
if (iter == attrs_.end()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Can't find axis attribute."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(iter->second); |
|
|
|
auto value_tuple = iter->second->cast<ValueTuplePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value_tuple); |
|
|
|
InferAxis(value_tuple); |
|
|
|
attrs_[AXIS] = axis_; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status SqueezeInfo::InferReplaceOps(const StrategyPtr& strategy) { |
|
|
|
Attr attr = std::make_pair(AXIS, axis_); |
|
|
|
OperatorAttrs attrs = {attr}; |
|
|
|
OperatorParams params; |
|
|
|
OperatorArgs args = std::make_pair(attrs, params); |
|
|
|
replace_op_ = {std::make_pair(SQUEEZE, args)}; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status SqueezeInfo::InferTensorMap() { |
|
|
|
// for example: if the shape of input is [32, 32, 1], and the axis is (2, ), |
|
|
|
// then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1] |
|
|
|
std::vector<int32_t> input_tensor_map, output_tensor_map; |
|
|
|
if (inputs_shape_.empty()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
size_t size = inputs_shape_[0].size(); |
|
|
|
std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_); |
|
|
|
for (size_t i = 0; i < size; ++i) { |
|
|
|
size_t index = size - i - 1; |
|
|
|
auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); |
|
|
|
if (iter == axis.end()) { |
|
|
|
output_tensor_map.push_back(SizeToInt(index)); |
|
|
|
} |
|
|
|
input_tensor_map.push_back(SizeToInt(index)); |
|
|
|
} |
|
|
|
inputs_tensor_map_.push_back(input_tensor_map); |
|
|
|
outputs_tensor_map_.push_back(output_tensor_map); |
|
|
|
MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) |
|
|
|
<< ", and the tensor map of output is " << ShapeToString(output_tensor_map); |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status SqueezeInfo::InferTensorInfo() { |
|
|
|
if (inputs_shape_.empty() || outputs_shape_.empty()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
Shape input_shape = inputs_shape_[0]; |
|
|
|
Shape output_shape = outputs_shape_[0]; |
|
|
|
|
|
|
|
// infer slice shape |
|
|
|
Shapes inputs_slice_shape, outputs_slice_shape; |
|
|
|
Strategys inputs_strategy = strategy_->GetInputDim(); |
|
|
|
Dimensions output_strategy; |
|
|
|
std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_); |
|
|
|
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { |
|
|
|
auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); |
|
|
|
if (iter == axis.end()) { |
|
|
|
output_strategy.push_back(inputs_strategy[0].at(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
Strategys outputs_strategy = {output_strategy}; |
|
|
|
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
Shape input_slice_shape = inputs_slice_shape[0]; |
|
|
|
Shape output_slice_shape = outputs_slice_shape[0]; |
|
|
|
|
|
|
|
// infer tensor layout |
|
|
|
TensorLayout input_tensor_layout, output_tensor_layout; |
|
|
|
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); |
|
|
|
TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); |
|
|
|
|
|
|
|
inputs_tensor_info_.push_back(input_tensor_info); |
|
|
|
outputs_tensor_info_.push_back(output_tensor_info); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status SqueezeInfo::Init(const StrategyPtr& strategy) { |
|
|
|
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << " : Init failed."; |
|
|
|
} |
|
|
|
|
|
|
|
if (InferReplaceOps(strategy) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << " : Infer replace ops failed"; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << name_ << " : Init success."; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} // namespace parallel |
|
|
|
} // namespace mindspore |