|
|
|
@@ -29,6 +29,7 @@ |
|
|
|
#include "backend/kernel_compiler/kernel.h" |
|
|
|
#include "backend/session/session_factory.h" |
|
|
|
#include "backend/session/ascend_control_parser.h" |
|
|
|
#include "runtime/device/ascend/ascend_memory_pool.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
@@ -37,7 +38,7 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, |
|
|
|
class AscendSession : public SessionBasic { |
|
|
|
public: |
|
|
|
AscendSession() { final_graph_id_ = kInvalidGraphId; } |
|
|
|
~AscendSession() override = default; |
|
|
|
~AscendSession() override { mindspore::device::ascend::AscendMemoryPool::GetInstance().ResetIdleMemBuf(); } |
|
|
|
void Init(uint32_t device_id) override { |
|
|
|
SessionBasic::Init(device_id); |
|
|
|
context_ = std::make_shared<Context>(kAscendDevice, device_id); |
|
|
|
|