Merge pull request !28983 from caifubi/master-pynative-add-bprop-flagfeature/build-system-rewrite
| @@ -2725,6 +2725,8 @@ void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &g | |||
| } | |||
| // Get bprop graph of top cell | |||
| auto bprop_graph = GetBpropGraph(grad, cell, w_args, p_args, size, args); | |||
| MS_EXCEPTION_IF_NULL(bprop_graph); | |||
| bprop_graph->set_is_bprop(true); | |||
| resource->set_func_graph(bprop_graph); | |||
| auto manager = resource->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| @@ -448,12 +448,7 @@ bool AscendDeviceAddress::SyncDeviceToDeviceWithSameFormatType(const ShapeVector | |||
| return false; | |||
| } | |||
| BindDevice(); | |||
| auto ret_rt_memcpy = aclrtMemcpy(ptr_, size, src_ptr, size, ACL_MEMCPY_DEVICE_TO_DEVICE); | |||
| if (ret_rt_memcpy != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "SyncDeviceToDevice failed, rtMemcpy mem size [" << size << "], ret [" << ret_rt_memcpy << "]"; | |||
| return false; | |||
| } | |||
| return true; | |||
| return AsyncDeviceToDevice(shape, size, type, src_ptr, format); | |||
| } | |||
| bool AscendDeviceAddress::SyncDeviceToDeviceWithDiffFormatType(const DeviceSync *src_device_addr) const { | |||
| @@ -345,7 +345,7 @@ void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_wit | |||
| GraphCompilerInfo::~GraphCompilerInfo() { GraphScheduler::GetInstance().Clear(name_, graphs_); } | |||
| GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs, | |||
| const DeviceContext *device_context) { | |||
| const DeviceContext *device_context, bool run_in_pynative) { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| MS_LOG(INFO) << "Status record: start compile graph."; | |||
| @@ -372,7 +372,17 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod | |||
| session_->SetInputNodeUsage(graph, manager); | |||
| graph->SetOptimizerFlag(); | |||
| auto graph_id = CompileGraphImpl(graph, device_context); | |||
| GraphId graph_id; | |||
| if (run_in_pynative) { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| // Graphkernel not support pynative mode now, so when users open graphkernel in pynative mode | |||
| // should print a warning log to reminder users by using GetInstance func. | |||
| (void)graphkernel::GraphKernelFlags::GetInstance(); | |||
| session_->InitAllBucket(graph, device_context); | |||
| graph_id = graph->graph_id(); | |||
| } else { | |||
| graph_id = CompileGraphImpl(graph, device_context); | |||
| } | |||
| session_->DumpGraphs({graph}); | |||
| @@ -435,14 +445,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| const auto &ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| // graphkernel not support pynative mode now, so when users open graphkernel | |||
| // in pynative mode should print a warning log to reminder users by using GetInstance func. | |||
| graphkernel::GraphKernelFlags::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| session_->InitAllBucket(graph, device_context); | |||
| return graph->graph_id(); | |||
| } | |||
| #ifdef ENABLE_DUMP_IR | |||
| bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| @@ -98,7 +98,7 @@ class GraphCompiler { | |||
| // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode, | |||
| // the detailed implementation of compiling graph is in 'CompileGraphImpl'. | |||
| GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs, | |||
| const DeviceContext *device_context); | |||
| const DeviceContext *device_context, bool run_in_pynative = false); | |||
| // Construct kernel graph from function graph and compile kernel graph in Graph mode, | |||
| // the detailed implementation of compiling graph is in 'CompileGraphImpl'. | |||
| @@ -31,6 +31,7 @@ | |||
| #include "runtime/hardware/ascend/ascend_graph_optimization.h" | |||
| #include "backend/kernel_compiler/ascend_kernel_mod.h" | |||
| #include "backend/kernel_compiler/aicpu/aicpu_kernel_load.h" | |||
| #include "backend/kernel_compiler/tbe/tbe_kernel_compile.h" | |||
| #include "runtime/device/ascend/ascend_bucket.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "runtime/device/ascend/ascend_memory_adapter.h" | |||
| @@ -263,6 +264,9 @@ void AscendDeviceContext::Initialize() { | |||
| compute_stream_ = runtime_instance_->compute_stream(); | |||
| communication_stream_ = runtime_instance_->communication_stream(); | |||
| // Initialize tbe using HCCL rank_id | |||
| kernel::ascend::TbeKernelCompileManager::GetInstance().TbeInitialize(); | |||
| initialized_ = true; | |||
| MS_LOG(INFO) << "Status record: Initialize success."; | |||
| } | |||
| @@ -279,6 +283,7 @@ void AscendDeviceContext::Destroy() { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Status record: Destroy start..."; | |||
| graph_event_.clear(); | |||
| rank_id_ = 0; | |||
| if (runtime_instance_) { | |||
| // TODO(lzlang): Destroy runtime instance after fully support MindRT, otherwise runtime will be destructed | |||
| @@ -550,6 +555,8 @@ bool AscendDeviceContext::ExecuteGraph(const KernelGraphPtr &graph) const { | |||
| const uint64_t kUSecondInSecond = 1000000; | |||
| bool ret = false; | |||
| if (graph->is_executing_sink()) { | |||
| InsertEventBeforeRunTask(graph); | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| auto start_time = std::chrono::steady_clock::now(); | |||
| #else | |||
| @@ -870,6 +877,23 @@ bool AscendDeviceContext::LaunchAtomicClean(const CNodePtr &node, const std::vec | |||
| return kernel_mod->Launch(atomic_inputs, {}, {}, GetKernelStream(atomic_node)); | |||
| } | |||
| void AscendDeviceContext::InsertEventBeforeRunTask(const KernelGraphPtr &graph) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| if (!graph->is_executing_sink() || graph->is_dynamic_shape()) { | |||
| return; | |||
| } | |||
| MS_LOG(DEBUG) << "Insert event between PyNative and Graph"; | |||
| MS_EXCEPTION_IF_NULL(runtime_instance_); | |||
| auto model_stream = runtime_instance_->GetModelStream(graph->graph_id()); | |||
| auto compute_event = runtime_instance_->CreateDeviceEvent(); | |||
| MS_EXCEPTION_IF_NULL(compute_event); | |||
| compute_event->set_wait_stream(model_stream); | |||
| compute_event->set_record_stream(compute_stream_); | |||
| compute_event->RecordEvent(); | |||
| compute_event->WaitEvent(); | |||
| graph_event_[graph->graph_id()] = compute_event; | |||
| } | |||
| MS_REGISTER_DEVICE(kAscendDevice, AscendDeviceContext); | |||
| } // namespace ascend | |||
| } // namespace device | |||
| @@ -141,6 +141,7 @@ class AscendDeviceContext : public DeviceContext { | |||
| bool PySyncRuning() const; | |||
| bool MemoryCopyAsync(const CNodePtr &node, const vector<AddressPtr> &inputs, const vector<AddressPtr> &outputs) const; | |||
| void GenKernelEvents(const NotNull<KernelGraphPtr> &root_graph) const; | |||
| void InsertEventBeforeRunTask(const KernelGraphPtr &graph) const; | |||
| void ReportErrorMessage() const; | |||
| void ReportWarningMessage() const; | |||
| @@ -166,6 +167,8 @@ class AscendDeviceContext : public DeviceContext { | |||
| // node_atomics_ will be cleaned up in CompileGraph. | |||
| mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_persistent_cache_; | |||
| mutable std::set<CNodePtr> nop_op_to_memcpy_; | |||
| // Event for multi-stream | |||
| mutable std::map<uint32_t, std::shared_ptr<DeviceEvent>> graph_event_; | |||
| // Some NOP nodes have be hide in execution order, it doesn't have output device address, this function creates | |||
| // output device address for these nodes, and the output device address is the same with input device address. | |||
| void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const; | |||
| @@ -20,6 +20,7 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| #include "runtime/device/device_address.h" | |||
| #include "runtime/device/bucket.h" | |||
| #include "runtime/hardware/collective/collective_communication_lib.h" | |||
| @@ -517,12 +517,12 @@ bool MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) { | |||
| // Foreach the segments to compile graph. | |||
| for (const auto &segment : new_segments) { | |||
| CompileGraph(segment, contain_multi_target); | |||
| CompileGraph(segment, contain_multi_target, func_graph->is_bprop()); | |||
| } | |||
| return true; | |||
| } | |||
| void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target) { | |||
| void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative) { | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| // Compile the normal nodes, which doesn't contain the cut node. | |||
| if (segment->nodes_.size() == 0) { | |||
| @@ -548,13 +548,14 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode. | |||
| if (contain_multi_target && ms_execution_mode_ == kPynativeMode) { | |||
| if ((contain_multi_target || !run_in_pynative) && ms_execution_mode_ == kPynativeMode) { | |||
| real_execution_mode_ = kGraphMode; | |||
| context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE"; | |||
| } | |||
| // Compile graph. | |||
| auto graph_id = graph_compiler_->CompileGraph(segment, outputs, device_context); | |||
| auto graph_id = graph_compiler_->CompileGraph(segment, outputs, device_context, run_in_pynative); | |||
| if (ms_execution_mode_ != real_execution_mode_) { | |||
| context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_); | |||
| @@ -905,6 +906,8 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, | |||
| const auto &graph_compiler_info = *(graph_iter->second); | |||
| const auto &origin_parameters = graph_compiler_info.origin_parameters_order_; | |||
| SyncLazyTasks(); | |||
| // Transform args to input tensors. | |||
| // Input tensors of the graph. | |||
| std::vector<std::vector<tensor::TensorPtr>> input_tensors; | |||
| @@ -134,7 +134,7 @@ class MindRTBackend : public Backend { | |||
| bool CompileGraph(const FuncGraphPtr &func_graph); | |||
| // Compile the kernel graph by the segment which is from the function graph partition. | |||
| void CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target); | |||
| void CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative); | |||
| // CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode. | |||
| void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context, | |||
| @@ -261,3 +261,31 @@ def test_pynative_ms_function(): | |||
| out_b = grad(net_b, params_b)(input_data) | |||
| assert np.allclose(out_a[0][0].asnumpy(), out_b[0][0].asnumpy(), 0.0001, 0.0001) | |||
| assert np.allclose(out_a[1][0].asnumpy(), out_b[1][0].asnumpy(), 0.0001, 0.0001) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_pynative_ms_function_mix_execute(): | |||
| """ | |||
| Feature: PyNative ms_function. | |||
| Description: Mixed execution of PyNative and ms_function. | |||
| Expectation: The calculation result is correct. | |||
| """ | |||
| class Net(nn.Cell): | |||
| @ms_function | |||
| def test_ms_function(self, x, y): | |||
| return x * y | |||
| def construct(self, x, y): | |||
| z = x * y | |||
| return self.test_ms_function(z, x) | |||
| net = Net() | |||
| a = Tensor(2) | |||
| b = Tensor(2) | |||
| output = net(a, b) | |||
| assert output == 8 | |||