| @@ -45,7 +45,6 @@ using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; | |||
| using mindspore::abstract::AbstractAttribute; | |||
| using mindspore::abstract::AbstractBase; | |||
| using mindspore::abstract::AbstractClass; | |||
| using mindspore::abstract::AbstractDictionary; | |||
| using mindspore::abstract::AbstractDictionaryPtr; | |||
| using mindspore::abstract::AbstractEllipsis; | |||
| @@ -77,20 +76,12 @@ void HyperMap::Init() { | |||
| } | |||
| HyperMap::HyperMap(bool reverse, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf) | |||
| : MetaFuncGraph("hyper_map"), | |||
| fn_leaf_(fn_leaf), | |||
| reverse_(reverse), | |||
| broadcast_(false), | |||
| nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { | |||
| : MetaFuncGraph("hyper_map"), fn_leaf_(fn_leaf), reverse_(reverse), nonleaf_({kObjectTypeList, kObjectTypeTuple}) { | |||
| Init(); | |||
| } | |||
| HyperMap::HyperMap(const HyperMap &h) | |||
| : MetaFuncGraph("hyper_map"), | |||
| fn_leaf_(h.fn_leaf_), | |||
| reverse_(h.reverse_), | |||
| broadcast_(h.broadcast_), | |||
| nonleaf_(h.nonleaf_) { | |||
| : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), reverse_(h.reverse_), nonleaf_(h.nonleaf_) { | |||
| Init(); | |||
| } | |||
| @@ -247,61 +238,21 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap | |||
| return func_graph->NewCNodeInOrder(inputs); | |||
| } | |||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::size_t attrSize = type->GetAttributes().size(); | |||
| constexpr size_t kPrimAndTypeLen = 2; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.reserve(attrSize + kPrimAndTypeLen); | |||
| inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); | |||
| inputs.push_back(NewValueNode(type)); | |||
| // cannot use shared_from_base() also known as this, as it will make a reference cycle on | |||
| // hypermap and graph generated, it will cause memory leak. | |||
| auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this)); | |||
| for (std::size_t i = 0; i < attrSize; i++) { | |||
| MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the target, reverse_: " << reverse_; | |||
| std::vector<AnfNodePtr> inputs2; | |||
| inputs2.push_back(fn_rec); | |||
| if (fn_arg) { | |||
| inputs2.push_back(fn_arg); | |||
| } | |||
| size_t size = arg_map.size(); | |||
| for (size_t j = 0; j < size; j++) { | |||
| size_t pos = (reverse_ ? (size - 1 - j) : j); | |||
| auto &item = arg_map[pos]; | |||
| inputs2.push_back( | |||
| func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))})); | |||
| } | |||
| auto call_node = func_graph->NewCNodeInOrder(inputs2); | |||
| if (reverse_) { | |||
| inputs.insert(inputs.begin() + kPrimAndTypeLen, call_node); | |||
| } else { | |||
| inputs.emplace_back(call_node); | |||
| } | |||
| } | |||
| return func_graph->NewCNodeInOrder(inputs); | |||
| } | |||
| AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||
| bool found = false; | |||
| bool is_leaf = false; | |||
| TypeId id = kObjectTypeEnd; | |||
| std::pair<AnfNodePtr, TypePtr> pair; | |||
| for (auto &item : arg_map) { | |||
| pair = item; | |||
| id = item.second->type_id(); | |||
| if (nonleaf_.count(id)) { | |||
| found = true; | |||
| // The graph building reaches the leaf situation when there exists type that can not be divided any more. | |||
| if (!nonleaf_.count(id)) { | |||
| is_leaf = true; | |||
| break; | |||
| } | |||
| } | |||
| if (found) { | |||
| if (!is_leaf) { | |||
| // In a nonleaf situation, all arguments must have the same generic. | |||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) { | |||
| if (item.first != pair.first) { | |||
| @@ -328,7 +279,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_a | |||
| ++idx; | |||
| oss << "The type of the " << str_index << " argument in HyperMap is " << item.second->ToString() << ".\n"; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The types of arguments in HyperMap must be consistent, " | |||
| MS_LOG(EXCEPTION) << "In a nonleaf situation, the types of arguments in HyperMap must be consistent, " | |||
| << "but the types of arguments are inconsistent.\n" | |||
| << oss.str(); | |||
| } | |||
| @@ -343,36 +294,11 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_a | |||
| auto type = std::static_pointer_cast<Tuple>(pair.second); | |||
| return FullMake(type, func_graph, fn_arg, arg_map); | |||
| } | |||
| case kObjectTypeClass: { | |||
| auto type = std::static_pointer_cast<Class>(pair.second); | |||
| return FullMake(type, func_graph, fn_arg, arg_map); | |||
| } | |||
| default: | |||
| return FullMake(func_graph, fn_arg, arg_map); | |||
| } | |||
| } | |||
| ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { | |||
| TypePtr type_tensor = std::make_shared<TensorType>(); | |||
| bool flag = std::any_of( | |||
| args_spec_list.begin(), args_spec_list.end(), | |||
| [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); }); | |||
| if (flag && broadcast_) { | |||
| ArgsPairList ret; | |||
| for (auto &item : args_spec_list) { | |||
| if (!IsSubType(item.second, type_tensor)) { | |||
| TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second); | |||
| ret.push_back(std::make_pair(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimScalarToArray), item.first}), | |||
| type_tensor_ele)); | |||
| } else { | |||
| ret.push_back(std::make_pair(item.first, item.second)); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| return args_spec_list; | |||
| } | |||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||
| FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>(); | |||
| ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| @@ -382,7 +308,6 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||
| AnfNodePtr ptrFnArg = nullptr; | |||
| std::size_t i = 0; | |||
| ArgsPairList argmap; | |||
| ArgsPairList argmap2; | |||
| if (fn_leaf_ == nullptr) { | |||
| ptrFnArg = ptr_graph->add_parameter(); | |||
| i = 1; | |||
| @@ -393,8 +318,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||
| argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i])); | |||
| } | |||
| argmap2 = Harmonize(ptr_graph, argmap); | |||
| ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2)); | |||
| ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap)); | |||
| return ptr_graph; | |||
| } | |||
| @@ -56,7 +56,6 @@ class HyperMap : public MetaFuncGraph { | |||
| if (this != &h) { | |||
| fn_leaf_ = h.fn_leaf_; | |||
| reverse_ = h.reverse_; | |||
| broadcast_ = h.broadcast_; | |||
| nonleaf_ = h.nonleaf_; | |||
| if (fn_leaf_) { | |||
| name_ = "hyper_map[" + fn_leaf_->name() + "]"; | |||
| @@ -77,15 +76,11 @@ class HyperMap : public MetaFuncGraph { | |||
| const ArgsPairList &arg_map); | |||
| AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_map); | |||
| AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_map); | |||
| AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); | |||
| ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); | |||
| std::pair<std::string, std::string> GetHyperMapInputIndex(size_t num); | |||
| MultitypeFuncGraphPtr fn_leaf_; | |||
| bool reverse_; | |||
| bool broadcast_; | |||
| std::set<TypeId> nonleaf_; | |||
| }; | |||
| using HyperMapPtr = std::shared_ptr<HyperMap>; | |||
| @@ -206,48 +206,6 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGrap | |||
| return func_graph->NewCNodeInOrder(inputs); | |||
| } | |||
| AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| size_t attrSize = type->GetAttributes().size(); | |||
| constexpr size_t kPrimAndTypeLen = 2; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.reserve(attrSize + kPrimAndTypeLen); | |||
| inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); | |||
| inputs.push_back(NewValueNode(type)); | |||
| for (size_t i = 0; i < attrSize; i++) { | |||
| MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the inputs, reverse_: " << reverse_ << "."; | |||
| auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); | |||
| auto fn = NewValueNode(ptrGraph); | |||
| std::vector<AnfNodePtr> inputs2; | |||
| inputs2.push_back(fn); | |||
| if (fn_arg != nullptr) { | |||
| inputs2.push_back(fn_arg); | |||
| } | |||
| size_t size = arg_pairs.size(); | |||
| for (size_t j = 0; j < size; j++) { | |||
| size_t pos = (reverse_ ? (size - 1 - j) : j); | |||
| auto &item = arg_pairs[pos]; | |||
| inputs2.push_back( | |||
| func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))})); | |||
| } | |||
| auto call_node = func_graph->NewCNodeInOrder(inputs2); | |||
| if (reverse_) { | |||
| constexpr auto kCallNodePosition = 2; | |||
| (void)inputs.insert(inputs.begin() + kCallNodePosition, call_node); | |||
| } else { | |||
| inputs.emplace_back(call_node); | |||
| } | |||
| } | |||
| return func_graph->NewCNodeInOrder(inputs); | |||
| } | |||
| AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||
| if (arg_pairs.empty()) { | |||
| MS_EXCEPTION(TypeError) << "The Map operator must have at least two arguments. But the size of arguments is " | |||
| @@ -308,13 +266,8 @@ AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, c | |||
| auto type = std::static_pointer_cast<Tuple>(pair.second); | |||
| return FullMakeTuple(type, func_graph, fn_arg, arg_pairs); | |||
| } | |||
| case kObjectTypeClass: { | |||
| auto type = std::static_pointer_cast<Class>(pair.second); | |||
| return FullMakeClass(type, func_graph, fn_arg, arg_pairs); | |||
| } | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class, but got " << pair.second->ToString() | |||
| << "."; | |||
| MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple, but got " << pair.second->ToString() << "."; | |||
| } | |||
| } | |||
| @@ -39,7 +39,7 @@ class Map : public MetaFuncGraph { | |||
| fn_leaf_(fn_leaf), | |||
| reverse_(reverse), | |||
| broadcast_(false), | |||
| nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { | |||
| nonleaf_({kObjectTypeList, kObjectTypeTuple}) { | |||
| Init(); | |||
| } | |||
| Map(const Map &map) | |||
| @@ -75,8 +75,6 @@ class Map : public MetaFuncGraph { | |||
| const ArgsPairList &arg_pairs); | |||
| AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_pairs); | |||
| AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_pairs); | |||
| AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); | |||
| std::pair<std::string, std::string> GetMapInputIndex(size_t num); | |||
| void Init() { | |||
| @@ -612,6 +612,9 @@ class HyperMap(HyperMap_): | |||
| If `ops` is `None`, the first input is the operation, and the others are inputs. | |||
| Note: | |||
| Except for the operation input, the number of inputs should be equal to the number of inputs to `ops`. | |||
| Outputs: | |||
| Sequence or nested sequence, the sequence of output after applying the function. | |||
| e.g. `operation(args[0][i], args[1][i])`. | |||
| @@ -0,0 +1,90 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore import context, nn, Tensor | |||
| from mindspore import dtype as mstype | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| single_element_fg = C.MultitypeFuncGraph("single_element_fg") | |||
| @single_element_fg.register("Tensor") | |||
| def single_element_fg_for_tensor(x): | |||
| return P.Square()(x) | |||
| double_elements_fg = C.MultitypeFuncGraph("double_elements_fg") | |||
| @double_elements_fg.register("Tensor", "Tuple") | |||
| def double_elements_fg_for_tensor_tuple(x, y): | |||
| return P.Tile()(x, y) | |||
| class HyperMapNet(nn.Cell): | |||
| def __init__(self, fg): | |||
| super(HyperMapNet, self).__init__() | |||
| self.common_map = C.HyperMap() | |||
| self.fg = fg | |||
| def construct(self, nest_tensor_list): | |||
| output = self.common_map(self.fg, *nest_tensor_list) | |||
| return output | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_single_element_hypermap(): | |||
| """ | |||
| Feature: HyperMap | |||
| Description: Test whether the HyperMap with single tensor input can run successfully. | |||
| Expectation: success. | |||
| """ | |||
| x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32)) | |||
| common_map = HyperMapNet(single_element_fg) | |||
| output = common_map((x,)) | |||
| expect_output_1 = np.array([1.0, 4.0, 9.0]) | |||
| expect_output_2 = np.array([16.0, 25.0, 36.0]) | |||
| assert isinstance(output, tuple) | |||
| assert len(output) == 2 | |||
| assert isinstance(output[0], Tensor) | |||
| assert isinstance(output[1], Tensor) | |||
| assert np.allclose(output[0].asnumpy(), expect_output_1) | |||
| assert np.allclose(output[1].asnumpy(), expect_output_2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_double_elements_hypermap(): | |||
| """ | |||
| Feature: HyperMap | |||
| Description: Test whether the HyperMap with tensor and tuple inputs can run successfully. | |||
| Expectation: success. | |||
| """ | |||
| x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32)) | |||
| y = ((1, 2), (2, 1)) | |||
| common_map = HyperMapNet(double_elements_fg) | |||
| output = common_map((x, y)) | |||
| expect_output_1 = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]) | |||
| expect_output_2 = np.array([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]]) | |||
| assert isinstance(output, tuple) | |||
| assert len(output) == 2 | |||
| assert isinstance(output[0], Tensor) | |||
| assert isinstance(output[1], Tensor) | |||
| assert np.allclose(output[0].asnumpy(), expect_output_1) | |||
| assert np.allclose(output[1].asnumpy(), expect_output_2) | |||
| @@ -38,7 +38,7 @@ def test_hypermap_noleaf_tuple_list_mix(): | |||
| """ | |||
| tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | |||
| tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | |||
| with pytest.raises(Exception, match="The types of arguments in HyperMap must be consistent"): | |||
| with pytest.raises(Exception, match="the types of arguments in HyperMap must be consistent"): | |||
| main_noleaf((tensor1, 1), [tensor2, 2]) | |||
| @@ -74,7 +74,7 @@ def test_hypermap_noleaf_list_tuple(): | |||
| """ | |||
| tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | |||
| tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | |||
| with pytest.raises(Exception, match="The types of arguments in HyperMap must be consistent"): | |||
| with pytest.raises(Exception, match="the types of arguments in HyperMap must be consistent"): | |||
| main_noleaf([tensor1], (tensor2, tensor2)) | |||
| @@ -87,14 +87,14 @@ def test_tuple_slice_stop_index(): | |||
| class TupleSliceNet(Cell): | |||
| def __init__(self): | |||
| super(TupleSliceNet, self).__init__() | |||
| self.addN = P.AddN() | |||
| self.addn = P.AddN() | |||
| self.index_0 = Tensor(3) | |||
| def construct(self, tensor_tuple): | |||
| tensor_tuple_slice0 = tensor_tuple[:] | |||
| tensor_tuple_slice1 = tensor_tuple[self.index_0:"str"] # slice should be Scalar or None, rather than string | |||
| sum0 = self.addN(tensor_tuple_slice0) | |||
| sum1 = self.addN(tensor_tuple_slice1) | |||
| sum0 = self.addn(tensor_tuple_slice0) | |||
| sum1 = self.addn(tensor_tuple_slice1) | |||
| ret = sum0 + sum1 | |||
| return ret | |||
| @@ -120,7 +120,7 @@ def test_tuple_slice_start_index(): | |||
| class TupleSliceNet(Cell): | |||
| def __init__(self): | |||
| super(TupleSliceNet, self).__init__() | |||
| self.addN = P.AddN() | |||
| self.addn = P.AddN() | |||
| self.index_0 = Tensor(3) | |||
| self.index_1 = Tensor([5]) | |||
| self.index_3 = Tensor([True]) | |||
| @@ -130,10 +130,10 @@ def test_tuple_slice_start_index(): | |||
| tensor_tuple_slice1 = tensor_tuple["str":self.index_0] | |||
| tensor_tuple_slice2 = tensor_tuple[self.index_3:] | |||
| tensor_tuple_slice3 = tensor_tuple[2:self.index_1:] | |||
| sum0 = self.addN(tensor_tuple_slice0) | |||
| sum1 = self.addN(tensor_tuple_slice1) | |||
| sum2 = self.addN(tensor_tuple_slice2) | |||
| sum3 = self.addN(tensor_tuple_slice3) | |||
| sum0 = self.addn(tensor_tuple_slice0) | |||
| sum1 = self.addn(tensor_tuple_slice1) | |||
| sum2 = self.addn(tensor_tuple_slice2) | |||
| sum3 = self.addn(tensor_tuple_slice3) | |||
| ret = sum0 + sum1 + sum2 + sum3 | |||
| return ret | |||
| @@ -159,7 +159,7 @@ def test_tuple_slice_step(): | |||
| class TupleSliceNet(Cell): | |||
| def __init__(self): | |||
| super(TupleSliceNet, self).__init__() | |||
| self.addN = P.AddN() | |||
| self.addn = P.AddN() | |||
| self.index_0 = Tensor(3) | |||
| self.index_1 = Tensor([5]) | |||
| self.index_3 = Tensor([True]) | |||
| @@ -169,10 +169,10 @@ def test_tuple_slice_step(): | |||
| tensor_tuple_slice1 = tensor_tuple[:self.index_0] | |||
| tensor_tuple_slice2 = tensor_tuple[self.index_3:] | |||
| tensor_tuple_slice3 = tensor_tuple[2:self.index_1:0] | |||
| sum0 = self.addN(tensor_tuple_slice0) | |||
| sum1 = self.addN(tensor_tuple_slice1) | |||
| sum2 = self.addN(tensor_tuple_slice2) | |||
| sum3 = self.addN(tensor_tuple_slice3) | |||
| sum0 = self.addn(tensor_tuple_slice0) | |||
| sum1 = self.addn(tensor_tuple_slice1) | |||
| sum2 = self.addn(tensor_tuple_slice2) | |||
| sum3 = self.addn(tensor_tuple_slice3) | |||
| ret = sum0 + sum1 + sum2 + sum3 | |||
| return ret | |||