|
|
@@ -214,10 +214,10 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { |
|
|
|
|
|
|
|
|
bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::string &print_file_path, prntpb::Print print, |
|
|
bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::string &print_file_path, prntpb::Print print, |
|
|
std::fstream *output) { |
|
|
std::fstream *output) { |
|
|
bool ret_end_sequence = false; |
|
|
|
|
|
|
|
|
bool ret_end_thread = false; |
|
|
for (auto &item : items) { |
|
|
for (auto &item : items) { |
|
|
if (item.dataType_ == tdt::TDT_END_OF_SEQUENCE) { |
|
|
if (item.dataType_ == tdt::TDT_END_OF_SEQUENCE) { |
|
|
ret_end_sequence = true; |
|
|
|
|
|
|
|
|
ret_end_thread = true; |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
prntpb::Print_Value *value = print.add_value(); |
|
|
prntpb::Print_Value *value = print.add_value(); |
|
|
@@ -225,14 +225,16 @@ bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::strin |
|
|
MS_EXCEPTION_IF_NULL(str_data_ptr); |
|
|
MS_EXCEPTION_IF_NULL(str_data_ptr); |
|
|
if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) { |
|
|
if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) { |
|
|
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { |
|
|
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { |
|
|
MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Print op receive data length is invalid."; |
|
|
|
|
|
ret_end_thread = true; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<int> tensor_shape; |
|
|
std::vector<int> tensor_shape; |
|
|
size_t totaldims = 1; |
|
|
size_t totaldims = 1; |
|
|
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { |
|
|
if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { |
|
|
MS_LOG(EXCEPTION) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; |
|
|
|
|
|
ret_end_thread = true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (item.tensorType_ == "string") { |
|
|
if (item.tensorType_ == "string") { |
|
|
@@ -252,11 +254,12 @@ bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::strin |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (!print.SerializeToOstream(output)) { |
|
|
if (!print.SerializeToOstream(output)) { |
|
|
MS_LOG(EXCEPTION) << "Save print file:" << print_file_path << " fail."; |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Save print file:" << print_file_path << " fail."; |
|
|
|
|
|
ret_end_thread = true; |
|
|
} |
|
|
} |
|
|
print.Clear(); |
|
|
print.Clear(); |
|
|
} |
|
|
} |
|
|
return ret_end_sequence; |
|
|
|
|
|
|
|
|
return ret_end_thread; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void TensorPrint::operator()() { |
|
|
void TensorPrint::operator()() { |
|
|
|