Merge pull request !4730 from riemann_penn/fix_grad_operation_apitags/v0.7.0-beta
| @@ -462,7 +462,8 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { | |||
| std::vector<FuncGraphPtr> graphs{}; | |||
| auto graphs_cnode = sw->input(2)->cast<CNodePtr>(); | |||
| auto &graphs_inputs = graphs_cnode->inputs(); | |||
| if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(graphs_inputs[1])) { | |||
| if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= 2 && | |||
| IsValueNode<FuncGraph>(graphs_inputs[1])) { | |||
| (void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs), | |||
| [](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); }); | |||
| } | |||
| @@ -89,6 +89,7 @@ class GetItemTransformACrossGraph { | |||
| ss << idx; | |||
| auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str())); | |||
| fg->manager()->AddFuncGraph(new_fg_outer); | |||
| auto output_outer = new_fg_outer->output(); | |||
| if (!IsValueNode<FuncGraph>(output_outer)) { | |||
| MS_LOG(WARNING) << "Output of outer graph should be a func_graph"; | |||
| @@ -486,7 +487,7 @@ class IncorporateGetitemSwitchLayerA : public AnfVisitor { | |||
| switch_layer_ = inputs[0]; | |||
| (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); | |||
| } | |||
| if (is_in_switch_ && cnode->size() > 2) { | |||
| if (is_in_switch_ && cnode->size() >= 2) { | |||
| auto &inputs = cnode->inputs(); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) { | |||
| (void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), | |||
| @@ -578,7 +579,7 @@ class IncorporateGetitemSwitchLayerB : public AnfVisitor { | |||
| switch_layer_call_ = inputs[0]; | |||
| (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outer_call_args_)); | |||
| } | |||
| if (is_in_switch_ && cnode->size() > 2) { | |||
| if (is_in_switch_ && cnode->size() >= 2) { | |||
| auto &inputs = cnode->inputs(); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) { | |||
| (void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_), | |||
| @@ -36,10 +36,9 @@ class SwitchLayerDeferInline : public AnfVisitor { | |||
| auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract()); | |||
| for (auto elem : tuple->elements()) { | |||
| auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem); | |||
| if (abstract == nullptr) { | |||
| return nullptr; | |||
| if (abstract != nullptr) { | |||
| *(abstract->func_graph()->switch_layer_input()) = true; | |||
| } | |||
| *(abstract->func_graph()->switch_layer_input()) = true; | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -137,6 +137,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.arithmetic_simplify2_, | |||
| irpass.same_eliminate_, | |||
| irpass.check_bprop_eliminate_, | |||
| irpass.switch_layer_defer_inline_, | |||
| irpass.replace_applicator_, | |||
| }); | |||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | |||
| @@ -16,6 +16,7 @@ | |||
| #include "abstract/param_validator.h" | |||
| #include "abstract/infer_functions.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "abstract/utils.h" | |||
| #include "utils/symbolic.h" | |||
| @@ -121,12 +122,18 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP | |||
| for (size_t i = 0; i < branches.size(); i++) { | |||
| MS_EXCEPTION_IF_NULL(branches[i]); | |||
| if (!branches[i]->isa<AbstractFunction>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " | |||
| << branches[i]->ToString() << " as the " << i << "th element."; | |||
| MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got " | |||
| << branches[i]->ToString() << " as the " << i << "th element."; | |||
| } | |||
| } | |||
| auto b = branches[0]; | |||
| // Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0] | |||
| // which will cancel the out of bound checking for index | |||
| if (branches.size() == 1) { | |||
| AbstractFuncAtomPtrList func_list{b->cast<AbstractFuncAtomPtr>()}; | |||
| return std::make_shared<AbstractFuncUnion>(func_list); | |||
| } | |||
| for (size_t i = 1; i < branches.size(); i++) { | |||
| b = b->Join(branches[i]); | |||
| } | |||
| @@ -444,6 +444,86 @@ def test_index_to_switch_layer(): | |||
| C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| def test_parser_switch_layer_switch_in_bprop(): | |||
| class OneInputBprop(nn.Cell): | |||
| def __init__(self, funcs): | |||
| super(OneInputBprop, self).__init__() | |||
| self.op = P.ReLU() | |||
| self.funcs = funcs | |||
| def construct(self, i, x): | |||
| return self.op(x) | |||
| def bprop(self, i, x, out, dout): | |||
| return i, self.funcs[i](x, dout) | |||
| class Add(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.op = P.TensorAdd() | |||
| def construct(self, x, y): | |||
| return self.op(x, y) | |||
| class Mul(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.op = P.Mul() | |||
| def construct(self, x, y): | |||
| return self.op(x, y) | |||
| func1 = Add() | |||
| func2 = Mul() | |||
| funcs = (func1, func2) | |||
| net = OneInputBprop(funcs) | |||
| input1 = Tensor(np.ones([2, 2]).astype(np.float32)) | |||
| grad = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| i = Tensor(1, mstype.int32) | |||
| grad_net = C.grad_all_with_sens(net) | |||
| grad_net(i, input1, grad) | |||
| def test_parser_switch_layer_inputs_tuple(): | |||
| class TwoInputTupleFinalNet(nn.Cell): | |||
| def __init__(self, funcs): | |||
| super().__init__() | |||
| self.funcs = funcs | |||
| def construct(self, i, inputa, inputb): | |||
| inputs = (inputa, inputb) | |||
| x = self.funcs[i](inputs) | |||
| return x | |||
| class Add(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.op = P.TensorAdd() | |||
| def construct(self, x): | |||
| y = self.op(x[0], x[1]) | |||
| return self.op(x[0], y) | |||
| class Mul(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.op = P.Mul() | |||
| def construct(self, x): | |||
| y = self.op(x[0], x[1]) | |||
| return self.op(x[0], y) | |||
| func1 = Add() | |||
| func2 = Mul() | |||
| funcs = (func1, func2) | |||
| net = TwoInputTupleFinalNet(funcs) | |||
| input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) | |||
| input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) | |||
| i = Tensor(1, mstype.int32) | |||
| grad = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) | |||
| back_net = C.grad_all_with_sens(net) | |||
| back_out = back_net(i, input1, input2, grad) | |||
| def test_switch_layer_with_single_prim(): | |||
| class SwitchLayerCell(nn.Cell): | |||
| def __init__(self): | |||
| @@ -494,6 +574,35 @@ def test_switch_layer_env_eliminate(): | |||
| net2(x, i) | |||
| def test_switch_layer_single_layer(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.conv = nn.Conv2d(1, 1, 3, pad_mode='same') | |||
| self.funs = (self.conv,) | |||
| def construct(self, x, index): | |||
| x = self.funs[index](x) | |||
| return x | |||
| class NetGrad(nn.Cell): | |||
| def __init__(self, net): | |||
| super(NetGrad, self).__init__() | |||
| self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False) | |||
| self.net = net | |||
| self.weights = ParameterTuple(self.net.trainable_params()) | |||
| def construct(self, x, index): | |||
| weights = self.weights | |||
| grad = self.grad_op(self.net, weights)(x, index) | |||
| return grad | |||
| net = Net() | |||
| net2 = NetGrad(net) | |||
| x = Tensor(np.ones((3, 1, 12, 12)), ms.float32) | |||
| i = Tensor(1, ms.int32) | |||
| net2(x, i) | |||
| def test_control_depend_check(): | |||
| with pytest.raises(TypeError) as e: | |||
| P.ControlDepend(0.0) | |||