Browse Source

!13225 Get a string path when the summary path is a list

From: @ouwenchang
Reviewed-by: @yelihua,@lixiaohui33
Signed-off-by: @lixiaohui33
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
4c9ad93e13
2 changed files with 15 additions and 4 deletions
  1. +13
    -3
      mindspore/train/callback/_summary_collector.py
  2. +2
    -1
      mindspore/train/summary/_summary_adapter.py

+ 13
- 3
mindspore/train/callback/_summary_collector.py View File

@@ -20,6 +20,7 @@ import json
from json.decoder import JSONDecodeError

from importlib import import_module
from collections.abc import Iterable

import numpy as np

@@ -842,12 +843,21 @@ class SummaryCollector(Callback):
dataset_file_set = (dataset_package.MindDataset, dataset_package.ManifestDataset)
dataset_files_set = (dataset_package.TFRecordDataset, dataset_package.TextFileDataset)

dataset_path = ''

if isinstance(output_dataset, dataset_file_set):
return output_dataset.dataset_file
dataset_path = output_dataset.dataset_file
if isinstance(output_dataset, dataset_dir_set):
return output_dataset.dataset_dir
dataset_path = output_dataset.dataset_dir
if isinstance(output_dataset, dataset_files_set):
return output_dataset.dataset_files[0]
dataset_path = output_dataset.dataset_files[0]

if dataset_path:
if isinstance(dataset_path, str):
return dataset_path
if isinstance(dataset_path, Iterable):
return list(dataset_path)[0]

return self._get_dataset_path(output_dataset.children[0])

@staticmethod


+ 2
- 1
mindspore/train/summary/_summary_adapter.py View File

@@ -22,6 +22,7 @@ from PIL import Image
from mindspore import log as logger
from mindspore import context
from mindspore.communication.management import get_rank
from mindspore.communication.management import GlobalComm

from ..._checkparam import Validator
from ..anf_ir_pb2 import DataType, ModelProto
@@ -57,7 +58,7 @@ def get_event_file_name(prefix, suffix, time_second):

device_num = context.get_auto_parallel_context('device_num')
device_id = context.get_context('device_id')
if device_num > 1:
if device_num > 1 or GlobalComm.WORLD_COMM_GROUP == 'nccl_world_group':
# Notice:
# In GPU distribute training scene, get_context('device_id') will not work,
# so we use get_rank instead of get_context.


Loading…
Cancel
Save