Browse Source

parse acl data type for ps cache

tags/v1.6.0
ms_yan 4 years ago
parent
commit
8136ea8816
2 changed files with 19 additions and 4 deletions
  1. +11
    -4
      mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc
  2. +8
    -0
      mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h

+ 11
- 4
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc View File

@@ -77,10 +77,8 @@ Status TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_n
// Data prefetch only when PS mode enables cache.
if (acltdtGetDatasetSize(acl_dataset) > 0) {
acltdtDataItem *item0 = acltdtGetDataItem(acl_dataset, 0);
std::string item_type = "unsupported";
if (acltdtGetDataTypeFromItem(item0) == ACL_INT32) {
item_type = "int32";
}
std::string item_type;
RETURN_IF_NOT_OK(ParseType(acltdtGetDataTypeFromItem(item0), item_type));
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, acltdtGetDataAddrFromItem(item0),
acltdtGetDataSizeFromItem(item0), item_type)) {
RETURN_STATUS_UNEXPECTED("PrefetchData failed in when pre-processing sending data.");
@@ -146,6 +144,15 @@ Status TdtPlugin::getTdtType(DataType d_type, aclDataType &datatype) {
return Status::OK();
}

Status TdtPlugin::ParseType(const aclDataType &acl_data_type, std::string &data_type) {
auto type_iter = parse_map.find(acl_data_type);
if (type_iter == parse_map.end()) {
RETURN_STATUS_UNEXPECTED("Got unsupported acl datatype: " + std::to_string(acl_data_type));
}
data_type = type_iter->second;
return Status::OK();
}

Status TdtPlugin::translate(acltdtTensorType tdt_type, const TensorRow &ts_row, acltdtDataset **output_acl_dataset) {
auto acl_dataset = acltdtCreateDataset();
if (acl_dataset == nullptr) {


+ 8
- 0
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h View File

@@ -19,6 +19,7 @@
#include <dlfcn.h>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <vector>
@@ -55,11 +56,18 @@ class TdtPlugin {

Status getTdtType(DataType d_type, aclDataType &datatype);

Status ParseType(const aclDataType &acl_data_type, std::string &data_type);

Status translate(acltdtTensorType tdt_type, const TensorRow &ts_row, acltdtDataset **output_acl_dataset);

void ReportErrorMessage();

void *tdt_handle_ = nullptr;

std::map<aclDataType, std::string> parse_map = {
{ACL_INT8, "int8"}, {ACL_UINT8, "uint8"}, {ACL_INT16, "int16"}, {ACL_UINT16, "uint16"},
{ACL_INT32, "int32"}, {ACL_UINT32, "uint32"}, {ACL_INT64, "int64"}, {ACL_UINT64, "uint64"},
{ACL_FLOAT16, "float16"}, {ACL_FLOAT, "float32"}, {ACL_DOUBLE, "float64"}, {ACL_BOOL, "bool"}};
};
} // namespace dataset
} // namespace mindspore


Loading…
Cancel
Save