|
|
|
@@ -31,7 +31,6 @@ |
|
|
|
#include "backend/kernel_compiler/kernel.h" |
|
|
|
#include "backend/session/session_factory.h" |
|
|
|
#include "backend/session/ascend_control_parser.h" |
|
|
|
#include "runtime/context.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
@@ -40,28 +39,8 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, |
|
|
|
class AscendSession : public SessionBasic { |
|
|
|
public: |
|
|
|
AscendSession() { final_graph_id_ = kInvalidGraphId; } |
|
|
|
~AscendSession() { |
|
|
|
if (rt_context_ != nullptr) { |
|
|
|
auto ret = rtCtxDestroy(rt_context_); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]"; |
|
|
|
} |
|
|
|
rt_context_ = nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Init(uint32_t device_id) override { |
|
|
|
InitDevice(kAscendDevice, device_id); |
|
|
|
auto ret = rtCtxCreate(&rt_context_, 0, device_id); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
} |
|
|
|
ret = rtCtxSetCurrent(rt_context_); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
~AscendSession() = default; |
|
|
|
void Init(uint32_t device_id) override; |
|
|
|
// get graph id of final graph |
|
|
|
GraphId GetFinalRunGraph() const override { return final_graph_id_; } |
|
|
|
|
|
|
|
@@ -136,8 +115,6 @@ class AscendSession : public SessionBasic { |
|
|
|
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_; |
|
|
|
// final_graph_id is used in every root graph has it's own session situation |
|
|
|
GraphId final_graph_id_; |
|
|
|
// ascend runtime context |
|
|
|
rtContext_t rt_context_{nullptr}; |
|
|
|
}; |
|
|
|
MS_REG_SESSION(kAscendDevice, AscendSession); |
|
|
|
} // namespace session |
|
|
|
|