Browse Source

!28517 [Dataset][docs][chinese] 修复dataset中文文档中WatiedCallback、deserialize、imshow_det_bbox检视问题

Merge pull request !28517 from xiefangqi/code_docs_fix_chinese_docs_problem_stage1
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
c2c47413e6
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 121 additions and 34 deletions
  1. +68
    -9
      docs/api/api_python/dataset/mindspore.dataset.WaitedDSCallback.rst
  2. +8
    -9
      docs/api/api_python/dataset/mindspore.dataset.deserialize.rst
  3. +45
    -16
      docs/api/api_python/dataset/mindspore.dataset.utils.imshow_det_bbox.rst

+ 68
- 9
docs/api/api_python/dataset/mindspore.dataset.WaitedDSCallback.rst View File

@@ -3,25 +3,84 @@ mindspore.dataset.WaitedDSCallback

.. py:class:: mindspore.dataset.WaitedDSCallback(step_size=1)

用于自定义与训练回调同步的数据集回调类的抽象基类
数据集自定义回调类的抽象基类,用于与训练回调类(`mindspore.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_)的同步

此类可用于自定义在step或epoch结束后执行的回调方法
可用于在每个step或epoch开始前执行自定义的回调方法,注意,第二个step或epoch开始时才会触发该调用
例如在自动数据增强中根据上一个epoch的loss值来更新增强算子参数配置。

用户可通过 `train_run_context` 获取模型相关信息。如 `network` 、 `train_network` 、 `epoch_num` 、 `batch_num` 、 `loss_fn` 、 `optimizer` 、 `parallel_mode` 、 `device_number` 、 `list_callback` 、 `cur_epoch_num` 、 `cur_step_num` 、 `dataset_sink_mode` 、 `net_outputs` 等,详见 `mindspore.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_ 。

用户可通过 `ds_run_context` 获取数据处理管道相关信息。包括 `cur_epoch_num` (当前epoch数)、 `cur_step_num_in_epoch` (当前epoch的step数)、 `cur_step_num` (当前step数)。

**参数:**

- **step_size** (int, optional) - 每个step包含的数据行数。step大小通常与batch大小相等(默认值为1)。
- **step_size** (int, optional) - 每个step包含的数据行数。通常step_size与batch_size一致,默认值:1

**样例:**

>>> import mindspore.nn as nn
>>> from mindspore.dataset import WaitedDSCallback
>>> from mindspore import context
>>> from mindspore.train import Model
>>> from mindspore.train.callback import Callback
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>>
>>> # 自定义用于数据处理管道同步数据的回调类
>>> class MyWaitedCallback(WaitedDSCallback):
... def __init__(self, events, step_size=1):
... super().__init__(step_size)
... self.events = events
...
... # epoch开始前数据处理管道要执行的回调函数
... def sync_epoch_begin(self, train_run_context, ds_run_context):
... event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
... self.events.append(event)
...
... # step开始前数据处理管道要执行的回调函数
... def sync_step_begin(self, train_run_context, ds_run_context):
... event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
... self.events.append(event)
>>>
>>> # 自定义用于网络训练时同步数据的回调类
>>> class MyMSCallback(Callback):
... def __init__(self, events):
... self.events = events
...
... # epoch结束网络训练要执行的回调函数
... def epoch_end(self, run_context):
... cb_params = run_context.original_args()
... event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
... self.events.append(event)
...
... # step结束网络训练要执行的回调函数
... def step_end(self, run_context):
... cb_params = run_context.original_args()
... event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
... self.events.append(event)
>>>
>>> # 自定义网络
>>> class Net(nn.Cell):
... def construct(self, x, y):
... return x
>>>
>>> # 声明一个网络训练与数据处理同步的数据
>>> events = []
>>>
>>> # 声明数据处理管道和网络训练的回调类
>>> my_cb1 = MyWaitedCallback(events, 1)
>>> my_cb2 = MyMSCallback(events)
>>> arr = [1, 2, 3, 4]
>>> # 构建数据处理管道
>>> data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
>>> # 将数据处理管道的回调类加入到map中
>>> data = data.map(operations=(lambda x: x), callbacks=my_cb1)
>>>
>>> net = Net()
>>> model = Model(net)
>>>
>>> my_cb = WaitedDSCallback(32)
>>> # dataset为任意数据集实例
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
>>> data = data.batch(32)
>>> # 定义网络
>>> model.train(epochs, data, callbacks=[my_cb])
>>> # 将数据处理管道和网络训练的回调类加入到模型训练的回调列表中
>>> model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])

.. py:method:: begin(run_context)



+ 8
- 9
docs/api/api_python/dataset/mindspore.dataset.deserialize.rst View File

@@ -10,16 +10,16 @@ mindspore.dataset.deserialize

**参数:**

- **input_dict** (dict) - 包含序列化数据集图的Python字典
- **json_filepath** (str) - JSON文件的路径,用户可通过 `mindspore.dataset.serialize()` 接口生成。
- **input_dict** (dict) - 以Python字典存储的数据处理管道。默认值:None
- **json_filepath** (str) - 数据处理管道JSON文件的路径,该文件以通用JSON格式存储了数据处理管道信息,用户可通过 `mindspore.dataset.serialize()` 接口生成。默认值:None。

**返回:**

成功时,返回Dataset对象;失败时,则返回None。
当反序列化成功时,将返回Dataset对象;当无法被反序列化时,deserialize将会失败,且返回None。

