Browse Source

fix minddata bugs

tags/v1.5.0-rc1
YangLuo 4 years ago
parent
commit
a00a2fb346
4 changed files with 12 additions and 2 deletions
  1. +3
    -0
      mindspore/ccsrc/minddata/dataset/api/execute.cc
  2. +1
    -1
      mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc
  3. +6
    -0
      mindspore/dataset/callback/ds_callback.py
  4. +2
    -1
      mindspore/dataset/engine/__init__.py

+ 3
- 0
mindspore/ccsrc/minddata/dataset/api/execute.cc View File

@@ -226,6 +226,7 @@ Execute::~Execute() {
Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
// Validate input tensor
CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data.");
CHECK_FAIL_RETURN_UNEXPECTED(output != nullptr, "Output Tensor can not be nullptr.");
CHECK_FAIL_RETURN_UNEXPECTED(ValidateDevice(), "Device Type should be 'Ascend310' or 'CPU'.");

// Parse TensorTransform transforms_ into TensorOperation ops_
@@ -311,6 +312,8 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor
Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
// Validate input tensor
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid.");
CHECK_FAIL_RETURN_UNEXPECTED(output_tensor_list != nullptr, "Output Tensor can not be nullptr.");
output_tensor_list->clear();
for (auto &tensor : input_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data.");
}


+ 1
- 1
mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc View File

@@ -55,7 +55,7 @@ SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path
Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
if (!model_status_.IsOk()) {
return model_status_;
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, model_status_.GetErrDescription());
}

if (input->Rank() != 0 || input->type() != DataType::DE_STRING) {


+ 6
- 0
mindspore/dataset/callback/ds_callback.py View File

@@ -238,6 +238,12 @@ class WaitedDSCallback(Callback, DSCallback):
return c_cb

def end(self, run_context):
"""
Internal method, release the wait if training is ended.

Args:
run_context: Include some information of the model.
"""
self.epoch_end(run_context)
self.step_end(run_context)
self.training_ended = True

+ 2
- 1
mindspore/dataset/engine/__init__.py View File

@@ -30,6 +30,7 @@ from .graphdata import GraphData, SamplingStrategy, OutputFormat
from .iterators import *
from .samplers import *
from .serializer_deserializer import compare, deserialize, serialize, show
from ..utils import imshow_det_bbox

__all__ = ["CelebADataset", "Cifar100Dataset", "Cifar10Dataset", "CLUEDataset", "CocoDataset", "CSVDataset",
"GeneratorDataset", "GraphData", "ImageFolderDataset", "ManifestDataset", "MindDataset", "MnistDataset",
@@ -37,4 +38,4 @@ __all__ = ["CelebADataset", "Cifar100Dataset", "Cifar10Dataset", "CLUEDataset",
"DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler",
"WeightedRandomSampler", "SubsetSampler",
"DatasetCache", "DSCallback", "Schema", "WaitedDSCallback", "compare", "deserialize",
"serialize", "show", "zip"]
"imshow_det_bbox", "serialize", "show", "zip"]

Loading…
Cancel
Save