| @@ -103,6 +103,14 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object& | |||||
| if (para_node == nullptr) { | if (para_node == nullptr) { | ||||
| ParameterPtr node = top_graph->AddWeightParameter(param_name); | ParameterPtr node = top_graph->AddWeightParameter(param_name); | ||||
| node->set_default_param(obj); | node->set_default_param(obj); | ||||
| // set_abstract for parameter | |||||
| auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input")); | |||||
| ValuePtr converted = nullptr; | |||||
| (void)ConvertData(to_convert, &converted); | |||||
| bool broaden = true; | |||||
| node->set_abstract(abstract::FromValue(converted, broaden)); | |||||
| para_node = node; | para_node = node; | ||||
| } | } | ||||
| auto iter = func_graph->make_ref_params().find(para_node); | auto iter = func_graph->make_ref_params().find(para_node); | ||||
| @@ -112,6 +112,13 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { | |||||
| }); | }); | ||||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | ||||
| opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | ||||
| opt::irpass::ResolveIRPassLib resolve_irpass; | |||||
| opt::OptPassConfig resolve_pass = opt::OptPassConfig({ | |||||
| resolve_irpass.resolver_resolve_, | |||||
| resolve_irpass.resolver_getattr_, | |||||
| irpass.get_make_ref_eliminate_, | |||||
| }); | |||||
| OptPassGroupMap map_a({{"a_1", a_1}, | OptPassGroupMap map_a({{"a_1", a_1}, | ||||
| {"a_2", a_2}, | {"a_2", a_2}, | ||||
| @@ -120,6 +127,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { | |||||
| {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | ||||
| {"virtual_dataset", virtual_dataset}, | {"virtual_dataset", virtual_dataset}, | ||||
| {"grad", grad}, | {"grad", grad}, | ||||
| {"resolve", resolve_pass}, | |||||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | {"renormalize", opt::OptPassConfig::Renormalize()}, | ||||
| {"cse", opt::OptPassConfig(opt::CSE(false))}, | {"cse", opt::OptPassConfig(opt::CSE(false))}, | ||||
| {"a_3", a_3}}); | {"a_3", a_3}}); | ||||
| @@ -554,24 +554,6 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat | |||||
| return eng->ForwardConfig(old_conf, fn_conf); | return eng->ForwardConfig(old_conf, fn_conf); | ||||
| } | } | ||||
| AbstractBasePtr GenerateResolveAbstract(const AnfNodeConfigPtr &out_conf, const py::object &obj, | |||||
| const ValuePtr &converted_ret) { | |||||
| if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { | |||||
| TypePtr cls_ptr = parse::ParseDataClass(converted_ret->cast<std::shared_ptr<parse::PyObjectWrapper>>()->obj()); | |||||
| std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial), NewValueNode(prim::kPrimMakeRecord), | |||||
| NewValueNode(cls_ptr)}; | |||||
| MS_EXCEPTION_IF_NULL(out_conf); | |||||
| FuncGraphPtr func_graph = out_conf->node()->func_graph(); | |||||
| CNodePtr new_cnode = func_graph->NewCNode(input); | |||||
| AnalysisEnginePtr eng = out_conf->engine(); | |||||
| AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, out_conf->context()); | |||||
| return eng->ForwardConfig(out_conf, fn_conf); | |||||
| } else { | |||||
| return ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); | |||||
| } | |||||
| } | |||||
| AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, | AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, | ||||
| const AbstractBasePtrList &args_spec_list, | const AbstractBasePtrList &args_spec_list, | ||||
| const AnfNodeConfigPtr &out_conf) { | const AnfNodeConfigPtr &out_conf) { | ||||
| @@ -602,23 +584,16 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng | |||||
| // item_name to func addr from obj_map | // item_name to func addr from obj_map | ||||
| parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>(); | parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>(); | ||||
| parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>(); | parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>(); | ||||
| FuncGraphPtr func_graph = out_conf->node()->func_graph(); | |||||
| parse::SymbolResolverPtr symbol_resolver = | |||||
| std::make_shared<parse::SymbolResolver>(name_space, symbol, out_conf->node()); | |||||
| if (!symbol_resolver->Resolve()) { | |||||
| auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node()); | |||||
| if (new_node == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Resolve node failed"; | MS_LOG(EXCEPTION) << "Resolve node failed"; | ||||
| } | } | ||||
| py::object obj = symbol_resolver->result(); | |||||
| ValuePtr converted_ret = nullptr; | |||||
| bool converted = parse::ConvertData(obj, &converted_ret, true); | |||||
| if (!converted) { | |||||
| MS_LOG(EXCEPTION) << "Convert data failed"; | |||||
| } | |||||
| if (converted_ret->isa<FuncGraph>()) { | |||||
| AddToManager(engine, converted_ret->cast<FuncGraphPtr>()); | |||||
| } | |||||
| return GenerateResolveAbstract(out_conf, obj, converted_ret); | |||||
| AnalysisEnginePtr eng = out_conf->engine(); | |||||
| AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context()); | |||||
| return eng->ForwardConfig(out_conf, fn_conf); | |||||
| } | } | ||||
| AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, | AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, | ||||
| @@ -17,13 +17,14 @@ import numpy as np | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | |||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| from ....mindspore_test_framework.utils.bprop_util import bprop | from ....mindspore_test_framework.utils.bprop_util import bprop | ||||
| from ....mindspore_test_framework.utils.debug_util import PrintShapeTypeCell, PrintGradShapeTypeCell | from ....mindspore_test_framework.utils.debug_util import PrintShapeTypeCell, PrintGradShapeTypeCell | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| import mindspore | |||||
| def setup_module(module): | def setup_module(module): | ||||
| context.set_context(mode=context.PYNATIVE_MODE) | context.set_context(mode=context.PYNATIVE_MODE) | ||||
| @@ -107,3 +108,36 @@ def test_print_shape_type(): | |||||
| return z | return z | ||||
| bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)), | bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)), | ||||
| Tensor(np.ones([2, 2]).astype(np.float32))) | Tensor(np.ones([2, 2]).astype(np.float32))) | ||||
| def test_cell_assign(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| class GradNetWrap(nn.Cell): | |||||
| """ GradNetWrap definition """ | |||||
| def __init__(self, net): | |||||
| super(GradNetWrap, self).__init__() | |||||
| self.net = net | |||||
| self.weights = mindspore.ParameterTuple(net.get_parameters()) | |||||
| def construct(self, x, y): | |||||
| return C.grad_by_list(self.net, self.weights)(x, y) | |||||
| class Mul(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Mul, self).__init__() | |||||
| self.get_g = P.InsertGradientOf(self.save_gradient) | |||||
| self.matrix_w = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_w") | |||||
| self.matrix_g = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_g") | |||||
| def save_gradient(self, dout): | |||||
| self.matrix_g = dout | |||||
| return dout | |||||
| def construct(self, x, y): | |||||
| z = x * self.matrix_w | |||||
| z = self.get_g(z) | |||||
| z = z * y | |||||
| return z | |||||
| input_x = Tensor(np.ones([2, 2], np.float32)) | |||||
| input_y = Tensor(np.ones([2, 2], np.float32)) | |||||
| GradNetWrap(Mul())(input_x, input_y) | |||||