From: @liangzhibo Reviewed-by: @kingxian,@jpc_chenjianping Signed-off-by: @kingxiantags/v1.2.0-rc1
| @@ -92,7 +92,7 @@ def zeros_like_tensor(x): | |||||
| return value | return value | ||||
| def switch(c, x, y): | |||||
| def Switch(c, x, y): | |||||
| """Implement `switch`.""" | """Implement `switch`.""" | ||||
| return x if c else y | return x if c else y | ||||
| @@ -47,7 +47,7 @@ namespace abstract { | |||||
| using mindspore::parse::PyObjectWrapper; | using mindspore::parse::PyObjectWrapper; | ||||
| std::unordered_set<std::string> prims_to_skip_undetermined_infer{ | std::unordered_set<std::string> prims_to_skip_undetermined_infer{ | ||||
| "MakeTuple", "make_list", "switch", "env_setitem", "env_getitem", "Load", "UpdateState"}; | |||||
| "MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"}; | |||||
| EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | ||||
| AnfNodeConfigPtr out_conf) { | AnfNodeConfigPtr out_conf) { | ||||
| @@ -184,7 +184,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto"; | |||||
| constexpr auto kBNInferGradOpName = "BNInferGrad"; | constexpr auto kBNInferGradOpName = "BNInferGrad"; | ||||
| constexpr auto kCallOpName = "call"; | constexpr auto kCallOpName = "call"; | ||||
| constexpr auto kPartialOpName = "partial"; | constexpr auto kPartialOpName = "partial"; | ||||
| constexpr auto kSwitchOpName = "switch"; | |||||
| constexpr auto kSwitchOpName = "Switch"; | |||||
| constexpr auto kReturnOpName = "Return"; | constexpr auto kReturnOpName = "Return"; | ||||
| constexpr auto kLarsV2OpName = "LarsV2"; | constexpr auto kLarsV2OpName = "LarsV2"; | ||||
| constexpr auto kLarsV2UpdateOpName = "LarsV2Update"; | constexpr auto kLarsV2UpdateOpName = "LarsV2Update"; | ||||
| @@ -63,7 +63,7 @@ using InstType = std::pair<Instruction, VectorRef>; | |||||
| using InstSet = std::vector<InstType>; | using InstSet = std::vector<InstType>; | ||||
| using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>; | using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>; | ||||
| const std::vector<std::string> inst_str{"call", "tail_call", "Return", "partial", "switch", | |||||
| const std::vector<std::string> inst_str{"call", "tail_call", "Return", "partial", "Switch", | |||||
| "switch_return", "tuple", "input", "external", "push", | "switch_return", "tuple", "input", "external", "push", | ||||
| "primitive", "graph", "pad_stack", "switch_layer"}; | "primitive", "graph", "pad_stack", "switch_layer"}; | ||||
| class StructPartial : public Base { | class StructPartial : public Base { | ||||
| @@ -404,7 +404,7 @@ inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where"); | |||||
| // Statements | // Statements | ||||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return"); | inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return"); | ||||
| inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch"); | |||||
| inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("Switch"); | |||||
| inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer"); | inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer"); | ||||
| inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign"); | inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign"); | ||||
| inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd"); | inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd"); | ||||
| @@ -267,7 +267,7 @@ def bprop_control_depend(x, y, out, dout): | |||||
| return C.zeros_like(x), C.zeros_like(y) | return C.zeros_like(x), C.zeros_like(y) | ||||
| @bprops.register("switch") | |||||
| @bprops.register("Switch") | |||||
| def bprop_switch(cond, tb, fb, out, dout): | def bprop_switch(cond, tb, fb, out, dout): | ||||
| """Backpropagator for primitive `switch`.""" | """Backpropagator for primitive `switch`.""" | ||||
| return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \ | return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \ | ||||
| @@ -181,7 +181,7 @@ env_setitem = Primitive('env_setitem') | |||||
| env_getitem = Primitive('env_getitem') | env_getitem = Primitive('env_getitem') | ||||
| env_add = Primitive('env_add') | env_add = Primitive('env_add') | ||||
| J = Primitive('J') | J = Primitive('J') | ||||
| switch = Primitive('switch') | |||||
| switch = Primitive('Switch') | |||||
| switch_layer = Primitive('switch_layer') | switch_layer = Primitive('switch_layer') | ||||
| # for sum bprop | # for sum bprop | ||||
| reduced_shape = Primitive("reduced_shape") | reduced_shape = Primitive("reduced_shape") | ||||
| @@ -310,7 +310,7 @@ TEST_F(TestOps, Col2ImV1Test) { | |||||
| // Statements | // Statements | ||||
| TEST_F(TestOps, SwitchTest) { | TEST_F(TestOps, SwitchTest) { | ||||
| auto prim = std::make_shared<Primitive>("switch"); | |||||
| auto prim = std::make_shared<Primitive>("Switch"); | |||||
| ASSERT_EQ(prim->name(), kPrimSwitch->name()); | ASSERT_EQ(prim->name(), kPrimSwitch->name()); | ||||
| } | } | ||||
| @@ -294,7 +294,7 @@ TEST_F(TestPrim, test_J_2) { | |||||
| // tail half | // tail half | ||||
| TEST_F(TestPrim, test_switch1) { | TEST_F(TestPrim, test_switch1) { | ||||
| PrimitivePtr switch_ = std::make_shared<Primitive>("switch"); | |||||
| PrimitivePtr switch_ = std::make_shared<Primitive>("Switch"); | |||||
| FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3); | FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3); | ||||
| AbstractBasePtr arg0 = FromValue(true, false); | AbstractBasePtr arg0 = FromValue(true, false); | ||||
| @@ -307,7 +307,7 @@ TEST_F(TestPrim, test_switch1) { | |||||
| } | } | ||||
| TEST_F(TestPrim, test_switch2) { | TEST_F(TestPrim, test_switch2) { | ||||
| PrimitivePtr switch_ = std::make_shared<Primitive>("switch"); | |||||
| PrimitivePtr switch_ = std::make_shared<Primitive>("Switch"); | |||||
| FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3); | FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3); | ||||
| AbstractBasePtr arg0 = FromValue(false, false); | AbstractBasePtr arg0 = FromValue(false, false); | ||||
| @@ -30,7 +30,7 @@ from mindspore.ops.operations import _grad_ops as G | |||||
| scalar_add = Primitive(Constants.kScalarAdd) | scalar_add = Primitive(Constants.kScalarAdd) | ||||
| scalar_mul = Primitive(Constants.kScalarMul) | scalar_mul = Primitive(Constants.kScalarMul) | ||||
| tuple_getitem = Primitive(Constants.kTupleGetItem) | tuple_getitem = Primitive(Constants.kTupleGetItem) | ||||
| switch = Primitive('switch') | |||||
| switch = Primitive('Switch') | |||||
| def test_sexp_conversion(): | def test_sexp_conversion(): | ||||