From b0921c15e9d0c89b3eaf7be3fd9c1e2d03be9cb5 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Mon, 11 May 2020 21:47:22 +0800 Subject: [PATCH] xreshape tensor_redistrinution bug fix --- .../parallel/auto_parallel/operator_costmodel.h | 4 ++-- .../ccsrc/parallel/ops_info/gather_v2_p_info.h | 5 ++++- mindspore/ccsrc/parallel/ops_info/reshape_info.cc | 13 +++++++------ mindspore/ccsrc/parallel/ops_info/reshape_info.h | 2 ++ .../tensor_layout/reshape_layout_transfer.cc | 12 ++++++++++++ 5 files changed, 27 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index 7d966845d1..4b48282612 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -616,8 +616,8 @@ using GatherV2CostPtr = std::shared_ptr; class GatherV2PCost : public OperatorCost { public: - explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GatherV2PCost() : OperatorCost(true) {} + explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {} + GatherV2PCost() : OperatorCost(true), axis_(0) {} ~GatherV2PCost() override = default; double GetCommCost(const std::vector &inputs, const std::vector &outputs, diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h index 03cbd70e8d..62553b5588 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -33,7 +33,10 @@ class GatherV2PInfo : public OperatorInfo { public: GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), + axis_(0), + bias_(0), + slice_size_(0) {} ~GatherV2PInfo() override = default; Status Init(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override; diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc index 40b8b79c4c..f663eaa6a2 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc @@ -456,8 +456,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector(false)), dev_num_(0), + pre_operator_index_(0), + next_operator_index_(0), input_layout_set_flag_(false), output_layout_set_flag_(false) {} ~ReshapeInfo() override = default; diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc index f6c90e9d46..4c66befd78 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc @@ -30,9 +30,18 @@ Status ReshapeLayoutTransfer::CheckValidTransfer() { std::shared_ptr ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { bool is_unified = IsSameTensorShape(); std::shared_ptr out_layout_ptr = std::make_shared(*this); + if (out_layout_ptr == nullptr) { + return nullptr; + } while (!is_unified) { std::shared_ptr temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); + if (temp_layout_ptr == nullptr) { + return nullptr; + } out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); + if (out_layout_ptr == nullptr) { + return nullptr; + } is_unified = out_layout_ptr->IsSameTensorShape(); } return out_layout_ptr; @@ -91,6 +100,9 @@ std::shared_ptr ReshapeLayoutTransfer::ExtendToTensorShap } std::shared_ptr exchanged_out = exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); + if (exchanged_out == nullptr) { + return nullptr; + } return exchanged_out->ExchangeFromAndTo(); }