| @@ -228,7 +228,7 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||||
| if (device_address->ptr_) { | if (device_address->ptr_) { | ||||
| mem_manager_->FreeMemFromMemPool(device_address); | mem_manager_->FreeMemFromMemPool(device_address); | ||||
| } | } | ||||
| @@ -289,7 +289,7 @@ bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) { | |||||
| for (auto &mem_swap_info : mem_swap_info_list) { | for (auto &mem_swap_info : mem_swap_info_list) { | ||||
| auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); | auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); | ||||
| const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; | const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_); | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); | |||||
| if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { | if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { | ||||
| mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); | mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); | ||||
| @@ -379,7 +379,8 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k | |||||
| MS_EXCEPTION_IF_NULL(kernel_inputs); | MS_EXCEPTION_IF_NULL(kernel_inputs); | ||||
| MS_EXCEPTION_IF_NULL(mem_swap_manager_); | MS_EXCEPTION_IF_NULL(mem_swap_manager_); | ||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | ||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | |||||
| // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | |||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| if (mem_swap_manager_->trigger_swap()) { | if (mem_swap_manager_->trigger_swap()) { | ||||
| while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { | while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { | ||||
| @@ -437,7 +438,7 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern | |||||
| } | } | ||||
| auto output_sizes = kernel_mod.GetOutputSizeList(); | auto output_sizes = kernel_mod.GetOutputSizeList(); | ||||
| for (size_t i = 0; i < output_sizes.size(); ++i) { | for (size_t i = 0; i < output_sizes.size(); ++i) { | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { | if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { | ||||
| return false; | return false; | ||||
| @@ -495,7 +496,7 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN | |||||
| std::vector<size_t> size_list; | std::vector<size_t> size_list; | ||||
| DeviceAddressPtrList addr_list; | DeviceAddressPtrList addr_list; | ||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | ||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | |||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| if (device_address->ptr_ == nullptr) { | if (device_address->ptr_ == nullptr) { | ||||
| is_need_alloc_memory = true; | is_need_alloc_memory = true; | ||||
| @@ -520,7 +521,7 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | auto output_sizes = kernel_mod->GetOutputSizeList(); | ||||
| for (size_t i = 0; i < output_sizes.size(); ++i) { | for (size_t i = 0; i < output_sizes.size(); ++i) { | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| if (device_address->ptr_ == nullptr) { | if (device_address->ptr_ == nullptr) { | ||||
| is_need_alloc_memory = true; | is_need_alloc_memory = true; | ||||
| @@ -578,7 +579,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||||
| MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; | MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; | ||||
| } | } | ||||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | ||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | |||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | mem_manager_->FreeMemFromMemPool(device_address); | ||||
| device_address->set_status(DeviceAddressStatus::kInDevice); | device_address->set_status(DeviceAddressStatus::kInDevice); | ||||
| } | } | ||||
| @@ -590,7 +591,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | mem_manager_->FreeMemFromMemPool(device_address); | ||||
| device_address->set_status(DeviceAddressStatus::kInDevice); | device_address->set_status(DeviceAddressStatus::kInDevice); | ||||
| } | } | ||||
| @@ -228,7 +228,8 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t | |||||
| << AnfAlgo::GetInputTensorNum(kernel); | << AnfAlgo::GetInputTensorNum(kernel); | ||||
| } | } | ||||
| auto input_node = kernel->input(input_idx + 1); | auto input_node = kernel->input(input_idx + 1); | ||||
| auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); | |||||
| // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | |||||
| auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); | |||||
| if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { | if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { | ||||
| MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; | MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; | ||||
| } | } | ||||
| @@ -269,7 +270,8 @@ void MemReuseUtil::SetKernelDefInputs() { | |||||
| if (ref_ptr != nullptr) { | if (ref_ptr != nullptr) { | ||||
| // set the inputs of this kernel_def | // set the inputs of this kernel_def | ||||
| auto input_node = AnfAlgo::GetInputNode(kernel, i); | auto input_node = AnfAlgo::GetInputNode(kernel, i); | ||||
| auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); | |||||
| // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | |||||
| auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); | |||||
| if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { | if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { | ||||
| MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; | MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; | ||||
| } | } | ||||
| @@ -544,9 +544,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &an | |||||
| } | } | ||||
| // get output device addr of anf_node | // get output device addr of anf_node | ||||
| const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx) { | |||||
| const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, | |||||
| bool visit_nop_node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (opt::IsNopNode(node)) { | |||||
| if (opt::IsNopNode(node) && visit_nop_node) { | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (cnode->inputs().size() == 2) { | if (cnode->inputs().size() == 2) { | ||||
| @@ -565,9 +566,10 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, | |||||
| return addr; | return addr; | ||||
| } | } | ||||
| DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx) { | |||||
| DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, | |||||
| bool visit_nop_node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (opt::IsNopNode(node)) { | |||||
| if (opt::IsNopNode(node) && visit_nop_node) { | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (cnode->inputs().size() == 2) { | if (cnode->inputs().size() == 2) { | ||||
| @@ -598,14 +600,16 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_ | |||||
| return kernel_info->OutputAddrExist(output_idx); | return kernel_info->OutputAddrExist(output_idx); | ||||
| } | } | ||||
| const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { | |||||
| const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, | |||||
| bool visit_nop_node) { | |||||
| KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); | KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); | ||||
| return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second); | |||||
| return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); | |||||
| } | } | ||||
| DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { | |||||
| DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, | |||||
| bool visit_nop_node) { | |||||
| KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); | KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); | ||||
| return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second); | |||||
| return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); | |||||
| } | } | ||||
| // set output device addr of anf_node | // set output device addr of anf_node | ||||
| @@ -121,14 +121,16 @@ class AnfRuntimeAlgorithm { | |||||
| // get output select data type from prev node,input_index is the input index of current node related to prev node | // get output select data type from prev node,input_index is the input index of current node related to prev node | ||||
| static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); | static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); | ||||
| // get output device addr of anf_node | // get output device addr of anf_node | ||||
| static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx); | |||||
| static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); | |||||
| // get mutable output device addr of anf_node | // get mutable output device addr of anf_node | ||||
| static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx); | |||||
| static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); | |||||
| // check whether output addr is exist or not | // check whether output addr is exist or not | ||||
| static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); | static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); | ||||
| // get address from prev node,input_index is the input index of current node related to prev node | // get address from prev node,input_index is the input index of current node related to prev node | ||||
| static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx); | |||||
| static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx); | |||||
| static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, | |||||
| bool visit_nop_node = true); | |||||
| static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, | |||||
| bool visit_nop_node = true); | |||||
| // set output device addr of anf_node | // set output device addr of anf_node | ||||
| static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); | static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); | ||||
| // set workspace device addr of anf_node | // set workspace device addr of anf_node | ||||
| @@ -31,6 +31,49 @@ class NetFlatten(nn.Cell): | |||||
| return self.flatten(x) | return self.flatten(x) | ||||
| class NetAllFlatten(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetAllFlatten, self).__init__() | |||||
| self.flatten = P.Flatten() | |||||
| def construct(self, x): | |||||
| loop_count = 4 | |||||
| while loop_count > 0: | |||||
| x = self.flatten(x) | |||||
| loop_count = loop_count - 1 | |||||
| return x | |||||
| class NetFirstFlatten(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetFirstFlatten, self).__init__() | |||||
| self.flatten = P.Flatten() | |||||
| self.relu = P.ReLU() | |||||
| def construct(self, x): | |||||
| loop_count = 4 | |||||
| while loop_count > 0: | |||||
| x = self.flatten(x) | |||||
| loop_count = loop_count - 1 | |||||
| x = self.relu(x) | |||||
| return x | |||||
| class NetLastFlatten(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetLastFlatten, self).__init__() | |||||
| self.flatten = P.Flatten() | |||||
| self.relu = P.ReLU() | |||||
| def construct(self, x): | |||||
| loop_count = 4 | |||||
| x = self.relu(x) | |||||
| while loop_count > 0: | |||||
| x = self.flatten(x) | |||||
| loop_count = loop_count - 1 | |||||
| return x | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @@ -46,3 +89,55 @@ def test_flatten(): | |||||
| flatten = NetFlatten() | flatten = NetFlatten() | ||||
| output = flatten(x) | output = flatten(x) | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_all_flatten(): | |||||
| x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) | |||||
| expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| flatten = NetAllFlatten() | |||||
| output = flatten(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| flatten = NetAllFlatten() | |||||
| output = flatten(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_first_flatten(): | |||||
| x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) | |||||
| expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| flatten = NetFirstFlatten() | |||||
| output = flatten(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| flatten = NetFirstFlatten() | |||||
| output = flatten(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_last_flatten(): | |||||
| x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) | |||||
| expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| flatten = NetLastFlatten() | |||||
| output = flatten(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| flatten = NetLastFlatten() | |||||
| output = flatten(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||