Merge pull request !1741 from wenkai/wk1_cpu_summary_0531tags/v0.5.0-beta
| @@ -22,10 +22,23 @@ namespace device { | |||
| namespace cpu { | |||
| bool CPUDeviceAddress::SyncDeviceToHost(const std::vector<int> & /*shape*/, size_t size, TypeId type, | |||
| void *host_ptr) const { | |||
| if (type == kNumberTypeFloat16) { | |||
| MS_EXCEPTION_IF_NULL(ptr_); | |||
| if (host_ptr == ptr_) { | |||
| MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; | |||
| return true; | |||
| } | |||
| if (type == type_id_) { | |||
| (void)memcpy_s(host_ptr, size, ptr_, size); | |||
| } else if (type == kNumberTypeFloat16) { | |||
| FloatToHalf(host_ptr, ptr_, size / 2); | |||
| } else if (type == kNumberTypeFloat64) { | |||
| FloatToDouble(host_ptr, ptr_, size / sizeof(double)); | |||
| } else { | |||
| MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type) | |||
| << "."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -27,6 +27,7 @@ | |||
| #include "utils/config_manager.h" | |||
| #include "common/utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "session/session_basic.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| @@ -234,9 +235,18 @@ void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector<ker | |||
| input_list->push_back(input); | |||
| } | |||
| void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { | |||
| resource_manager_.IncreaseSummaryRefCount(summary_outputs); | |||
| } | |||
| void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { | |||
| resource_manager_.DecreaseSummaryRefCount(summary_outputs); | |||
| } | |||
| bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| resource_manager_.ResetAddressRefCount(kernel_graph); | |||
| resource_manager_.IncreaseAddressRefCount(kernel_graph); | |||
| auto kernels = kernel_graph->execution_order(); | |||
| for (const auto &kernel : kernels) { | |||
| std::vector<kernel::AddressPtr> kernel_inputs; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <unordered_map> | |||
| #include "device/kernel_runtime.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "session/session_basic.h" | |||
| #include "device/cpu/cpu_resource_manager.h" | |||
| #include "utils/any.h" | |||
| namespace mindspore { | |||
| @@ -37,6 +38,8 @@ class CPUKernelRuntime : public KernelRuntime { | |||
| void AssignKernelAddress(session::KernelGraph *kernel_graph); | |||
| void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs); | |||
| void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||
| void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||
| protected: | |||
| bool SyncStream() override { return true; }; | |||
| @@ -76,7 +76,47 @@ void CPUResourceManager::MemFree(void *ptr) { | |||
| } | |||
| } | |||
| void CPUResourceManager::ResetAddressRefCount(const session::KernelGraph *graph) { | |||
| void CPUResourceManager::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { | |||
| if (!dynamic_malloc_) { | |||
| return; | |||
| } | |||
| if (summary_outputs.empty()) { | |||
| return; | |||
| } | |||
| for (auto &output_item : summary_outputs) { | |||
| auto node = output_item.second.first; | |||
| size_t index = IntToSize(output_item.second.second); | |||
| auto address = AnfAlgo::GetMutableOutputAddr(node, index); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| address->ref_count_++; | |||
| } | |||
| } | |||
| void CPUResourceManager::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { | |||
| if (!dynamic_malloc_) { | |||
| return; | |||
| } | |||
| if (summary_outputs.empty()) { | |||
| return; | |||
| } | |||
| for (auto &output_item : summary_outputs) { | |||
| auto node = output_item.second.first; | |||
| size_t index = IntToSize(output_item.second.second); | |||
| auto address = AnfAlgo::GetMutableOutputAddr(node, index); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| address->ref_count_--; | |||
| if (address->ref_count_ == 0 && address->ptr_ != nullptr) { | |||
| MemFree(address->ptr_); | |||
| address->ptr_ = nullptr; | |||
| } | |||
| } | |||
| } | |||
| void CPUResourceManager::IncreaseAddressRefCount(const session::KernelGraph *graph) { | |||
| if (!dynamic_malloc_) { | |||
| return; | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "session/kernel_graph.h" | |||
| #include "session/session_basic.h" | |||
| #include "device/device_address.h" | |||
| #include "device/cpu/cpu_simple_mem_plan.h" | |||
| namespace mindspore { | |||
| @@ -31,10 +32,12 @@ class CPUResourceManager { | |||
| void MemPlan(const session::KernelGraph *graph); | |||
| void MemMalloc(const session::KernelGraph *graph); | |||
| void ResetAddressRefCount(const session::KernelGraph *graph); | |||
| void IncreaseAddressRefCount(const session::KernelGraph *graph); | |||
| void DecreaseAddressRefCount(const AnfNodePtr &kernel); | |||
| void *MemMalloc(size_t mem_size); | |||
| void MemFree(void *ptr); | |||
| void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||
| void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||
| private: | |||
| void MemFree(); | |||
| @@ -68,11 +68,25 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||
| predictmodel::StepConvertWeight(inputs); | |||
| auto execution_order = kernel_graph->execution_order(); | |||
| Reorder(&execution_order); | |||
| bool enable_summary = summary_callback_ != nullptr; | |||
| kernel_graph->set_execution_order(execution_order); | |||
| NamedSummaryOutputs summary_outputs; | |||
| if (enable_summary) { | |||
| GetSummaryNodes(kernel_graph.get(), &summary_outputs); | |||
| runtime_.IncreaseSummaryRefCount(summary_outputs); | |||
| } | |||
| bool ret = runtime_.Run(kernel_graph.get()); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Run graph failed"; | |||
| } | |||
| if (enable_summary) { | |||
| Summary(kernel_graph.get()); | |||
| runtime_.DecreaseSummaryRefCount(summary_outputs); | |||
| } | |||
| MS_LOG(INFO) << "Run graph end"; | |||
| } | |||
| @@ -745,8 +745,7 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { | |||
| (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); | |||
| } | |||
| void SessionBasic::GetSummaryNodes(const KernelGraph *graph, | |||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) { | |||
| void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs *summary) { | |||
| MS_LOG(DEBUG) << "Update summary Start"; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(summary); | |||
| @@ -780,7 +779,7 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs; | |||
| NamedSummaryOutputs summary_outputs; | |||
| GetSummaryNodes(graph, &summary_outputs); | |||
| // do not exist summary node | |||
| if (summary_outputs.empty()) { | |||
| @@ -130,6 +130,7 @@ class SessionBasic { | |||
| }; | |||
| using SessionPtr = std::shared_ptr<session::SessionBasic>; | |||
| using NamedSummaryOutputs = std::unordered_map<std::string, std::pair<AnfNodePtr, int>>; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H | |||
| @@ -0,0 +1,79 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Summary cpu st.""" | |||
| import os | |||
| import tempfile | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from tests.summary_utils import SummaryReader | |||
| from mindspore.train.summary.summary_record import SummaryRecord | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class SummaryNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.scalar_summary = P.ScalarSummary() | |||
| self.image_summary = P.ImageSummary() | |||
| self.tensor_summary = P.TensorSummary() | |||
| self.histogram_summary = P.HistogramSummary() | |||
| def construct(self, image_tensor): | |||
| self.image_summary("image", image_tensor) | |||
| self.tensor_summary("tensor", image_tensor) | |||
| self.histogram_summary("histogram", image_tensor) | |||
| scalar = image_tensor[0][0][0][0] | |||
| self.scalar_summary("scalar", scalar) | |||
| return scalar | |||
| def train_summary_record(test_writer, steps): | |||
| """Train and record summary.""" | |||
| net = SummaryNet() | |||
| out_me_dict = {} | |||
| for i in range(0, steps): | |||
| image_tensor = Tensor(np.array([[[[i]]]]).astype(np.float32)) | |||
| out_put = net(image_tensor) | |||
| test_writer.record(i) | |||
| out_me_dict[i] = out_put.asnumpy() | |||
| return out_me_dict | |||
| class TestCpuSummary: | |||
| """Test cpu summary.""" | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_summary_step2_summary_record1(self): | |||
| """Test record 10 step summary.""" | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| steps = 2 | |||
| with SummaryRecord(tmp_dir) as test_writer: | |||
| train_summary_record(test_writer, steps=steps) | |||
| file_name = os.path.realpath(test_writer.full_file_name) | |||
| with SummaryReader(file_name) as summary_writer: | |||
| for _ in range(steps): | |||
| event = summary_writer.read_event() | |||
| tags = set(value.tag for value in event.summary.value) | |||
| assert tags == {'tensor', 'histogram', 'scalar', 'image'} | |||
| @@ -22,22 +22,44 @@ _HEADER_CRC_SIZE = 4 | |||
| _DATA_CRC_SIZE = 4 | |||
| class _EndOfSummaryFileException(Exception): | |||
| """Indicates the summary file is exhausted.""" | |||
| class SummaryReader: | |||
| """Read events from summary file.""" | |||
| """ | |||
| Basic summary read function. | |||
| Args: | |||
| canonical_file_path (str): The canonical summary file path. | |||
| ignore_version_event (bool): Whether ignore the version event at the beginning of summary file. | |||
| """ | |||
| def __init__(self, canonical_file_path, ignore_version_event=True): | |||
| self._file_path = canonical_file_path | |||
| self._ignore_version_event = ignore_version_event | |||
| def __init__(self, file_name): | |||
| self._file_name = file_name | |||
| self._file_handler = open(self._file_name, "rb") | |||
| # skip version event | |||
| self.read_event() | |||
| def __enter__(self): | |||
| self._file_handler = open(self._file_path, "rb") | |||
| if self._ignore_version_event: | |||
| self.read_event() | |||
| return self | |||
| def __exit__(self, *unused_args): | |||
| self._file_handler.close() | |||
| return False | |||
| def read_event(self): | |||
| """Read next event.""" | |||
| file_handler = self._file_handler | |||
| header = file_handler.read(_HEADER_SIZE) | |||
| data_len = struct.unpack('Q', header)[0] | |||
| # Ignore crc check. | |||
| file_handler.read(_HEADER_CRC_SIZE) | |||
| event_str = file_handler.read(data_len) | |||
| # Ignore crc check. | |||
| file_handler.read(_DATA_CRC_SIZE) | |||
| summary_event = summary_pb2.Event.FromString(event_str) | |||
| return summary_event | |||
| @@ -22,7 +22,7 @@ import numpy as np | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.train.summary._summary_adapter import _calc_histogram_bins | |||
| from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data | |||
| from .summary_reader import SummaryReader | |||
| from tests.summary_utils import SummaryReader | |||
| CUR_DIR = os.getcwd() | |||
| SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/") | |||
| @@ -57,9 +57,9 @@ def test_histogram_summary(): | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == 6 | |||
| with SummaryReader(file_name) as reader: | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == 6 | |||
| def test_histogram_multi_summary(): | |||
| @@ -79,10 +79,10 @@ def test_histogram_multi_summary(): | |||
| test_writer.record(step=i) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| for _ in range(num_step): | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == size | |||
| with SummaryReader(file_name) as reader: | |||
| for _ in range(num_step): | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == size | |||
| def test_histogram_summary_scalar_tensor(): | |||
| @@ -94,9 +94,9 @@ def test_histogram_summary_scalar_tensor(): | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == 1 | |||
| with SummaryReader(file_name) as reader: | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == 1 | |||
| def test_histogram_summary_empty_tensor(): | |||
| @@ -108,9 +108,9 @@ def test_histogram_summary_empty_tensor(): | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == 0 | |||
| with SummaryReader(file_name) as reader: | |||
| event = reader.read_event() | |||
| assert event.summary.value[0].histogram.count == 0 | |||
| def test_histogram_summary_same_value(): | |||
| @@ -125,11 +125,11 @@ def test_histogram_summary_same_value(): | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| event = reader.read_event() | |||
| LOG.debug(event) | |||
| with SummaryReader(file_name) as reader: | |||
| event = reader.read_event() | |||
| LOG.debug(event) | |||
| assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2) | |||
| assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2) | |||
| def test_histogram_summary_high_dims(): | |||
| @@ -145,11 +145,11 @@ def test_histogram_summary_high_dims(): | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| event = reader.read_event() | |||
| LOG.debug(event) | |||
| with SummaryReader(file_name) as reader: | |||
| event = reader.read_event() | |||
| LOG.debug(event) | |||
| assert event.summary.value[0].histogram.count == tensor_data.size | |||
| assert event.summary.value[0].histogram.count == tensor_data.size | |||
| def test_histogram_summary_nan_inf(): | |||
| @@ -169,11 +169,11 @@ def test_histogram_summary_nan_inf(): | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| event = reader.read_event() | |||
| LOG.debug(event) | |||
| with SummaryReader(file_name) as reader: | |||
| event = reader.read_event() | |||
| LOG.debug(event) | |||
| assert event.summary.value[0].histogram.nan_count == 1 | |||
| assert event.summary.value[0].histogram.nan_count == 1 | |||
| def test_histogram_summary_all_nan_inf(): | |||
| @@ -185,11 +185,11 @@ def test_histogram_summary_all_nan_inf(): | |||
| test_writer.record(step=1) | |||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | |||
| reader = SummaryReader(file_name) | |||
| event = reader.read_event() | |||
| LOG.debug(event) | |||
| histogram = event.summary.value[0].histogram | |||
| assert histogram.nan_count == 3 | |||
| assert histogram.pos_inf_count == 1 | |||
| assert histogram.neg_inf_count == 1 | |||
| with SummaryReader(file_name) as reader: | |||
| event = reader.read_event() | |||
| LOG.debug(event) | |||
| histogram = event.summary.value[0].histogram | |||
| assert histogram.nan_count == 3 | |||
| assert histogram.pos_inf_count == 1 | |||
| assert histogram.neg_inf_count == 1 | |||