Merge pull request !1741 from wenkai/wk1_cpu_summary_0531tags/v0.5.0-beta
| @@ -22,10 +22,23 @@ namespace device { | |||||
| namespace cpu { | namespace cpu { | ||||
| bool CPUDeviceAddress::SyncDeviceToHost(const std::vector<int> & /*shape*/, size_t size, TypeId type, | bool CPUDeviceAddress::SyncDeviceToHost(const std::vector<int> & /*shape*/, size_t size, TypeId type, | ||||
| void *host_ptr) const { | void *host_ptr) const { | ||||
| if (type == kNumberTypeFloat16) { | |||||
| MS_EXCEPTION_IF_NULL(ptr_); | |||||
| if (host_ptr == ptr_) { | |||||
| MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; | |||||
| return true; | |||||
| } | |||||
| if (type == type_id_) { | |||||
| (void)memcpy_s(host_ptr, size, ptr_, size); | |||||
| } else if (type == kNumberTypeFloat16) { | |||||
| FloatToHalf(host_ptr, ptr_, size / 2); | FloatToHalf(host_ptr, ptr_, size / 2); | ||||
| } else if (type == kNumberTypeFloat64) { | } else if (type == kNumberTypeFloat64) { | ||||
| FloatToDouble(host_ptr, ptr_, size / sizeof(double)); | FloatToDouble(host_ptr, ptr_, size / sizeof(double)); | ||||
| } else { | |||||
| MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type) | |||||
| << "."; | |||||
| return false; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "session/session_basic.h" | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -234,9 +235,18 @@ void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector<ker | |||||
| input_list->push_back(input); | input_list->push_back(input); | ||||
| } | } | ||||
| void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { | |||||
| resource_manager_.IncreaseSummaryRefCount(summary_outputs); | |||||
| } | |||||
| void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { | |||||
| resource_manager_.DecreaseSummaryRefCount(summary_outputs); | |||||
| } | |||||
| bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { | bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| resource_manager_.ResetAddressRefCount(kernel_graph); | |||||
| resource_manager_.IncreaseAddressRefCount(kernel_graph); | |||||
| auto kernels = kernel_graph->execution_order(); | auto kernels = kernel_graph->execution_order(); | ||||
| for (const auto &kernel : kernels) { | for (const auto &kernel : kernels) { | ||||
| std::vector<kernel::AddressPtr> kernel_inputs; | std::vector<kernel::AddressPtr> kernel_inputs; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "device/kernel_runtime.h" | #include "device/kernel_runtime.h" | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "session/session_basic.h" | |||||
| #include "device/cpu/cpu_resource_manager.h" | #include "device/cpu/cpu_resource_manager.h" | ||||
| #include "utils/any.h" | #include "utils/any.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -37,6 +38,8 @@ class CPUKernelRuntime : public KernelRuntime { | |||||
| void AssignKernelAddress(session::KernelGraph *kernel_graph); | void AssignKernelAddress(session::KernelGraph *kernel_graph); | ||||
| void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs, | ||||
| VectorRef *outputs); | VectorRef *outputs); | ||||
| void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||||
| void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||||
| protected: | protected: | ||||
| bool SyncStream() override { return true; }; | bool SyncStream() override { return true; }; | ||||
| @@ -76,7 +76,47 @@ void CPUResourceManager::MemFree(void *ptr) { | |||||
| } | } | ||||
| } | } | ||||
| void CPUResourceManager::ResetAddressRefCount(const session::KernelGraph *graph) { | |||||
| void CPUResourceManager::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { | |||||
| if (!dynamic_malloc_) { | |||||
| return; | |||||
| } | |||||
| if (summary_outputs.empty()) { | |||||
| return; | |||||
| } | |||||
| for (auto &output_item : summary_outputs) { | |||||
| auto node = output_item.second.first; | |||||
| size_t index = IntToSize(output_item.second.second); | |||||
| auto address = AnfAlgo::GetMutableOutputAddr(node, index); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| address->ref_count_++; | |||||
| } | |||||
| } | |||||
| void CPUResourceManager::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { | |||||
| if (!dynamic_malloc_) { | |||||
| return; | |||||
| } | |||||
| if (summary_outputs.empty()) { | |||||
| return; | |||||
| } | |||||
| for (auto &output_item : summary_outputs) { | |||||
| auto node = output_item.second.first; | |||||
| size_t index = IntToSize(output_item.second.second); | |||||
| auto address = AnfAlgo::GetMutableOutputAddr(node, index); | |||||
| MS_EXCEPTION_IF_NULL(address); | |||||
| address->ref_count_--; | |||||
| if (address->ref_count_ == 0 && address->ptr_ != nullptr) { | |||||
| MemFree(address->ptr_); | |||||
| address->ptr_ = nullptr; | |||||
| } | |||||
| } | |||||
| } | |||||
| void CPUResourceManager::IncreaseAddressRefCount(const session::KernelGraph *graph) { | |||||
| if (!dynamic_malloc_) { | if (!dynamic_malloc_) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "session/session_basic.h" | |||||
| #include "device/device_address.h" | #include "device/device_address.h" | ||||
| #include "device/cpu/cpu_simple_mem_plan.h" | #include "device/cpu/cpu_simple_mem_plan.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -31,10 +32,12 @@ class CPUResourceManager { | |||||
| void MemPlan(const session::KernelGraph *graph); | void MemPlan(const session::KernelGraph *graph); | ||||
| void MemMalloc(const session::KernelGraph *graph); | void MemMalloc(const session::KernelGraph *graph); | ||||
| void ResetAddressRefCount(const session::KernelGraph *graph); | |||||
| void IncreaseAddressRefCount(const session::KernelGraph *graph); | |||||
| void DecreaseAddressRefCount(const AnfNodePtr &kernel); | void DecreaseAddressRefCount(const AnfNodePtr &kernel); | ||||
| void *MemMalloc(size_t mem_size); | void *MemMalloc(size_t mem_size); | ||||
| void MemFree(void *ptr); | void MemFree(void *ptr); | ||||
| void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||||
| void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); | |||||
| private: | private: | ||||
| void MemFree(); | void MemFree(); | ||||
| @@ -68,11 +68,25 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||||
| predictmodel::StepConvertWeight(inputs); | predictmodel::StepConvertWeight(inputs); | ||||
| auto execution_order = kernel_graph->execution_order(); | auto execution_order = kernel_graph->execution_order(); | ||||
| Reorder(&execution_order); | Reorder(&execution_order); | ||||
| bool enable_summary = summary_callback_ != nullptr; | |||||
| kernel_graph->set_execution_order(execution_order); | kernel_graph->set_execution_order(execution_order); | ||||
| NamedSummaryOutputs summary_outputs; | |||||
| if (enable_summary) { | |||||
| GetSummaryNodes(kernel_graph.get(), &summary_outputs); | |||||
| runtime_.IncreaseSummaryRefCount(summary_outputs); | |||||
| } | |||||
| bool ret = runtime_.Run(kernel_graph.get()); | bool ret = runtime_.Run(kernel_graph.get()); | ||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(EXCEPTION) << "Run graph failed"; | MS_LOG(EXCEPTION) << "Run graph failed"; | ||||
| } | } | ||||
| if (enable_summary) { | |||||
| Summary(kernel_graph.get()); | |||||
| runtime_.DecreaseSummaryRefCount(summary_outputs); | |||||
| } | |||||
| MS_LOG(INFO) << "Run graph end"; | MS_LOG(INFO) << "Run graph end"; | ||||
| } | } | ||||
| @@ -745,8 +745,7 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { | |||||
| (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); | (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); | ||||
| } | } | ||||
| void SessionBasic::GetSummaryNodes(const KernelGraph *graph, | |||||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) { | |||||
| void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs *summary) { | |||||
| MS_LOG(DEBUG) << "Update summary Start"; | MS_LOG(DEBUG) << "Update summary Start"; | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(summary); | MS_EXCEPTION_IF_NULL(summary); | ||||
| @@ -780,7 +779,7 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||||
| return; | return; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs; | |||||
| NamedSummaryOutputs summary_outputs; | |||||
| GetSummaryNodes(graph, &summary_outputs); | GetSummaryNodes(graph, &summary_outputs); | ||||
| // do not exist summary node | // do not exist summary node | ||||
| if (summary_outputs.empty()) { | if (summary_outputs.empty()) { | ||||
| @@ -130,6 +130,7 @@ class SessionBasic { | |||||
| }; | }; | ||||
| using SessionPtr = std::shared_ptr<session::SessionBasic>; | using SessionPtr = std::shared_ptr<session::SessionBasic>; | ||||
| using NamedSummaryOutputs = std::unordered_map<std::string, std::pair<AnfNodePtr, int>>; | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H | #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H | ||||
| @@ -0,0 +1,79 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Summary cpu st.""" | |||||
| import os | |||||
| import tempfile | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from tests.summary_utils import SummaryReader | |||||
| from mindspore.train.summary.summary_record import SummaryRecord | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||||
| class SummaryNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.scalar_summary = P.ScalarSummary() | |||||
| self.image_summary = P.ImageSummary() | |||||
| self.tensor_summary = P.TensorSummary() | |||||
| self.histogram_summary = P.HistogramSummary() | |||||
| def construct(self, image_tensor): | |||||
| self.image_summary("image", image_tensor) | |||||
| self.tensor_summary("tensor", image_tensor) | |||||
| self.histogram_summary("histogram", image_tensor) | |||||
| scalar = image_tensor[0][0][0][0] | |||||
| self.scalar_summary("scalar", scalar) | |||||
| return scalar | |||||
| def train_summary_record(test_writer, steps): | |||||
| """Train and record summary.""" | |||||
| net = SummaryNet() | |||||
| out_me_dict = {} | |||||
| for i in range(0, steps): | |||||
| image_tensor = Tensor(np.array([[[[i]]]]).astype(np.float32)) | |||||
| out_put = net(image_tensor) | |||||
| test_writer.record(i) | |||||
| out_me_dict[i] = out_put.asnumpy() | |||||
| return out_me_dict | |||||
| class TestCpuSummary: | |||||
| """Test cpu summary.""" | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_summary_step2_summary_record1(self): | |||||
| """Test record 10 step summary.""" | |||||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||||
| steps = 2 | |||||
| with SummaryRecord(tmp_dir) as test_writer: | |||||
| train_summary_record(test_writer, steps=steps) | |||||
| file_name = os.path.realpath(test_writer.full_file_name) | |||||
| with SummaryReader(file_name) as summary_writer: | |||||
| for _ in range(steps): | |||||
| event = summary_writer.read_event() | |||||
| tags = set(value.tag for value in event.summary.value) | |||||
| assert tags == {'tensor', 'histogram', 'scalar', 'image'} | |||||
| @@ -22,22 +22,44 @@ _HEADER_CRC_SIZE = 4 | |||||
| _DATA_CRC_SIZE = 4 | _DATA_CRC_SIZE = 4 | ||||
| class _EndOfSummaryFileException(Exception): | |||||
| """Indicates the summary file is exhausted.""" | |||||
| class SummaryReader: | class SummaryReader: | ||||
| """Read events from summary file.""" | |||||
| """ | |||||
| Basic summary read function. | |||||
| Args: | |||||
| canonical_file_path (str): The canonical summary file path. | |||||
| ignore_version_event (bool): Whether ignore the version event at the beginning of summary file. | |||||
| """ | |||||
| def __init__(self, canonical_file_path, ignore_version_event=True): | |||||
| self._file_path = canonical_file_path | |||||
| self._ignore_version_event = ignore_version_event | |||||
| def __init__(self, file_name): | |||||
| self._file_name = file_name | |||||
| self._file_handler = open(self._file_name, "rb") | |||||
| # skip version event | |||||
| self.read_event() | |||||
| def __enter__(self): | |||||
| self._file_handler = open(self._file_path, "rb") | |||||
| if self._ignore_version_event: | |||||
| self.read_event() | |||||
| return self | |||||
| def __exit__(self, *unused_args): | |||||
| self._file_handler.close() | |||||
| return False | |||||
| def read_event(self): | def read_event(self): | ||||
| """Read next event.""" | """Read next event.""" | ||||
| file_handler = self._file_handler | file_handler = self._file_handler | ||||
| header = file_handler.read(_HEADER_SIZE) | header = file_handler.read(_HEADER_SIZE) | ||||
| data_len = struct.unpack('Q', header)[0] | data_len = struct.unpack('Q', header)[0] | ||||
| # Ignore crc check. | |||||
| file_handler.read(_HEADER_CRC_SIZE) | file_handler.read(_HEADER_CRC_SIZE) | ||||
| event_str = file_handler.read(data_len) | event_str = file_handler.read(data_len) | ||||
| # Ignore crc check. | |||||
| file_handler.read(_DATA_CRC_SIZE) | file_handler.read(_DATA_CRC_SIZE) | ||||
| summary_event = summary_pb2.Event.FromString(event_str) | summary_event = summary_pb2.Event.FromString(event_str) | ||||
| return summary_event | return summary_event | ||||
| @@ -22,7 +22,7 @@ import numpy as np | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.train.summary._summary_adapter import _calc_histogram_bins | from mindspore.train.summary._summary_adapter import _calc_histogram_bins | ||||
| from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data | from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data | ||||
| from .summary_reader import SummaryReader | |||||
| from tests.summary_utils import SummaryReader | |||||
| CUR_DIR = os.getcwd() | CUR_DIR = os.getcwd() | ||||
| SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/") | SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/") | ||||
| @@ -57,9 +57,9 @@ def test_histogram_summary(): | |||||
| test_writer.record(step=1) | test_writer.record(step=1) | ||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | |||||
| event = reader.read_event() | |||||
| assert event.summary.value[0].histogram.count == 6 | |||||
| with SummaryReader(file_name) as reader: | |||||
| event = reader.read_event() | |||||
| assert event.summary.value[0].histogram.count == 6 | |||||
| def test_histogram_multi_summary(): | def test_histogram_multi_summary(): | ||||
| @@ -79,10 +79,10 @@ def test_histogram_multi_summary(): | |||||
| test_writer.record(step=i) | test_writer.record(step=i) | ||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | |||||
| for _ in range(num_step): | |||||
| event = reader.read_event() | |||||
| assert event.summary.value[0].histogram.count == size | |||||
| with SummaryReader(file_name) as reader: | |||||
| for _ in range(num_step): | |||||
| event = reader.read_event() | |||||
| assert event.summary.value[0].histogram.count == size | |||||
| def test_histogram_summary_scalar_tensor(): | def test_histogram_summary_scalar_tensor(): | ||||
| @@ -94,9 +94,9 @@ def test_histogram_summary_scalar_tensor(): | |||||
| test_writer.record(step=1) | test_writer.record(step=1) | ||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | |||||
| event = reader.read_event() | |||||
| assert event.summary.value[0].histogram.count == 1 | |||||
| with SummaryReader(file_name) as reader: | |||||
| event = reader.read_event() | |||||
| assert event.summary.value[0].histogram.count == 1 | |||||
| def test_histogram_summary_empty_tensor(): | def test_histogram_summary_empty_tensor(): | ||||
| @@ -108,9 +108,9 @@ def test_histogram_summary_empty_tensor(): | |||||
| test_writer.record(step=1) | test_writer.record(step=1) | ||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | |||||
| event = reader.read_event() | |||||
| assert event.summary.value[0].histogram.count == 0 | |||||
| with SummaryReader(file_name) as reader: | |||||
| event = reader.read_event() | |||||
| assert event.summary.value[0].histogram.count == 0 | |||||
| def test_histogram_summary_same_value(): | def test_histogram_summary_same_value(): | ||||
| @@ -125,11 +125,11 @@ def test_histogram_summary_same_value(): | |||||
| test_writer.record(step=1) | test_writer.record(step=1) | ||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | |||||
| event = reader.read_event() | |||||
| LOG.debug(event) | |||||
| with SummaryReader(file_name) as reader: | |||||
| event = reader.read_event() | |||||
| LOG.debug(event) | |||||
| assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2) | |||||
| assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2) | |||||
| def test_histogram_summary_high_dims(): | def test_histogram_summary_high_dims(): | ||||
| @@ -145,11 +145,11 @@ def test_histogram_summary_high_dims(): | |||||
| test_writer.record(step=1) | test_writer.record(step=1) | ||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | |||||
| event = reader.read_event() | |||||
| LOG.debug(event) | |||||
| with SummaryReader(file_name) as reader: | |||||
| event = reader.read_event() | |||||
| LOG.debug(event) | |||||
| assert event.summary.value[0].histogram.count == tensor_data.size | |||||
| assert event.summary.value[0].histogram.count == tensor_data.size | |||||
| def test_histogram_summary_nan_inf(): | def test_histogram_summary_nan_inf(): | ||||
| @@ -169,11 +169,11 @@ def test_histogram_summary_nan_inf(): | |||||
| test_writer.record(step=1) | test_writer.record(step=1) | ||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | |||||
| event = reader.read_event() | |||||
| LOG.debug(event) | |||||
| with SummaryReader(file_name) as reader: | |||||
| event = reader.read_event() | |||||
| LOG.debug(event) | |||||
| assert event.summary.value[0].histogram.nan_count == 1 | |||||
| assert event.summary.value[0].histogram.nan_count == 1 | |||||
| def test_histogram_summary_all_nan_inf(): | def test_histogram_summary_all_nan_inf(): | ||||
| @@ -185,11 +185,11 @@ def test_histogram_summary_all_nan_inf(): | |||||
| test_writer.record(step=1) | test_writer.record(step=1) | ||||
| file_name = os.path.join(tmp_dir, test_writer.event_file_name) | file_name = os.path.join(tmp_dir, test_writer.event_file_name) | ||||
| reader = SummaryReader(file_name) | |||||
| event = reader.read_event() | |||||
| LOG.debug(event) | |||||
| histogram = event.summary.value[0].histogram | |||||
| assert histogram.nan_count == 3 | |||||
| assert histogram.pos_inf_count == 1 | |||||
| assert histogram.neg_inf_count == 1 | |||||
| with SummaryReader(file_name) as reader: | |||||
| event = reader.read_event() | |||||
| LOG.debug(event) | |||||
| histogram = event.summary.value[0].histogram | |||||
| assert histogram.nan_count == 3 | |||||
| assert histogram.pos_inf_count == 1 | |||||
| assert histogram.neg_inf_count == 1 | |||||