**异常:**

**OSError:** 无法打开JSON文件
- **OSError:** - `json_filepath` 不为None且JSON文件解析失败时

**样例:**

@@ -28,9 +28,8 @@ mindspore.dataset.deserialize
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
>>> # 用例1:序列化/反序列化 JSON文件
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> dataset = ds.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
>>> # 用例2:序列化/反序列化 Python字典
>>> serialized_data = ds.engine.serialize(dataset)
>>> dataset = ds.engine.deserialize(input_dict=serialized_data)

>>> serialized_data = ds.serialize(dataset)
>>> dataset = ds.deserialize(input_dict=serialized_data)

+ 45
- 16
docs/api/api_python/dataset/mindspore.dataset.utils.imshow_det_bbox.rst View File

@@ -7,22 +7,51 @@

**参数:**

- **image** (ndarray): 待绘制的图像,shape为(C, H, W)或(H, W, C),通道顺序为RGB。
- **bboxes** (ndarray): 边界框(包含类别置信度),shape为(N, 4)或(N, 5),格式为(N,X,Y,W,H)
- **labels** (ndarray): 边界框的类别,shape为(N, 1)。
- **segm** (ndarray): 图像分割掩码,shape为(M, H, W),M表示类别总数(默认值None,不绘制掩码)
- **class_names** (list[str], dict): 类别索引到类别名的映射表(默认值None,仅显示类别索引)
- **score_threshold** (float): 绘制边界框的类别置信度阈值(默认值0,绘制所有边界框)
- **bbox_color** (tuple(int)): 指定绘制边界框时线条的颜色,顺序为BGR(默认值(0,255,0),表示'green')
- **text_color** (tuple(int)):指定类别文本的显示颜色,顺序为BGR(默认值(203, 192, 255),表示'pink')
- **mask_color** (tuple(int)):指定掩码的显示颜色,顺序为BGR(默认值(128, 0, 128),表示'purple')
- **thickness** (int): 指定边界框和类别文本的线条粗细(默认值2)
- **font_size** (int, float): 指定类别文本字体大小(默认值0.8)
- **show** (bool): 是否显示图像(默认值为True)
- **win_name** (str): 指定窗口名称(默认值"win")
- **wait_time** (int): 指定cv2.waitKey的时延,单位为ms,即图像显示的自动切换间隔(默认值2000,表示间隔为2000ms)
- **out_file** (str, optional): 输出图像的文件名,用于在绘制后将结果存储到本地(默认值None,不保存)
- **image** (numpy.ndarray) - 待绘制的图像,shape为(C, H, W)或(H, W, C),通道顺序为RGB。
- **bboxes** (numpy.ndarray) - 边界框(包含类别置信度),shape为(N, 4)或(N, 5),格式为(N,X,Y,W,H)
- **labels** (numpy.ndarray) - 边界框的类别,shape为(N, 1)。
- **segm** (numpy.ndarray) - 图像分割掩码,shape为(M, H, W),M表示类别总数,默认值:None,不绘制掩码
- **class_names** (list[str], dict) - 类别索引到类别名的映射表,默认值:None,仅显示类别索引
- **score_threshold** (float) - 绘制边界框的类别置信度阈值,默认值:0,绘制所有边界框
- **bbox_color** (tuple(int)) - 指定绘制边界框时线条的颜色,顺序为BGR,默认值:(0,255,0),表示绿色
- **text_color** (tuple(int)) - 指定类别文本的显示颜色,顺序为BGR,默认值:(203, 192, 255),表示粉色
- **mask_color** (tuple(int)) - 指定掩码的显示颜色,顺序为BGR,默认值:(128, 0, 128),表示紫色
- **thickness** (int) - 指定边界框和类别文本的线条粗细,默认值:2
- **font_size** (int, float) - 指定类别文本字体大小,默认值:0.8
- **show** (bool) - 是否显示图像,默认值:True
- **win_name** (str) - 指定窗口名称,默认值:"win"
- **wait_time** (int) - 指定cv2.waitKey的时延,单位为ms,即图像显示的自动切换间隔,默认值:2000,表示间隔为2000ms
- **out_file** (str, optional) - 输出图像的文件路径,用于在绘制后将结果存储到本地,默认值:None,不保存

**返回:**

ndarray,带边界框和类别置信度的图像。
numpy.ndarray,带边界框和类别置信度的图像。

**样例:**

>>> import numpy as np
>>> from mindspore.dataset.utils.browse_dataset import imshow_det_bbox
>>>
>>> # 读取VOC数据集.
>>> voc_dataset_dir = "/path/to/voc_dataset_directory"
>>> dataset = ds.VOCDataset(voc_dataset_dir, task="Detection", shuffle=False, decode=True, num_samples=5)
>>> dataset_iter = dataset.create_dict_iterator(output_numpy=True, num_epochs=1)
>>>
>>> # 调用imshow_det_bbox自动标注图像
>>> for index, data in enumerate(dataset_iter):
... image = data["image"]
... bbox = data["bbox"]
... label = data["label"]
... # draw image with bboxes
... imshow_det_bbox(image, bbox, label,
... class_names=['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
... 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
... 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'],
... win_name="my_window",
... wait_time=5000,
... show=True,
... out_file="voc_dataset_{}.jpg".format(str(index)))

**`imshow_det_bbox` 在VOC2012数据集的使用图示:**

.. image:: api_img/browse_dataset.png

Loading…
Cancel
Save