| @@ -228,7 +228,7 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap | |||
| continue; | |||
| } | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||
| if (device_address->ptr_) { | |||
| mem_manager_->FreeMemFromMemPool(device_address); | |||
| } | |||
| @@ -289,7 +289,7 @@ bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) { | |||
| for (auto &mem_swap_info : mem_swap_info_list) { | |||
| auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); | |||
| const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); | |||
| if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { | |||
| mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); | |||
| @@ -379,7 +379,8 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k | |||
| MS_EXCEPTION_IF_NULL(kernel_inputs); | |||
| MS_EXCEPTION_IF_NULL(mem_swap_manager_); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | |||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | |||
| // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | |||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (mem_swap_manager_->trigger_swap()) { | |||
| while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { | |||
| @@ -437,7 +438,7 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern | |||
| } | |||
| auto output_sizes = kernel_mod.GetOutputSizeList(); | |||
| for (size_t i = 0; i < output_sizes.size(); ++i) { | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { | |||
| return false; | |||
| @@ -495,7 +496,7 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN | |||
| std::vector<size_t> size_list; | |||
| DeviceAddressPtrList addr_list; | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | |||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | |||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (device_address->ptr_ == nullptr) { | |||
| is_need_alloc_memory = true; | |||
| @@ -520,7 +521,7 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | |||
| for (size_t i = 0; i < output_sizes.size(); ++i) { | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (device_address->ptr_ == nullptr) { | |||
| is_need_alloc_memory = true; | |||
| @@ -578,7 +579,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||
| MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; | |||
| } | |||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | |||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); | |||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); | |||
| mem_manager_->FreeMemFromMemPool(device_address); | |||
| device_address->set_status(DeviceAddressStatus::kInDevice); | |||
| } | |||
| @@ -590,7 +591,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||
| continue; | |||
| } | |||
| if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||
| mem_manager_->FreeMemFromMemPool(device_address); | |||
| device_address->set_status(DeviceAddressStatus::kInDevice); | |||
| } | |||
| @@ -228,7 +228,8 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t | |||
| << AnfAlgo::GetInputTensorNum(kernel); | |||
| } | |||
| auto input_node = kernel->input(input_idx + 1); | |||
| auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); | |||
| // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | |||
| auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); | |||
| if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { | |||
| MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; | |||
| } | |||
| @@ -269,7 +270,8 @@ void MemReuseUtil::SetKernelDefInputs() { | |||
| if (ref_ptr != nullptr) { | |||
| // set the inputs of this kernel_def | |||
| auto input_node = AnfAlgo::GetInputNode(kernel, i); | |||
| auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); | |||
| // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. | |||
| auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); | |||
| if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { | |||
| MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; | |||
| } | |||
| @@ -544,9 +544,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &an | |||
| } | |||
| // get output device addr of anf_node | |||
| const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx) { | |||
| const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, | |||
| bool visit_nop_node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (opt::IsNopNode(node)) { | |||
| if (opt::IsNopNode(node) && visit_nop_node) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() == 2) { | |||
| @@ -565,9 +566,10 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, | |||
| return addr; | |||
| } | |||
| DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx) { | |||
| DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, | |||
| bool visit_nop_node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (opt::IsNopNode(node)) { | |||
| if (opt::IsNopNode(node) && visit_nop_node) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() == 2) { | |||
| @@ -598,14 +600,16 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_ | |||
| return kernel_info->OutputAddrExist(output_idx); | |||
| } | |||
| const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { | |||
| const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, | |||
| bool visit_nop_node) { | |||
| KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); | |||
| return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second); | |||
| return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); | |||
| } | |||
| DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { | |||
| DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, | |||
| bool visit_nop_node) { | |||
| KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); | |||
| return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second); | |||
| return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); | |||
| } | |||
| // set output device addr of anf_node | |||
| @@ -121,14 +121,16 @@ class AnfRuntimeAlgorithm { | |||
| // get output select data type from prev node,input_index is the input index of current node related to prev node | |||
| static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); | |||
| // get output device addr of anf_node | |||
| static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx); | |||
| static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); | |||
| // get mutable output device addr of anf_node | |||
| static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx); | |||
| static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); | |||
| // check whether output addr is exist or not | |||
| static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); | |||
| // get address from prev node,input_index is the input index of current node related to prev node | |||
| static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx); | |||
| static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx); | |||
| static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, | |||
| bool visit_nop_node = true); | |||
| static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, | |||
| bool visit_nop_node = true); | |||
| // set output device addr of anf_node | |||
| static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); | |||
| // set workspace device addr of anf_node | |||
| @@ -31,6 +31,49 @@ class NetFlatten(nn.Cell): | |||
| return self.flatten(x) | |||
| class NetAllFlatten(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAllFlatten, self).__init__() | |||
| self.flatten = P.Flatten() | |||
| def construct(self, x): | |||
| loop_count = 4 | |||
| while loop_count > 0: | |||
| x = self.flatten(x) | |||
| loop_count = loop_count - 1 | |||
| return x | |||
| class NetFirstFlatten(nn.Cell): | |||
| def __init__(self): | |||
| super(NetFirstFlatten, self).__init__() | |||
| self.flatten = P.Flatten() | |||
| self.relu = P.ReLU() | |||
| def construct(self, x): | |||
| loop_count = 4 | |||
| while loop_count > 0: | |||
| x = self.flatten(x) | |||
| loop_count = loop_count - 1 | |||
| x = self.relu(x) | |||
| return x | |||
| class NetLastFlatten(nn.Cell): | |||
| def __init__(self): | |||
| super(NetLastFlatten, self).__init__() | |||
| self.flatten = P.Flatten() | |||
| self.relu = P.ReLU() | |||
| def construct(self, x): | |||
| loop_count = 4 | |||
| x = self.relu(x) | |||
| while loop_count > 0: | |||
| x = self.flatten(x) | |||
| loop_count = loop_count - 1 | |||
| return x | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -46,3 +89,55 @@ def test_flatten(): | |||
| flatten = NetFlatten() | |||
| output = flatten(x) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_all_flatten(): | |||
| x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) | |||
| expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| flatten = NetAllFlatten() | |||
| output = flatten(x) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| flatten = NetAllFlatten() | |||
| output = flatten(x) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_first_flatten(): | |||
| x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) | |||
| expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| flatten = NetFirstFlatten() | |||
| output = flatten(x) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| flatten = NetFirstFlatten() | |||
| output = flatten(x) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_last_flatten(): | |||
| x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) | |||
| expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| flatten = NetLastFlatten() | |||
| output = flatten(x) | |||
| assert (output.asnumpy() == expect).all() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| flatten = NetLastFlatten() | |||
| output = flatten(x) | |||
| assert (output.asnumpy() == expect).all() | |||