Browse Source

fix-bug-of-repeated-graph-id-when-run-dynamic-op-in-pynative

tags/v1.1.0
lvliang 5 years ago
parent
commit
9b8f0e3b5e
3 changed files with 6 additions and 5 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +2
    -2
      mindspore/ccsrc/backend/session/session_basic.cc
  3. +1
    -2
      mindspore/ccsrc/backend/session/session_basic.h

+ 3
- 1
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -402,7 +402,9 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra
// malloc mem
RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get());
// Build dynamic kernel
BuildDynamicKernel(graph);
if (op_run_info.is_dynamic_shape) {
BuildDynamicKernel(graph);
}
// load input data to device
LoadInputData(graph, input_tensors);
// run op


+ 2
- 2
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1322,8 +1322,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(run_op_graph_id_);
run_op_graph_id_++;
graph->set_graph_id(graph_sum_);
graph_sum_++;
std::vector<AnfNodePtr> inputs;
// set input[0]
PrimitivePtr op_prim = op_run_info.primitive;


+ 1
- 2
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -56,7 +56,7 @@ using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class Executor;
class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
public:
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0), run_op_graph_id_(0) {
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
#if !defined(_WIN32) && !defined(_WIN64)
debugger_ = nullptr;
#endif
@@ -184,7 +184,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
CallBackFunc summary_callback_;
static GraphId graph_sum_;
uint32_t device_id_;
uint32_t run_op_graph_id_;
std::shared_ptr<Executor> executor_;
#if !defined(_WIN32) && !defined(_WIN64)
std::shared_ptr<Debugger> debugger_;


Loading…
Cancel
Save