diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 2f287227a3..3395115c6b 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -2701,14 +2701,13 @@ std::vector> NodeParameterName(const CNodePtr &n if (!IsValueNode(cnode->input(0))) { return param_names; } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = prim_anf_node->value()->cast(); - if (prim->name() == CAST && cnode->inputs().size() >= 1) { - auto cast_input = cnode->inputs()[1]; - if (cast_input->isa()) { - auto cast_input_parameter = cast_input->cast(); - if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) { - param_names.push_back({cast_input_parameter->name(), i}); + if ((IsPrimitiveCNode(cnode, prim::kPrimCast) && cnode->inputs().size() >= 1) || + IsPrimitiveCNode(cnode, prim::kPrimLoad)) { + auto inp = cnode->input(1); + if (inp->isa()) { + auto inp_param = inp->cast(); + if (inp_param->has_default() && ParameterRequireGrad(inp_param)) { + param_names.push_back({inp_param->name(), i}); } } }