Browse Source

fix bug

tags/v0.6.0-beta
jinyaohui 5 years ago
parent
commit
861546921c
2 changed files with 10 additions and 7 deletions
  1. +9
    -6
      mindspore/ccsrc/utils/tensorprint_utils.cc
  2. +1
    -1
      mindspore/context.py

+ 9
- 6
mindspore/ccsrc/utils/tensorprint_utils.cc View File

@@ -214,10 +214,10 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {

bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::string &print_file_path, prntpb::Print print,
std::fstream *output) {
bool ret_end_sequence = false;
bool ret_end_thread = false;
for (auto &item : items) {
if (item.dataType_ == tdt::TDT_END_OF_SEQUENCE) {
ret_end_sequence = true;
ret_end_thread = true;
break;
}
prntpb::Print_Value *value = print.add_value();
@@ -225,14 +225,16 @@ bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::strin
MS_EXCEPTION_IF_NULL(str_data_ptr);
if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) {
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
MS_LOG(ERROR) << "Print op receive data length is invalid.";
ret_end_thread = true;
}
}

std::vector<int> tensor_shape;
size_t totaldims = 1;
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) {
MS_LOG(EXCEPTION) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_;
MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_;
ret_end_thread = true;
}

if (item.tensorType_ == "string") {
@@ -252,11 +254,12 @@ bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::strin
}

if (!print.SerializeToOstream(output)) {
MS_LOG(EXCEPTION) << "Save print file:" << print_file_path << " fail.";
MS_LOG(ERROR) << "Save print file:" << print_file_path << " fail.";
ret_end_thread = true;
}
print.Clear();
}
return ret_end_sequence;
return ret_end_thread;
}

void TensorPrint::operator()() {


+ 1
- 1
mindspore/context.py View File

@@ -564,7 +564,7 @@ def set_context(**kwargs):
max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU.
The format is "xxGB". Default: "1024GB".
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by default,and turn off printing to the screen.
a file by default, and turn off printing to the screen.

Raises:
ValueError: If input key is not an attribute in context.


Loading…
Cancel
Save