From: @chujinjin Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -1383,5 +1383,14 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const | |||
| RunInfer(func_graph, inputs); | |||
| return CompileGraphImpl(func_graph); | |||
| } | |||
| void AscendSession::SyncStream() { | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| auto ret = runtime_instance->SyncStream(); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Sync stream error!"; | |||
| } | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -48,6 +48,7 @@ class AscendSession : public SessionBasic { | |||
| void Init(uint32_t device_id) override; | |||
| // get graph id of final graph | |||
| GraphId GetFinalRunGraph() const override { return final_graph_id_; } | |||
| void SyncStream() override; | |||
| protected: | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| @@ -508,6 +508,15 @@ void GPUSession::CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &ke | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| runtime_instance->CleanValueNodeDeviceAddr(kernel_graph.get()); | |||
| } | |||
| void GPUSession::SyncStream() { | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| auto ret = runtime_instance->SyncStream(); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Sync stream error!"; | |||
| } | |||
| } | |||
| } // namespace gpu | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -32,6 +32,7 @@ class GPUSession : public SessionBasic { | |||
| GPUSession() = default; | |||
| ~GPUSession() override = default; | |||
| void Init(uint32_t device_id) override; | |||
| void SyncStream() override; | |||
| protected: | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| @@ -66,9 +66,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| } | |||
| virtual void Init(uint32_t device_id) { device_id_ = device_id; } | |||
| void InitExecutor(const std::string &device_name, uint32_t device_id); | |||
| virtual void SyncStream() {} | |||
| virtual ~SessionBasic() { summary_callback_ = nullptr; } | |||
| GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); | |||
| @@ -2113,6 +2113,13 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c | |||
| PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args); | |||
| } | |||
| void PynativeExecutor::Sync() { | |||
| if (session == nullptr) { | |||
| MS_EXCEPTION(NotExistsError) << "No session has been created!"; | |||
| } | |||
| session->SyncStream(); | |||
| } | |||
| REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { | |||
| (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_") | |||
| .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") | |||
| @@ -2121,6 +2128,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { | |||
| .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.") | |||
| .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.") | |||
| .def("clear", &PynativeExecutor::Clear, "pynative clear status.") | |||
| .def("sync", &PynativeExecutor::Sync, "pynative sync stream.") | |||
| .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""), | |||
| "Executor run function.") | |||
| .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), | |||
| @@ -96,6 +96,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| void Clean(); | |||
| // Destrcut call | |||
| void ClearRes(); | |||
| // Sync stream | |||
| void Sync(); | |||
| private: | |||
| PynativeExecutor() = default; | |||
| @@ -314,6 +314,9 @@ class _PynativeExecutor: | |||
| def clear(self, flag=""): | |||
| self._executor.clear(flag) | |||
| def sync(self): | |||
| self._executor.sync() | |||
| def set_grad_flag(self, flag): | |||
| self._executor.set_grad_flag(flag) | |||
| @@ -67,6 +67,7 @@ def connect_network_with_dataset(network, dataset_helper): | |||
| >>> net = Net() | |||
| >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) | |||
| """ | |||
| class _DataWrapper(nn.Cell): | |||
| """ | |||
| Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the | |||
| @@ -163,16 +164,20 @@ class DatasetHelper: | |||
| if context.get_context("enable_ge"): | |||
| iterclass = _DatasetIterGE | |||
| else: | |||
| if context.get_context("device_target") == "Ascend": | |||
| iterclass = _DatasetIterMSLoopSink | |||
| elif context.get_context("device_target") == "GPU": | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| iterclass = _DatasetIterPSLite | |||
| else: | |||
| if context.get_context("mode") == context.GRAPH_MODE: | |||
| if context.get_context("device_target") == "Ascend": | |||
| iterclass = _DatasetIterMSLoopSink | |||
| elif context.get_context("device_target") == "CPU": | |||
| raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.") | |||
| elif context.get_context("device_target") == "GPU": | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| iterclass = _DatasetIterPSLite | |||
| else: | |||
| iterclass = _DatasetIterMSLoopSink | |||
| elif context.get_context("device_target") == "CPU": | |||
| raise RuntimeError( | |||
| "Currently dataset sink mode is not supported when the device target is CPU.") | |||
| else: | |||
| iterclass = _DatasetIterPyNative | |||
| self.iter = iterclass(dataset, sink_size, epoch_num) | |||
| else: | |||
| iterclass = _DatasetIterNormal | |||
| @@ -281,6 +286,20 @@ class _DatasetIterGE(_DatasetIter): | |||
| self.op = op | |||
| class _DatasetIterPyNative(_DatasetIter): | |||
| """Iter for MS(enable_loop_sink=False).""" | |||
| def __init__(self, dataset, sink_size, epoch_num): | |||
| super().__init__(dataset, sink_size, epoch_num) | |||
| if sink_size > 0: | |||
| self.sink_count = sink_size | |||
| else: | |||
| self.sink_count = dataset.get_dataset_size() | |||
| def op(): | |||
| return tuple() | |||
| self.op = op | |||
| class _DatasetIterMSLoopSink(_DatasetIter): | |||
| """Iter for context (device_target=Ascend)""" | |||
| @@ -329,6 +348,7 @@ class _DatasetIterPSLite(_DatasetIter): | |||
| def op(): | |||
| return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1) | |||
| self.op = op | |||
| @@ -23,7 +23,7 @@ from mindspore import log as logger | |||
| from ..common.tensor import Tensor | |||
| from ..nn.metrics import get_metrics | |||
| from .._checkparam import check_input_data, check_output_data, Validator | |||
| from .callback import _InternalCallbackParam, RunContext, _CallbackManager | |||
| from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback | |||
| from .. import context | |||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | |||
| @@ -35,6 +35,7 @@ from ..context import ParallelMode | |||
| from ..parallel._cost_model_context import _set_multi_subgraphs | |||
| from .dataset_helper import DatasetHelper, connect_network_with_dataset | |||
| from . import amp | |||
| from ..common.api import _pynative_exec | |||
| def _transfer_tensor_to_tuple(inputs): | |||
| @@ -47,6 +48,11 @@ def _transfer_tensor_to_tuple(inputs): | |||
| return inputs | |||
| class _StepSync(Callback): | |||
| def step_end(self, run_context): | |||
| _pynative_exec.sync() | |||
| class Model: | |||
| """ | |||
| High-Level API for Training or Testing. | |||
| @@ -365,6 +371,9 @@ class Model: | |||
| cb_params.device_number = self._device_number | |||
| cb_params.train_dataset = train_dataset | |||
| cb_params.list_callback = self._transform_callbacks(callbacks) | |||
| if context.get_context("mode") == context.PYNATIVE_MODE: | |||
| cb_params.list_callback.insert(0, _StepSync()) | |||
| callbacks = cb_params.list_callback | |||
| cb_params.train_dataset_element = None | |||
| cb_params.network = self._network | |||
| if _is_role_pserver() or _is_role_sched(): | |||
| @@ -374,8 +383,8 @@ class Model: | |||
| with _CallbackManager(callbacks) as list_callback: | |||
| if not dataset_sink_mode: | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| elif context.get_context("device_target") == "CPU" or context.get_context("mode") == context.PYNATIVE_MODE: | |||
| logger.warning("The CPU or PyNative mode cannot support dataset sink mode currently." | |||
| elif context.get_context("device_target") == "CPU": | |||
| logger.warning("The CPU cannot support dataset sink mode currently." | |||
| "So the training process will be performed with dataset not sink.") | |||
| self._train_process(epoch, train_dataset, list_callback, cb_params) | |||
| else: | |||
| @@ -417,7 +426,7 @@ class Model: | |||
| run_context = RunContext(cb_params) | |||
| list_callback.begin(run_context) | |||
| is_graph = (context.get_context("mode") == context.GRAPH_MODE) | |||
| # used to stop training for early stop, such as stopAtTIme or stopATStep | |||
| should_stop = False | |||
| dataset_helper = None | |||
| @@ -441,7 +450,10 @@ class Model: | |||
| cb_params.train_dataset_element = inputs | |||
| list_callback.step_begin(run_context) | |||
| outputs = self._train_network(*inputs) | |||
| cb_params.cur_step_num += dataset_helper.sink_size() | |||
| if is_graph: | |||
| cb_params.cur_step_num += dataset_helper.sink_size() | |||
| else: | |||
| cb_params.cur_step_num += 1 | |||
| cb_params.net_outputs = outputs | |||
| list_callback.step_end(run_context) | |||