|
|
|
@@ -77,10 +77,8 @@ Status TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_n |
|
|
|
// Data prefetch only when PS mode enables cache. |
|
|
|
if (acltdtGetDatasetSize(acl_dataset) > 0) { |
|
|
|
acltdtDataItem *item0 = acltdtGetDataItem(acl_dataset, 0); |
|
|
|
std::string item_type = "unsupported"; |
|
|
|
if (acltdtGetDataTypeFromItem(item0) == ACL_INT32) { |
|
|
|
item_type = "int32"; |
|
|
|
} |
|
|
|
std::string item_type; |
|
|
|
RETURN_IF_NOT_OK(ParseType(acltdtGetDataTypeFromItem(item0), item_type)); |
|
|
|
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, acltdtGetDataAddrFromItem(item0), |
|
|
|
acltdtGetDataSizeFromItem(item0), item_type)) { |
|
|
|
RETURN_STATUS_UNEXPECTED("PrefetchData failed in when pre-processing sending data."); |
|
|
|
@@ -146,6 +144,15 @@ Status TdtPlugin::getTdtType(DataType d_type, aclDataType &datatype) { |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status TdtPlugin::ParseType(const aclDataType &acl_data_type, std::string &data_type) { |
|
|
|
auto type_iter = parse_map.find(acl_data_type); |
|
|
|
if (type_iter == parse_map.end()) { |
|
|
|
RETURN_STATUS_UNEXPECTED("Got unsupported acl datatype: " + std::to_string(acl_data_type)); |
|
|
|
} |
|
|
|
data_type = type_iter->second; |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status TdtPlugin::translate(acltdtTensorType tdt_type, const TensorRow &ts_row, acltdtDataset **output_acl_dataset) { |
|
|
|
auto acl_dataset = acltdtCreateDataset(); |
|
|
|
if (acl_dataset == nullptr) { |
|
|
|
|