|
|
|
@@ -2701,14 +2701,13 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n |
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
return param_names; |
|
|
|
} |
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
if (prim->name() == CAST && cnode->inputs().size() >= 1) { |
|
|
|
auto cast_input = cnode->inputs()[1]; |
|
|
|
if (cast_input->isa<Parameter>()) { |
|
|
|
auto cast_input_parameter = cast_input->cast<ParameterPtr>(); |
|
|
|
if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) { |
|
|
|
param_names.push_back({cast_input_parameter->name(), i}); |
|
|
|
if ((IsPrimitiveCNode(cnode, prim::kPrimCast) && cnode->inputs().size() >= 1) || |
|
|
|
IsPrimitiveCNode(cnode, prim::kPrimLoad)) { |
|
|
|
auto inp = cnode->input(1); |
|
|
|
if (inp->isa<Parameter>()) { |
|
|
|
auto inp_param = inp->cast<ParameterPtr>(); |
|
|
|
if (inp_param->has_default() && ParameterRequireGrad(inp_param)) { |
|
|
|
param_names.push_back({inp_param->name(), i}); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|