Merge pull request !7345 from yao_yf/reshape_redistribution_all_scene_support_addtags/v1.1.0
| @@ -770,7 +770,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| } | } | ||||
| } | } | ||||
| bool FindReshape(const CNodePtr &cnode) { | |||||
| bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) { | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -780,7 +780,16 @@ bool FindReshape(const CNodePtr &cnode) { | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| return (prim->name() == RESHAPE); | |||||
| if (prim->name() == RESHAPE) { | |||||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||||
| std::string op_info_name = operator_info->name(); | |||||
| if (op_cache->find(op_info_name) != op_cache->end()) { | |||||
| return false; | |||||
| } | |||||
| op_cache->insert(op_info_name); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | ||||
| @@ -871,9 +880,10 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator | |||||
| } | } | ||||
| void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | ||||
| std::unordered_set<std::string> op_cache; | |||||
| for (auto node : all_nodes) { | for (auto node : all_nodes) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| if (!FindReshape(cnode)) { | |||||
| if (!FindReshape(cnode, &op_cache)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| MS_ASSERT(cnode->inputs().size() == 3); | MS_ASSERT(cnode->inputs().size() == 3); | ||||
| @@ -36,11 +36,14 @@ std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrange | |||||
| while (!is_unified) { | while (!is_unified) { | ||||
| std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); | std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); | ||||
| if (temp_layout_ptr == nullptr) { | if (temp_layout_ptr == nullptr) { | ||||
| return nullptr; | |||||
| out_layout_ptr->SetExpandAble(false); | |||||
| return out_layout_ptr; | |||||
| } | } | ||||
| out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); | out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); | ||||
| if (out_layout_ptr == nullptr) { | if (out_layout_ptr == nullptr) { | ||||
| return nullptr; | |||||
| std::shared_ptr<ReshapeLayoutTransfer> layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this); | |||||
| layout_ptr->SetExpandAble(false); | |||||
| return layout_ptr; | |||||
| } | } | ||||
| is_unified = out_layout_ptr->IsSameTensorShape(); | is_unified = out_layout_ptr->IsSameTensorShape(); | ||||
| } | } | ||||
| @@ -58,7 +58,11 @@ Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tens | |||||
| MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString(); | MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString(); | ||||
| return Status::SUCCESS; | return Status::SUCCESS; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); | |||||
| if (layout_transfer_) { | |||||
| MS_LOG(WARNING) << "invalid origin tensor layout " << this->OriginToString(); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); | |||||
| } | |||||
| return Status::FAILED; | return Status::FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -90,7 +94,11 @@ bool TensorLayout::IsValidTensorLayout() const { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) { | if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) { | ||||
| MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; | |||||
| if (layout_transfer_) { | |||||
| MS_LOG(WARNING) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; | |||||
| } | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -214,6 +222,7 @@ std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShapeWithoutExtendDevice | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| TensorLayout tensor_layout_new; | TensorLayout tensor_layout_new; | ||||
| tensor_layout_new.set_layout_transfer(true); | |||||
| Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape); | Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape); | ||||
| if (status != Status::SUCCESS) { | if (status != Status::SUCCESS) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -391,9 +400,9 @@ TensorLayout TensorLayout::SqueezeShape() const { | |||||
| } | } | ||||
| TensorLayout TensorLayout::TransferRepeatLayout() const { | TensorLayout TensorLayout::TransferRepeatLayout() const { | ||||
| Shape dev_mat(device_arrangement_.array()); | |||||
| Shape tensor_map(tensor_map_.GetDimSize(), -1); | |||||
| Shape tensor_shape(tensor_shape_.array()); | |||||
| Shape dev_mat(device_arrangement_origin_.array()); | |||||
| Shape tensor_map(tensor_map_origin_.GetDimSize(), -1); | |||||
| Shape tensor_shape(tensor_shape_origin_.array()); | |||||
| TensorLayout repeat; | TensorLayout repeat; | ||||
| repeat.InitFromVector(dev_mat, tensor_map, tensor_shape); | repeat.InitFromVector(dev_mat, tensor_map, tensor_shape); | ||||
| return repeat; | return repeat; | ||||
| @@ -46,6 +46,10 @@ class TensorLayout { | |||||
| void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } | void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } | ||||
| bool layout_transfer() const { return layout_transfer_; } | |||||
| void set_layout_transfer(bool flag) { layout_transfer_ = flag; } | |||||
| int32_t get_field_size() const { return field_size_; } | int32_t get_field_size() const { return field_size_; } | ||||
| void set_field_size(int32_t field_size) { field_size_ = field_size; } | void set_field_size(int32_t field_size) { field_size_ = field_size; } | ||||
| @@ -113,14 +117,15 @@ class TensorLayout { | |||||
| int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const; | int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const; | ||||
| Arrangement device_arrangement_origin_; | Arrangement device_arrangement_origin_; | ||||
| Map tensor_map_origin_; | |||||
| Arrangement tensor_shape_origin_; | Arrangement tensor_shape_origin_; | ||||
| Arrangement device_arrangement_; | Arrangement device_arrangement_; | ||||
| Map tensor_map_; | |||||
| Arrangement tensor_shape_; | Arrangement tensor_shape_; | ||||
| Map tensor_map_; | |||||
| Map tensor_map_origin_; | |||||
| bool skip_redistribution_ = false; | bool skip_redistribution_ = false; | ||||
| int32_t field_size_ = 0; | |||||
| bool uniform_split_ = true; | bool uniform_split_ = true; | ||||
| bool layout_transfer_ = false; | |||||
| int32_t field_size_ = 0; | |||||
| Shape opt_shard_slice_shape_; | Shape opt_shard_slice_shape_; | ||||
| std::string opt_shard_group_ = ""; | std::string opt_shard_group_ = ""; | ||||
| }; | }; | ||||
| @@ -43,7 +43,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL | |||||
| TensorLayout from_repeat = from_origin_.TransferRepeatLayout(); | TensorLayout from_repeat = from_origin_.TransferRepeatLayout(); | ||||
| TensorLayout to_repeat = to_origin_.TransferRepeatLayout(); | TensorLayout to_repeat = to_origin_.TransferRepeatLayout(); | ||||
| MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString(); | MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString(); | ||||
| MS_LOG(DEBUG) << "reshape to_layout " << to_repeat.ToString(); | |||||
| MS_LOG(DEBUG) << "reshape to_repeat " << to_repeat.ToString(); | |||||
| MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); | MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); | ||||
| MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); | MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); | ||||
| MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); | MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); | ||||
| @@ -204,3 +204,35 @@ def test_reshape_unexpand_6(): | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | ||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| _executor.compile(net, x) | _executor.compile(net, x) | ||||
| def test_reshape_unexpand_7(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, in_channel=3, out_channel=8, axis=1, input_shape=(32, 4, 110, -1), | |||||
| mul_size=(32, 1, 220, 220)): | |||||
| super().__init__() | |||||
| mul_np = np.full(mul_size, 0.5, dtype=np.float32) | |||||
| self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight") | |||||
| self.mul = P.Mul() | |||||
| self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, | |||||
| kernel_size=5, has_bias=True, weight_init='ones', | |||||
| bias_init='ones', pad_mode='valid') | |||||
| self.softmax = nn.Softmax(axis=axis) | |||||
| self.relu = nn.ReLU() | |||||
| self.reshape = P.Reshape() | |||||
| self.input_shape = input_shape | |||||
| def construct(self, inputs): | |||||
| x = self.conv(inputs) | |||||
| x = self.softmax(x) | |||||
| x = self.relu(x) | |||||
| x = self.mul(x, self.mul_weight) | |||||
| x = self.reshape(x, self.input_shape) | |||||
| return x | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| x = Tensor(np.ones([32, 3, 224, 224]), dtype=ms.float32) | |||||
| net = GradWrap(NetWithLoss(Net())) | |||||
| net.set_auto_parallel() | |||||
| _executor.compile(net, x) | |||||