| @@ -616,8 +616,8 @@ using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>; | |||||
| class GatherV2PCost : public OperatorCost { | class GatherV2PCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| GatherV2PCost() : OperatorCost(true) {} | |||||
| explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {} | |||||
| GatherV2PCost() : OperatorCost(true), axis_(0) {} | |||||
| ~GatherV2PCost() override = default; | ~GatherV2PCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -33,7 +33,10 @@ class GatherV2PInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()), | |||||
| axis_(0), | |||||
| bias_(0), | |||||
| slice_size_(0) {} | |||||
| ~GatherV2PInfo() override = default; | ~GatherV2PInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| @@ -456,8 +456,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; | TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; | ||||
| TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); | |||||
| SetInputLayout(pre_out_tensor_layout); | |||||
| SetInputLayout(pre_out_tensor_info.tensor_layout()); | |||||
| // infer pre_node output strategy from output_layout. | // infer pre_node output strategy from output_layout. | ||||
| Dimensions stra = pre_out_tensor_info.InferStrategy(); | Dimensions stra = pre_out_tensor_info.InferStrategy(); | ||||
| if (stra.empty()) { | if (stra.empty()) { | ||||
| @@ -481,15 +480,17 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; | TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; | ||||
| TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); | |||||
| SetOutputLayout(next_in_tensor_layout); | |||||
| SetOutputLayout(next_in_tensor_info.tensor_layout()); | |||||
| if (Init(nullptr) == FAILED) { | if (Init(nullptr) == FAILED) { | ||||
| MS_LOG(ERROR) << "Failure:operator reshape init failed"; | |||||
| return FAILED; | |||||
| MS_LOG(DEBUG) << "Failure:operator reshape init failed"; | |||||
| continue; | |||||
| } | } | ||||
| SetCostForReshape(reshape_stra); | SetCostForReshape(reshape_stra); | ||||
| } | } | ||||
| } | } | ||||
| if (strategy_cost_.empty()) { | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace parallel | } // namespace parallel | ||||
| @@ -38,6 +38,8 @@ class ReshapeInfo : public OperatorInfo { | |||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)), | : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)), | ||||
| dev_num_(0), | dev_num_(0), | ||||
| pre_operator_index_(0), | |||||
| next_operator_index_(0), | |||||
| input_layout_set_flag_(false), | input_layout_set_flag_(false), | ||||
| output_layout_set_flag_(false) {} | output_layout_set_flag_(false) {} | ||||
| ~ReshapeInfo() override = default; | ~ReshapeInfo() override = default; | ||||
| @@ -30,9 +30,18 @@ Status ReshapeLayoutTransfer::CheckValidTransfer() { | |||||
| std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { | std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { | ||||
| bool is_unified = IsSameTensorShape(); | bool is_unified = IsSameTensorShape(); | ||||
| std::shared_ptr<ReshapeLayoutTransfer> out_layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this); | std::shared_ptr<ReshapeLayoutTransfer> out_layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this); | ||||
| if (out_layout_ptr == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| while (!is_unified) { | while (!is_unified) { | ||||
| std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); | std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); | ||||
| if (temp_layout_ptr == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); | out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); | ||||
| if (out_layout_ptr == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| is_unified = out_layout_ptr->IsSameTensorShape(); | is_unified = out_layout_ptr->IsSameTensorShape(); | ||||
| } | } | ||||
| return out_layout_ptr; | return out_layout_ptr; | ||||
| @@ -91,6 +100,9 @@ std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendToTensorShap | |||||
| } | } | ||||
| std::shared_ptr<ReshapeLayoutTransfer> exchanged_out = | std::shared_ptr<ReshapeLayoutTransfer> exchanged_out = | ||||
| exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); | exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); | ||||
| if (exchanged_out == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return exchanged_out->ExchangeFromAndTo(); | return exchanged_out->ExchangeFromAndTo(); | ||||
| } | } | ||||