Merge pull request !31159 from chenfei_mindspore/master-developr1.7
| @@ -683,6 +683,7 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo> | |||
| } | |||
| void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) { | |||
| MS_LOG(INFO) << "Set dump config:" << str; | |||
| static mindspore::HashMap<std::string, enum LocDumpMode> dump_level_map = { | |||
| {kDumpConfigLineLevel0, kOff}, {kDumpConfigLineLevel1, kTopStack}, {kDumpConfigLineLevel2, kWholeStack}}; | |||
| auto it = dump_level_map.find(str); | |||
| @@ -700,11 +701,64 @@ void SetDumpConfigByString(const std::string &str, DumpConfig *dump_config) { | |||
| } | |||
| } | |||
| std::shared_ptr<OrderedSet<std::string>> GetAllConfigStrings(const std::string &config_full_string) { | |||
| size_t start_pos = 0; | |||
| auto config_strings = std::make_shared<OrderedSet<std::string>>(); | |||
| // if '#' is the last char of str, the str is legal, so we use '<=' but not '<'. | |||
| while (start_pos <= config_full_string.size()) { | |||
| auto pos = config_full_string.find('#', start_pos); | |||
| if (pos == std::string::npos) { | |||
| pos = config_full_string.size(); | |||
| } | |||
| auto substr = config_full_string.substr(start_pos, pos - start_pos); | |||
| // Skip the '#' | |||
| start_pos = pos + 1; | |||
| if (substr.empty()) { | |||
| continue; | |||
| } | |||
| (void)config_strings->insert(substr); | |||
| } | |||
| return config_strings; | |||
| } | |||
| bool ConfigsAreLegal(const std::shared_ptr<OrderedSet<std::string>> &config_strings) { | |||
| // Value 'int' is used to mark config group id | |||
| HashMap<std::string, int> config_white_list = {{kDumpConfigLineLevel0, 0}, | |||
| {kDumpConfigLineLevel1, 0}, | |||
| {kDumpConfigLineLevel2, 0}, | |||
| {kDumpConfigDisableBackend, 1}, | |||
| {kDumpConfigEnablePassIR, 2}}; | |||
| // Key 'int' is config group id, value is the config. | |||
| HashMap<int, std::string> config_groups; | |||
| for (const auto &config_string : *config_strings) { | |||
| auto config_white_list_it = config_white_list.find(config_string); | |||
| if (config_white_list_it == config_white_list.end()) { | |||
| std::ostringstream buffer; | |||
| buffer << "Support configs:\n" | |||
| << "[0]: " << kDumpConfigLineLevel0 << "\n" | |||
| << "[1]: " << kDumpConfigLineLevel1 << "\n" | |||
| << "[2]: " << kDumpConfigLineLevel2 << "\n" | |||
| << "[3]: " << kDumpConfigDisableBackend << "\n" | |||
| << "[4]: " << kDumpConfigEnablePassIR; | |||
| MS_LOG(WARNING) << "Illegal dump config:\n" << config_string << "\n" << buffer.str(); | |||
| return false; | |||
| } | |||
| auto group_id = config_white_list_it->second; | |||
| // Check conflict configs. | |||
| auto config_groups_it = config_groups.find(group_id); | |||
| if (config_groups_it != config_groups.end()) { | |||
| const auto &record_config = config_groups_it->second; | |||
| MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: [" << record_config << "] and [" << config_string | |||
| << "].\n" | |||
| << "Please keep only one of them."; | |||
| return false; | |||
| } | |||
| config_groups[group_id] = config_string; | |||
| } | |||
| return true; | |||
| } | |||
| DumpConfig GetDumpConfig() { | |||
| static std::vector<HashSet<std::string>> config_white_list = { | |||
| {kDumpConfigLineLevel0, kDumpConfigLineLevel1, kDumpConfigLineLevel2}, | |||
| {kDumpConfigDisableBackend}, | |||
| {kDumpConfigEnablePassIR}}; | |||
| static DumpConfig dump_config = DumpConfig(); | |||
| static bool parsed = false; | |||
| if (parsed) { | |||
| @@ -713,9 +767,6 @@ DumpConfig GetDumpConfig() { | |||
| parsed = true; | |||
| // Start parse config. | |||
| std::string str(common::GetEnv("MS_DEV_DUMP_IR_CONFIG")); | |||
| std::vector<std::shared_ptr<HashSet<std::string>>> configs = {std::make_shared<HashSet<std::string>>(), | |||
| std::make_shared<HashSet<std::string>>(), | |||
| std::make_shared<HashSet<std::string>>()}; | |||
| auto constexpr max_string_len = 100; | |||
| if (str.size() > max_string_len) { | |||
| MS_LOG(WARNING) << "Dump ir config length exceed max length: " << max_string_len; | |||
| @@ -724,45 +775,12 @@ DumpConfig GetDumpConfig() { | |||
| if (str.empty()) { | |||
| return dump_config; | |||
| } | |||
| size_t start_pos = 0; | |||
| // if '#' is the last char of str, the str is illegal, so we use '<=' but not '<'. | |||
| while (start_pos <= str.size()) { | |||
| auto pos = str.find('#', start_pos); | |||
| if (pos == std::string::npos) { | |||
| pos = str.size(); | |||
| } | |||
| auto substr = str.substr(start_pos, pos - start_pos); | |||
| start_pos = pos + 1; | |||
| bool is_illegal_config = true; | |||
| for (size_t i = 0; i < config_white_list.size(); i++) { | |||
| if (config_white_list[i].find(substr) != config_white_list[i].end()) { | |||
| is_illegal_config = false; | |||
| (void)configs[i]->insert(substr); | |||
| if (configs[i]->size() > 1) { | |||
| std::ostringstream buffer; | |||
| (void)std::for_each(configs[i]->begin(), configs[i]->end(), [&buffer](const std::string &config) { | |||
| buffer << "\n" << config; | |||
| }); | |||
| MS_LOG(WARNING) << "Dump configs are conflict. Conflict configs: " << buffer.str() << "\n" | |||
| << "Please keep only one of them."; | |||
| return dump_config; | |||
| } | |||
| } | |||
| } | |||
| if (is_illegal_config) { | |||
| std::ostringstream buffer; | |||
| buffer << "Support configs:\n" | |||
| << "[0]: " << kDumpConfigLineLevel0 << "\n" | |||
| << "[1]: " << kDumpConfigLineLevel1 << "\n" | |||
| << "[2]: " << kDumpConfigLineLevel2 << "\n" | |||
| << "[3]: " << kDumpConfigDisableBackend << "\n" | |||
| << "[4]: " << kDumpConfigEnablePassIR; | |||
| MS_LOG(WARNING) << "Illegal dump config:\n" << substr << "\n" << buffer.str(); | |||
| return {}; | |||
| } | |||
| auto config_strings = GetAllConfigStrings(str); | |||
| if (!ConfigsAreLegal(config_strings)) { | |||
| return dump_config; | |||
| } | |||
| for (auto &config : configs) { | |||
| SetDumpConfigByString(*config->begin(), &dump_config); | |||
| for (const auto &config : *config_strings) { | |||
| SetDumpConfigByString(config, &dump_config); | |||
| } | |||
| return dump_config; | |||
| } | |||
| @@ -87,7 +87,7 @@ class GraphTupleTransform : public AnfVisitor { | |||
| GraphTupleParamTransform graph_transform_; | |||
| }; | |||
| // {,kPrimPartial, G, Tuple_Xs} | |||
| // {PrimPartial, G, Tuple_Xs} | |||
| // => | |||
| // {kPrimPartial, G, TupleGetItem{Tuple_Xs,0}, TupleGetItem{Tuple_Xs,1}, ..., TupleGetItem{Tuple_Xs,n}} | |||
| // transform partial's tuple binding args to flat inputs. | |||
| @@ -102,12 +102,12 @@ class PartialTupleArgTransform : public AnfVisitor { | |||
| auto partial = node->cast<CNodePtr>(); | |||
| const auto &partial_inputs = partial->inputs(); | |||
| const auto &fg = partial->func_graph(); | |||
| // And primitive and function value node into args. | |||
| constexpr auto kPartialFirstArgIndex = 2; | |||
| auto new_args = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex); | |||
| auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_args); | |||
| // Put ValueNode<kPrimPartial> and ValueNode<FuncGraph> into new_inputs. | |||
| auto new_inputs = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex); | |||
| auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_inputs); | |||
| if (change) { | |||
| auto new_partial = fg->NewCNode(new_args); | |||
| auto new_partial = fg->NewCNode(new_inputs); | |||
| new_partial->set_abstract(partial->abstract()); | |||
| return new_partial; | |||
| } | |||
| @@ -132,11 +132,11 @@ class CallTupleArgTransform : public AnfVisitor { | |||
| const auto &call_inputs = call_node->inputs(); | |||
| const auto &fg = call_node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| // Add function value node into args. | |||
| auto new_args = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1); | |||
| auto change = FlattenArgs(fg, call_inputs, 1, &new_args); | |||
| // Put ValueNode<FuncGraph> into inputs. | |||
| auto new_inputs = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1); | |||
| auto change = FlattenArgs(fg, call_inputs, 1, &new_inputs); | |||
| if (change) { | |||
| auto new_call = fg->NewCNode(new_args); | |||
| auto new_call = fg->NewCNode(new_inputs); | |||
| new_call->set_abstract(call_node->abstract()); | |||
| return new_call; | |||
| } | |||
| @@ -0,0 +1,99 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import mindspore as ms | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.nn as nn | |||
| import numpy as np | |||
| import pytest | |||
| class MAPPOCriticNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.linear1_actor = nn.Dense(54, # input local obs shape | |||
| 64, | |||
| weight_init='XavierUniform', | |||
| # paper uses orthogonal with gain 5/3 for every dense123 | |||
| has_bias=False, | |||
| activation=nn.Tanh()) | |||
| def construct(self, x): | |||
| # Feature Extraction | |||
| x = self.linear1_actor(x) | |||
| return x | |||
| class MAPPOActor(nn.Cell): | |||
| def __init__(self, actor_net): | |||
| super().__init__() | |||
| self.actor_net = actor_net | |||
| def construct(self, inputs_data): | |||
| _, global_obs = inputs_data | |||
| out = self.actor_net(global_obs) | |||
| return out | |||
| class TestClass(nn.Cell): | |||
| def __init__(self, actor_list): | |||
| super().__init__() | |||
| self.zero = Tensor(0, ms.int32) | |||
| self.actor_list = actor_list | |||
| self.less = P.Less() | |||
| self.zeros = P.Zeros() | |||
| def train(self): | |||
| state = Tensor(np.random.random((3, 128, 18)), ms.float32) | |||
| init_global_obs = self.zeros((128, 54), ms.float32) | |||
| out = self.test(state, init_global_obs) | |||
| return out | |||
| @ms_function | |||
| def test(self, state, init_global_obs): | |||
| num_agent = self.zero | |||
| while self.less(num_agent, 3): | |||
| samples = (state[num_agent], init_global_obs) | |||
| self.actor_list[num_agent](samples) | |||
| num_agent += 1 | |||
| return num_agent | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_net(): | |||
| """ | |||
| Feature: Tuple arg transform. | |||
| Description: Test the pass: transform tuple arg to tensor arg. | |||
| Expectation: Compile done without error. | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False, save_graphs_path="./graph_ir") | |||
| actor_list = nn.CellList() | |||
| for _ in range(3): | |||
| net = MAPPOCriticNet() | |||
| actor = MAPPOActor(net) | |||
| actor_list.append(actor) | |||
| test = TestClass(actor_list) | |||
| graph_out = test.train() | |||
| assert np.allclose(graph_out.asnumpy(), graph_out.asnumpy(), 0.0001, 0.0001) | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -22,10 +22,12 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/opt.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/irpass/arithmetic_simplify.h" | |||
| #include "pipeline/jit/action.h" | |||
| #include "debug/draw.h" | |||
| #include "frontend/operator/ops.h" | |||
| @@ -107,6 +109,8 @@ class TestOptOpt : public UT::Common { | |||
| FuncGraphPairMapEquiv equiv_graph; | |||
| NodeMapEquiv equiv_node; | |||
| irpass::OptimizeIRPassLib irpass_lib; | |||
| static const PrimitivePtr P; | |||
| static const PrimitivePtr Q; | |||
| static const PrimitivePtr R; | |||
| @@ -115,6 +119,7 @@ class TestOptOpt : public UT::Common { | |||
| SubstitutionPtr elim_R; | |||
| SubstitutionPtr idempotent_P; | |||
| SubstitutionPtr Qct_to_P; | |||
| SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_; | |||
| }; | |||
| const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P"); | |||
| @@ -148,8 +153,8 @@ TEST_F(TestOptOpt, ElimTwo) { | |||
| } | |||
| TEST_F(TestOptOpt, ElimR) { | |||
| FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1"); | |||
| FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after"); | |||
| FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_r", "before_1"); | |||
| FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_r", "after"); | |||
| ASSERT_TRUE(nullptr != before); | |||
| ASSERT_TRUE(nullptr != after); | |||
| @@ -208,5 +213,125 @@ TEST_F(TestOptOpt, CSE) { | |||
| ASSERT_EQ(manager2->all_nodes().size(), 12); | |||
| } | |||
| size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) { | |||
| // Check tuple params and tuple args. | |||
| auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude); | |||
| size_t tuple_arg_param_num = 0; | |||
| auto tuple_accumulate_func = [](size_t prev_num, const AnfNodePtr &node) -> size_t { | |||
| auto abs = node->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abs); | |||
| return abs->isa<abstract::AbstractTuple>() ? prev_num + 1 : prev_num; | |||
| }; | |||
| for (const auto &node : all_nodes) { | |||
| // Count func graph call tuple args. | |||
| if (node->isa<CNode>() && !IsValueNode<Primitive>(node->cast<CNodePtr>()->input(0))) { | |||
| auto call_node = node->cast<CNodePtr>(); | |||
| tuple_arg_param_num = std::accumulate(call_node->inputs().begin() + 1, call_node->inputs().end(), | |||
| tuple_arg_param_num, tuple_accumulate_func); | |||
| } | |||
| // Count partial tuple args. | |||
| if (IsPrimitiveCNode(node, prim::kPrimPartial)) { | |||
| auto partial = node->cast<CNodePtr>(); | |||
| constexpr auto kPartialFirstArgIdx = 2; | |||
| tuple_arg_param_num = std::accumulate(partial->inputs().begin() + kPartialFirstArgIdx, partial->inputs().end(), | |||
| tuple_arg_param_num, tuple_accumulate_func); | |||
| } | |||
| // Count tuple params. | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| auto fg = GetValueNode<FuncGraphPtr>(node); | |||
| tuple_arg_param_num = | |||
| std::accumulate(fg->parameters().begin(), fg->parameters().end(), tuple_arg_param_num, tuple_accumulate_func); | |||
| } | |||
| } | |||
| return tuple_arg_param_num; | |||
| } | |||
| // Feature: Switch call tuple arg transform. | |||
| // Description: Test switch call's tuple arg transform.This case include partial's tuple arg and the call's tuple arg in | |||
| // the same time. | |||
| // Expectation: All tuple args are correctly transformed to tensor args. | |||
| TEST_F(TestOptOpt, SwitchPartialTupleTrans) { | |||
| FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_partial_arg"); | |||
| ASSERT_TRUE(nullptr != test_graph); | |||
| FuncGraphManagerPtr manager1 = Manage(test_graph); | |||
| pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>(); | |||
| std::vector<AbstractBasePtr> args_spec; | |||
| // Renormalize firstly. | |||
| auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec); | |||
| ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0); | |||
| // Flatten tuple param and args. | |||
| OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res); | |||
| SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten})); | |||
| transform(renormalized_fg, optimizer); | |||
| // Renormalize again. | |||
| auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec); | |||
| ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0); | |||
| abstract::AnalysisResultCacheMgr::GetInstance().Clear(); | |||
| abstract::AnalysisContext::ClearContext(); | |||
| } | |||
| // Feature: Switch layer call tuple arg transform. | |||
| // Description: Test switch layer call's tuple arg transform.This case include partial's tuple arg and the partial's | |||
| // tensor arg in the same time. | |||
| // Expectation: All tuple args are correctly transformed to tensor args. | |||
| TEST_F(TestOptOpt, SwitchLayerPartialTupleTrans) { | |||
| FuncGraphPtr test_graph = getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_switch_layer_partial_arg"); | |||
| ASSERT_TRUE(nullptr != test_graph); | |||
| FuncGraphManagerPtr manager1 = Manage(test_graph); | |||
| pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>(); | |||
| std::vector<AbstractBasePtr> args_spec; | |||
| // Renormalize firstly. | |||
| auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec); | |||
| ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0); | |||
| // Flatten tuple param and args. | |||
| OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res); | |||
| SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten})); | |||
| transform(renormalized_fg, optimizer); | |||
| // Renormalize again. | |||
| auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec); | |||
| ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0); | |||
| abstract::AnalysisResultCacheMgr::GetInstance().Clear(); | |||
| abstract::AnalysisContext::ClearContext(); | |||
| } | |||
| // Feature: Single graph call tuple arg transform. | |||
| // Description: Test single graph call's tuple arg transform.This case include tuple in tuple args. | |||
| // Expectation: All tuple args are correctly transformed to tensor args. | |||
| TEST_F(TestOptOpt, SimpleCallTupleTupleTrans) { | |||
| FuncGraphPtr test_graph = | |||
| getPyFun.CallAndParseRet("test_tuple_flatten", "test_flatten_simple_call_tuple_in_tuple_arg"); | |||
| ASSERT_TRUE(nullptr != test_graph); | |||
| FuncGraphManagerPtr manager1 = Manage(test_graph); | |||
| pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>(); | |||
| std::vector<AbstractBasePtr> args_spec; | |||
| // Renormalize firstly. | |||
| auto renormalized_fg = pipeline::Renormalize(res, test_graph, args_spec); | |||
| ASSERT_TRUE(TupleArgAndParamSum(renormalized_fg) != 0); | |||
| // Flatten tuple param and args. | |||
| OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", res); | |||
| SubstitutionList transform(std::vector<SubstitutionPtr>({tuple_flatten})); | |||
| transform(renormalized_fg, optimizer); | |||
| // Renormalize again. | |||
| auto transformed_fg = pipeline::Renormalize(res, renormalized_fg, args_spec); | |||
| ASSERT_TRUE(TupleArgAndParamSum(transformed_fg) == 0); | |||
| abstract::AnalysisResultCacheMgr::GetInstance().Clear(); | |||
| abstract::AnalysisContext::ClearContext(); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -16,9 +16,11 @@ | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore import dtype as mstype | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import _constants as Constants | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| # pylint: disable=unused-variable | |||
| @@ -68,8 +70,12 @@ def test_add_zero(tag): | |||
| return fns[tag] | |||
| def test_elimR(tag): | |||
| """ test_elimR """ | |||
| def test_elim_r(tag): | |||
| """ | |||
| Feature: optimizer. | |||
| Description: test elimi R. | |||
| Expectation: run case with no exception. | |||
| """ | |||
| R = Primitive('R') | |||
| fns = FnDict() | |||
| @@ -495,6 +501,7 @@ def test_elim_transpose(tag): | |||
| return fns[tag] | |||
| def test_elim_depend_value(tag): | |||
| """ test_elim_depend_value """ | |||
| fns = FnDict() | |||
| @@ -1203,3 +1210,85 @@ def test_sparse_tensor(tag): | |||
| return z | |||
| return fns[tag] | |||
| # Test ut for file: call_graph_tuple_transform.h. | |||
| def test_tuple_flatten(tag): | |||
| """ | |||
| Feature: optimizer. | |||
| Description: test cases for pass: graph_tuple_transform. | |||
| Expectation: the tuple args and parameters are successfully flattened by the pass. | |||
| """ | |||
| fns = FnDict() | |||
| w = Tensor(np.random.randn(64, 3, 7, 7).astype(np.float32)) | |||
| x = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32)) | |||
| y = Tensor(np.random.randn(32, 3, 224, 224).astype(np.float32)) | |||
| p = Tensor(3, mstype.float32) | |||
| out_channel = 64 | |||
| kernel_size = 7 | |||
| conv = P.Conv2D(out_channel, | |||
| kernel_size, | |||
| mode=1, | |||
| pad_mode="valid", | |||
| pad=0, | |||
| stride=1, | |||
| dilation=1, | |||
| group=1) | |||
| pow_ops = P.Pow() | |||
| @fns | |||
| def test_flatten_switch_partial_arg(): | |||
| def called_graph_with_tuple(tuple_x, tuple_y): | |||
| return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) + conv(F.tuple_getitem(tuple_y, 0), | |||
| F.tuple_getitem(tuple_y, 1)) | |||
| # Add tuple args in partial args. | |||
| func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p))) | |||
| func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p))) | |||
| cond = x < y | |||
| switch_node = F.switch(cond, func1, func2) | |||
| # Add tuple args in call args. | |||
| return switch_node((pow_ops(x, p), pow_ops(w, p))) | |||
| index = Tensor(1, mstype.int32) | |||
| @fns | |||
| def test_flatten_switch_layer_partial_arg(): | |||
| def called_graph_with_tuple(tuple_x): | |||
| return conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) | |||
| def called_graph_no_tuple(param1, param2): | |||
| return conv(param1, param2) | |||
| # Add tuple args in partial | |||
| func1 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p))) | |||
| func2 = F.partial(called_graph_with_tuple, (pow_ops(x, p), pow_ops(w, p))) | |||
| # Add tensor args in partial | |||
| func3 = F.partial(called_graph_no_tuple, pow_ops(x, p), pow_ops(w, p)) | |||
| switch_node = F.switch_layer(pow_ops(index, index), (func1, func2, func3)) | |||
| return switch_node() | |||
| @fns | |||
| def test_flatten_simple_call_tuple_in_tuple_arg(): | |||
| def called_graph_with_tuple(tuple_x, tuple_tuple_y, tensor_z): | |||
| result1 = conv(F.tuple_getitem(tuple_x, 0), F.tuple_getitem(tuple_x, 1)) | |||
| tuple_0 = F.tuple_getitem(tuple_tuple_y, 0) | |||
| result2 = conv(F.tuple_getitem(tuple_0, 0), F.tuple_getitem(tuple_0, 1)) | |||
| tensor_1 = F.tuple_getitem(tuple_tuple_y, 1) | |||
| result3 = conv(tensor_1, tensor_z) | |||
| return result1 + result2 + result3 | |||
| # Tuple arg. | |||
| tuple_x_arg = (pow_ops(x, p), pow_ops(w, p)) | |||
| # TupleTuple arg. | |||
| tuple_0_arg = (pow_ops(x, p), pow_ops(w, p)) | |||
| tensor_1_arg = pow_ops(x, p) | |||
| tuple_tuple_y_arg = (tuple_0_arg, tensor_1_arg) | |||
| # TensorArg | |||
| tensor_z_arg = pow_ops(w, p) | |||
| return called_graph_with_tuple(tuple_x_arg, tuple_tuple_y_arg, tensor_z_arg) | |||
| return fns[tag] | |||