Browse Source

fix reshape output and clearres error

tags/v0.2.0-alpha
kswang 5 years ago
parent
commit
b8a7e73f7d
4 changed files with 28 additions and 65 deletions
  1. +4
    -2
      mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
  2. +3
    -2
      mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc
  3. +19
    -61
      mindspore/ccsrc/session/anf_runtime_algorithm.cc
  4. +2
    -0
      mindspore/ccsrc/session/anf_runtime_algorithm.h

+ 4
- 2
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc View File

@@ -85,8 +85,10 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
MS_EXCEPTION(DeviceProcessError) << "rtSetDevice, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "rtSetDevice, ret[" << static_cast<int>(ret) << "]";
} }


MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->FreeDeviceMemory();
if (mem_manager_ != nullptr) {
mem_manager_->FreeDeviceMemory();
}

(void)DestroyHccl(); (void)DestroyHccl();
(void)ResetDevice(); (void)ResetDevice();
(void)ProfilingManager::GetInstance().StopProfiling(); (void)ProfilingManager::GetInstance().StopProfiling();


+ 3
- 2
mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc View File

@@ -101,8 +101,9 @@ void GPUKernelRuntime::ReleaseDeviceRes() {
CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue."); CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue.");
} }
GPUDeviceManager::GetInstance().ReleaseDevice(); GPUDeviceManager::GetInstance().ReleaseDevice();
MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->FreeDeviceMemory();
if (mem_manager_ != nullptr) {
mem_manager_->FreeDeviceMemory();
}
} }
void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {


+ 19
- 61
mindspore/ccsrc/session/anf_runtime_algorithm.cc View File

@@ -112,6 +112,12 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx)); return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0); return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0);
} else if (opt::IsNopNode(cnode)) {
if (cnode->inputs().size() == 2) {
return VisitKernelWithReturnType(cnode->input(1), 0);
} else {
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
}
} else { } else {
return std::make_pair(anf_node, index); return std::make_pair(anf_node, index);
} }
@@ -299,20 +305,23 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
return build_info->GetInputFormat(input_idx); return build_info->GetInputFormat(input_idx);
} }


std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
if (!anf_node->isa<CNode>()) { if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
} }
auto cnode = anf_node->cast<CNodePtr>(); auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) { if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
<< ".";
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
} }
auto node = cnode->input(input_idx + 1); auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
return VisitKernel(node, 0);
}

std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
} }


@@ -346,18 +355,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
} }


std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) { std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
<< ".";
}
auto input_node = cnode->input(input_idx + 1);
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second); return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
} }


@@ -459,17 +457,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
} }


TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << node->DebugString() << "is not a CNode";
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto input_node = cnode->input(input_idx + 1);
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second); return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
} }


@@ -492,17 +480,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
} }


TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) { TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node);
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
} }


@@ -558,32 +536,12 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
} }


const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf node is not a CNode";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node);
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second); return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second);
} }


DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node);
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second); return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second);
} }




+ 2
- 0
mindspore/ccsrc/session/anf_runtime_algorithm.h View File

@@ -89,6 +89,8 @@ class AnfRuntimeAlgorithm {
static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
// get input format select of anf node // get input format select of anf node
static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx); static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
// get prev node output width output index
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
// get output format from prev node,input_index is the input index of current node related to prev node // get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get output shapes inferred by ME from input nodes. // get output shapes inferred by ME from input nodes.


Loading…
Cancel
Save