浏览代码

timemonitor

tags/v1.2.0-rc1
wanyiming 4 年前
父节点
当前提交
2edcfc2658
共有 4 个文件被更改,包括 5 次插入5 次删除
  1. +1
    -1
      mindspore/nn/layer/basic.py
  2. +1
    -1
      mindspore/ops/operations/array_ops.py
  3. +1
    -1
      mindspore/ops/operations/nn_ops.py
  4. +2
    -2
      model_zoo/research/audio/deepspeech2/train.py

+ 1
- 1
mindspore/nn/layer/basic.py 查看文件

@@ -709,7 +709,7 @@ class ResizeBilinear(Cell):
ValueError: If `size` is a list or tuple whose length is not equal to 2.

Supported Platforms:
``Ascend``
``Ascend`` ``CPU``

Examples:
>>> tensor = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], mindspore.float32)


+ 1
- 1
mindspore/ops/operations/array_ops.py 查看文件

@@ -3297,7 +3297,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
ValueError: If length of `size` is not equal to 2.

Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> input_tensor = Tensor(np.array([[[[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]]]), mindspore.float32)


+ 1
- 1
mindspore/ops/operations/nn_ops.py 查看文件

@@ -3148,7 +3148,7 @@ class ResizeBilinear(PrimitiveWithInfer):
ValueError: If length of shape of `input` is not equal to 4.

Supported Platforms:
``Ascend``
``Ascend`` ``CPU``

Examples:
>>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)


+ 2
- 2
model_zoo/research/audio/deepspeech2/train.py 查看文件

@@ -21,7 +21,7 @@ import argparse
from mindspore import context, Tensor, ParameterTuple
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Adam
from mindspore.nn import TrainOneStepCell
@@ -89,7 +89,7 @@ if __name__ == '__main__':
print('Successfully loading the pre-trained model')
model = Model(train_net)
callback_list = [LossMonitor()]
callback_list = [TimeMonitor(steps_size), LossMonitor()]
if args.is_distributed:
config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank())


正在加载...
取消
保存