diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h index e3dedf6d1f..88ddd7b322 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h @@ -31,7 +31,6 @@ namespace mindspore { namespace opt { namespace irpass { - // {prim::kPrimJ, C} class ExpandJPrim { public: diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index 0c193a7401..8d1d6aeb43 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -129,28 +129,7 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod if (!new_node->isa()) { MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << "."; } - auto c_node = node->cast(); - MS_EXCEPTION_IF_NULL(c_node); - auto inputs = c_node->inputs(); - std::vector new_inputs; - (void)std::transform( - inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [this](const AnfNodePtr &inp) -> AnfNodePtr { - auto new_inp = ReplicateDisconnectedNode(inp); - // Refer the comments in BuildReplacedNode. - if (inp->isa()) { - auto c_inp = inp->cast(); - MS_EXCEPTION_IF_NULL(c_inp); - auto c_new_inp = new_inp->cast(); - MS_EXCEPTION_IF_NULL(c_new_inp); - MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> " << new_inp->DebugString(); - c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp); - } - return new_inp; - }); - - auto c_new_node = new_node->cast(); - MS_EXCEPTION_IF_NULL(c_new_node); - c_new_node->set_inputs(new_inputs); + UpdateNewCNodeInputs(node, new_node); } iter = specializer->repl_node_->find(node); @@ -164,6 +143,31 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod return new_node; } +void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) { + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + auto inputs = c_node->inputs(); + std::vector new_inputs; + (void)std::transform( + inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [this](const AnfNodePtr &inp) -> AnfNodePtr { + auto new_inp = ReplicateDisconnectedNode(inp); + // Refer the comments in BuildReplacedNode. + if (inp->isa()) { + auto c_inp = inp->cast(); + MS_EXCEPTION_IF_NULL(c_inp); + auto c_new_inp = new_inp->cast(); + MS_EXCEPTION_IF_NULL(c_new_inp); + MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> " << new_inp->DebugString(); + c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp); + } + return new_inp; + }); + + auto c_new_node = new_node->cast(); + MS_EXCEPTION_IF_NULL(c_new_node); + c_new_node->set_inputs(new_inputs); +} + AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); FuncGraphPtr fg = node->func_graph(); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h index 27e5cb14db..ba7ba08569 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h @@ -130,6 +130,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this BuildFromBroadedArgsVal(const EvaluatorPtr &eval); + void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node); }; } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index c4012b5fe8..06654aad6d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -667,7 +667,7 @@ bool AscendDeviceAddress::DumpMemToFile(const std::string &filepath, const std:: std::string file_extension = ".bin"; if (trans_flag) { std::string path = - filepath + '_' + shape + '_' + TypeIdToType(type_id_)->ToString() + '_' + host_fmt + file_extension; + filepath + '_' + shape + '_' + TypeIdToType(host_type)->ToString() + '_' + host_fmt + file_extension; MS_LOG(INFO) << "E2E Dump path is " << path; mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); size_t host_size = out_tensor->data().nbytes();