| @@ -103,6 +103,14 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object& | |||
| if (para_node == nullptr) { | |||
| ParameterPtr node = top_graph->AddWeightParameter(param_name); | |||
| node->set_default_param(obj); | |||
| // set_abstract for parameter | |||
| auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input")); | |||
| ValuePtr converted = nullptr; | |||
| (void)ConvertData(to_convert, &converted); | |||
| bool broaden = true; | |||
| node->set_abstract(abstract::FromValue(converted, broaden)); | |||
| para_node = node; | |||
| } | |||
| auto iter = func_graph->make_ref_params().find(para_node); | |||
| @@ -112,6 +112,13 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { | |||
| }); | |||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | |||
| opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | |||
| opt::irpass::ResolveIRPassLib resolve_irpass; | |||
| opt::OptPassConfig resolve_pass = opt::OptPassConfig({ | |||
| resolve_irpass.resolver_resolve_, | |||
| resolve_irpass.resolver_getattr_, | |||
| irpass.get_make_ref_eliminate_, | |||
| }); | |||
| OptPassGroupMap map_a({{"a_1", a_1}, | |||
| {"a_2", a_2}, | |||
| @@ -120,6 +127,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { | |||
| {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | |||
| {"virtual_dataset", virtual_dataset}, | |||
| {"grad", grad}, | |||
| {"resolve", resolve_pass}, | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| {"cse", opt::OptPassConfig(opt::CSE(false))}, | |||
| {"a_3", a_3}}); | |||
| @@ -554,24 +554,6 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat | |||
| return eng->ForwardConfig(old_conf, fn_conf); | |||
| } | |||
| AbstractBasePtr GenerateResolveAbstract(const AnfNodeConfigPtr &out_conf, const py::object &obj, | |||
| const ValuePtr &converted_ret) { | |||
| if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { | |||
| TypePtr cls_ptr = parse::ParseDataClass(converted_ret->cast<std::shared_ptr<parse::PyObjectWrapper>>()->obj()); | |||
| std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial), NewValueNode(prim::kPrimMakeRecord), | |||
| NewValueNode(cls_ptr)}; | |||
| MS_EXCEPTION_IF_NULL(out_conf); | |||
| FuncGraphPtr func_graph = out_conf->node()->func_graph(); | |||
| CNodePtr new_cnode = func_graph->NewCNode(input); | |||
| AnalysisEnginePtr eng = out_conf->engine(); | |||
| AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, out_conf->context()); | |||
| return eng->ForwardConfig(out_conf, fn_conf); | |||
| } else { | |||
| return ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); | |||
| } | |||
| } | |||
| AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, | |||
| const AbstractBasePtrList &args_spec_list, | |||
| const AnfNodeConfigPtr &out_conf) { | |||
| @@ -602,23 +584,16 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng | |||
| // item_name to func addr from obj_map | |||
| parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>(); | |||
| parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>(); | |||
| FuncGraphPtr func_graph = out_conf->node()->func_graph(); | |||
| parse::SymbolResolverPtr symbol_resolver = | |||
| std::make_shared<parse::SymbolResolver>(name_space, symbol, out_conf->node()); | |||
| if (!symbol_resolver->Resolve()) { | |||
| auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node()); | |||
| if (new_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Resolve node failed"; | |||
| } | |||
| py::object obj = symbol_resolver->result(); | |||
| ValuePtr converted_ret = nullptr; | |||
| bool converted = parse::ConvertData(obj, &converted_ret, true); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Convert data failed"; | |||
| } | |||
| if (converted_ret->isa<FuncGraph>()) { | |||
| AddToManager(engine, converted_ret->cast<FuncGraphPtr>()); | |||
| } | |||
| return GenerateResolveAbstract(out_conf, obj, converted_ret); | |||
| AnalysisEnginePtr eng = out_conf->engine(); | |||
| AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context()); | |||
| return eng->ForwardConfig(out_conf, fn_conf); | |||
| } | |||
| AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, | |||
| @@ -17,13 +17,14 @@ import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.api import ms_function | |||
| from ....mindspore_test_framework.utils.bprop_util import bprop | |||
| from ....mindspore_test_framework.utils.debug_util import PrintShapeTypeCell, PrintGradShapeTypeCell | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| import mindspore | |||
| def setup_module(module): | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| @@ -107,3 +108,36 @@ def test_print_shape_type(): | |||
| return z | |||
| bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)), | |||
| Tensor(np.ones([2, 2]).astype(np.float32))) | |||
| def test_cell_assign(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| class GradNetWrap(nn.Cell): | |||
| """ GradNetWrap definition """ | |||
| def __init__(self, net): | |||
| super(GradNetWrap, self).__init__() | |||
| self.net = net | |||
| self.weights = mindspore.ParameterTuple(net.get_parameters()) | |||
| def construct(self, x, y): | |||
| return C.grad_by_list(self.net, self.weights)(x, y) | |||
| class Mul(nn.Cell): | |||
| def __init__(self): | |||
| super(Mul, self).__init__() | |||
| self.get_g = P.InsertGradientOf(self.save_gradient) | |||
| self.matrix_w = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_w") | |||
| self.matrix_g = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_g") | |||
| def save_gradient(self, dout): | |||
| self.matrix_g = dout | |||
| return dout | |||
| def construct(self, x, y): | |||
| z = x * self.matrix_w | |||
| z = self.get_g(z) | |||
| z = z * y | |||
| return z | |||
| input_x = Tensor(np.ones([2, 2], np.float32)) | |||
| input_y = Tensor(np.ones([2, 2], np.float32)) | |||
| GradNetWrap(Mul())(input_x, input_y) | |||