| @@ -2701,14 +2701,13 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | if (!IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return param_names; | return param_names; | ||||
| } | } | ||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||||
| if (prim->name() == CAST && cnode->inputs().size() >= 1) { | |||||
| auto cast_input = cnode->inputs()[1]; | |||||
| if (cast_input->isa<Parameter>()) { | |||||
| auto cast_input_parameter = cast_input->cast<ParameterPtr>(); | |||||
| if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) { | |||||
| param_names.push_back({cast_input_parameter->name(), i}); | |||||
| if ((IsPrimitiveCNode(cnode, prim::kPrimCast) && cnode->inputs().size() >= 1) || | |||||
| IsPrimitiveCNode(cnode, prim::kPrimLoad)) { | |||||
| auto inp = cnode->input(1); | |||||
| if (inp->isa<Parameter>()) { | |||||
| auto inp_param = inp->cast<ParameterPtr>(); | |||||
| if (inp_param->has_default() && ParameterRequireGrad(inp_param)) { | |||||
| param_names.push_back({inp_param->name(), i}); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||