Browse Source

!3330 [AutoParallel]Support dataset in GPU

Merge pull request !3330 from lichen/autoparallel_support_dataset_in_gpu
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
08bafed565
2 changed files with 8 additions and 2 deletions
  1. +4
    -2
      mindspore/ccsrc/frontend/parallel/context.cc
  2. +4
    -0
      mindspore/train/model.py

+ 4
- 2
mindspore/ccsrc/frontend/parallel/context.cc View File

@@ -44,7 +44,10 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
return inst_context_;
}

ParallelContext::ParallelContext() { Reset(); }
ParallelContext::ParallelContext() {
communication_backend_ = HCCL_BACKEND;
Reset();
}

void ParallelContext::Reset() {
mirror_mean_ = false;
@@ -53,7 +56,6 @@ void ParallelContext::Reset() {
loss_repeated_mean_ = true;
device_num_ = 1;
global_rank_ = 0;
communication_backend_ = HCCL_BACKEND;
device_num_is_set_ = false;
global_rank_is_set_ = false;
parallel_mode_ = STAND_ALONE;


+ 4
- 0
mindspore/train/model.py View File

@@ -30,6 +30,8 @@ from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ._utils import _to_full_tensor
from ..parallel._utils import _need_to_full
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper
from . import amp
@@ -418,6 +420,8 @@ class Model:

# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
if _need_to_full():
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
cb_params.cur_step_num += dataset_helper.sink_size()


Loading…
Cancel
Save