Browse Source

fix-bug-of-repeated-graph-id-when-run-dynamic-op-in-pynative

tags/v1.1.0
lvliang 5 years ago
parent
commit
9b8f0e3b5e
3 changed files with 6 additions and 5 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +2
    -2
      mindspore/ccsrc/backend/session/session_basic.cc
  3. +1
    -2
      mindspore/ccsrc/backend/session/session_basic.h

+ 3
- 1
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -402,7 +402,9 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra
// malloc mem // malloc mem
RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get()); RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get());
// Build dynamic kernel // Build dynamic kernel
BuildDynamicKernel(graph);
if (op_run_info.is_dynamic_shape) {
BuildDynamicKernel(graph);
}
// load input data to device // load input data to device
LoadInputData(graph, input_tensors); LoadInputData(graph, input_tensors);
// run op // run op


+ 2
- 2
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1322,8 +1322,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) { const std::vector<int64_t> &tensors_mask) {
auto graph = std::make_shared<KernelGraph>(); auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(run_op_graph_id_);
run_op_graph_id_++;
graph->set_graph_id(graph_sum_);
graph_sum_++;
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
// set input[0] // set input[0]
PrimitivePtr op_prim = op_run_info.primitive; PrimitivePtr op_prim = op_run_info.primitive;


+ 1
- 2
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -56,7 +56,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class Executor; class Executor;
class SessionBasic : public std::enable_shared_from_this<SessionBasic> { class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
public: public:
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0), run_op_graph_id_(0) {
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
#if !defined(_WIN32) && !defined(_WIN64) #if !defined(_WIN32) && !defined(_WIN64)
debugger_ = nullptr; debugger_ = nullptr;
#endif #endif
@@ -184,7 +184,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
CallBackFunc summary_callback_; CallBackFunc summary_callback_;
static GraphId graph_sum_; static GraphId graph_sum_;
uint32_t device_id_; uint32_t device_id_;
uint32_t run_op_graph_id_;
std::shared_ptr<Executor> executor_; std::shared_ptr<Executor> executor_;
#if !defined(_WIN32) && !defined(_WIN64) #if !defined(_WIN32) && !defined(_WIN64)
std::shared_ptr<Debugger> debugger_; std::shared_ptr<Debugger> debugger_;


Loading…
Cancel
Save