Merge pull request !7345 from yao_yf/reshape_redistribution_all_scene_support_addtags/v1.1.0
| @@ -770,7 +770,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| } | |||
| bool FindReshape(const CNodePtr &cnode) { | |||
| bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| @@ -780,7 +780,16 @@ bool FindReshape(const CNodePtr &cnode) { | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return (prim->name() == RESHAPE); | |||
| if (prim->name() == RESHAPE) { | |||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||
| std::string op_info_name = operator_info->name(); | |||
| if (op_cache->find(op_info_name) != op_cache->end()) { | |||
| return false; | |||
| } | |||
| op_cache->insert(op_info_name); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | |||
| @@ -871,9 +880,10 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator | |||
| } | |||
| void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| std::unordered_set<std::string> op_cache; | |||
| for (auto node : all_nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!FindReshape(cnode)) { | |||
| if (!FindReshape(cnode, &op_cache)) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(cnode->inputs().size() == 3); | |||
| @@ -36,11 +36,14 @@ std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrange | |||
| while (!is_unified) { | |||
| std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); | |||
| if (temp_layout_ptr == nullptr) { | |||
| return nullptr; | |||
| out_layout_ptr->SetExpandAble(false); | |||
| return out_layout_ptr; | |||
| } | |||
| out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); | |||
| if (out_layout_ptr == nullptr) { | |||
| return nullptr; | |||
| std::shared_ptr<ReshapeLayoutTransfer> layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this); | |||
| layout_ptr->SetExpandAble(false); | |||
| return layout_ptr; | |||
| } | |||
| is_unified = out_layout_ptr->IsSameTensorShape(); | |||
| } | |||
| @@ -58,7 +58,11 @@ Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tens | |||
| MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString(); | |||
| return Status::SUCCESS; | |||
| } else { | |||
| MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); | |||
| if (layout_transfer_) { | |||
| MS_LOG(WARNING) << "invalid origin tensor layout " << this->OriginToString(); | |||
| } else { | |||
| MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); | |||
| } | |||
| return Status::FAILED; | |||
| } | |||
| } | |||
| @@ -90,7 +94,11 @@ bool TensorLayout::IsValidTensorLayout() const { | |||
| return false; | |||
| } | |||
| if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) { | |||
| MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; | |||
| if (layout_transfer_) { | |||
| MS_LOG(WARNING) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; | |||
| } else { | |||
| MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; | |||
| } | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -214,6 +222,7 @@ std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShapeWithoutExtendDevice | |||
| return nullptr; | |||
| } | |||
| TensorLayout tensor_layout_new; | |||
| tensor_layout_new.set_layout_transfer(true); | |||
| Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape); | |||
| if (status != Status::SUCCESS) { | |||
| return nullptr; | |||
| @@ -391,9 +400,9 @@ TensorLayout TensorLayout::SqueezeShape() const { | |||
| } | |||
| TensorLayout TensorLayout::TransferRepeatLayout() const { | |||
| Shape dev_mat(device_arrangement_.array()); | |||
| Shape tensor_map(tensor_map_.GetDimSize(), -1); | |||
| Shape tensor_shape(tensor_shape_.array()); | |||
| Shape dev_mat(device_arrangement_origin_.array()); | |||
| Shape tensor_map(tensor_map_origin_.GetDimSize(), -1); | |||
| Shape tensor_shape(tensor_shape_origin_.array()); | |||
| TensorLayout repeat; | |||
| repeat.InitFromVector(dev_mat, tensor_map, tensor_shape); | |||
| return repeat; | |||
| @@ -46,6 +46,10 @@ class TensorLayout { | |||
| void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } | |||
| bool layout_transfer() const { return layout_transfer_; } | |||
| void set_layout_transfer(bool flag) { layout_transfer_ = flag; } | |||
| int32_t get_field_size() const { return field_size_; } | |||
| void set_field_size(int32_t field_size) { field_size_ = field_size; } | |||
| @@ -113,14 +117,15 @@ class TensorLayout { | |||
| int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const; | |||
| Arrangement device_arrangement_origin_; | |||
| Map tensor_map_origin_; | |||
| Arrangement tensor_shape_origin_; | |||
| Arrangement device_arrangement_; | |||
| Map tensor_map_; | |||
| Arrangement tensor_shape_; | |||
| Map tensor_map_; | |||
| Map tensor_map_origin_; | |||
| bool skip_redistribution_ = false; | |||
| int32_t field_size_ = 0; | |||
| bool uniform_split_ = true; | |||
| bool layout_transfer_ = false; | |||
| int32_t field_size_ = 0; | |||
| Shape opt_shard_slice_shape_; | |||
| std::string opt_shard_group_ = ""; | |||
| }; | |||
| @@ -43,7 +43,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL | |||
| TensorLayout from_repeat = from_origin_.TransferRepeatLayout(); | |||
| TensorLayout to_repeat = to_origin_.TransferRepeatLayout(); | |||
| MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString(); | |||
| MS_LOG(DEBUG) << "reshape to_layout " << to_repeat.ToString(); | |||
| MS_LOG(DEBUG) << "reshape to_repeat " << to_repeat.ToString(); | |||
| MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); | |||
| MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); | |||
| MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); | |||
| @@ -204,3 +204,35 @@ def test_reshape_unexpand_6(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_unexpand_7(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, in_channel=3, out_channel=8, axis=1, input_shape=(32, 4, 110, -1), | |||
| mul_size=(32, 1, 220, 220)): | |||
| super().__init__() | |||
| mul_np = np.full(mul_size, 0.5, dtype=np.float32) | |||
| self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight") | |||
| self.mul = P.Mul() | |||
| self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, | |||
| kernel_size=5, has_bias=True, weight_init='ones', | |||
| bias_init='ones', pad_mode='valid') | |||
| self.softmax = nn.Softmax(axis=axis) | |||
| self.relu = nn.ReLU() | |||
| self.reshape = P.Reshape() | |||
| self.input_shape = input_shape | |||
| def construct(self, inputs): | |||
| x = self.conv(inputs) | |||
| x = self.softmax(x) | |||
| x = self.relu(x) | |||
| x = self.mul(x, self.mul_weight) | |||
| x = self.reshape(x, self.input_shape) | |||
| return x | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| x = Tensor(np.ones([32, 3, 224, 224]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||