|
|
|
@@ -16,14 +16,19 @@ |
|
|
|
#include "runtime/device/ascend/ascend_device_address.h" |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include <unordered_map> |
|
|
|
#include <utility> |
|
|
|
#include <set> |
|
|
|
#include <algorithm> |
|
|
|
#include "runtime/mem.h" |
|
|
|
#include "runtime/device/kernel_runtime_manager.h" |
|
|
|
#include "runtime/device/kernel_runtime.h" |
|
|
|
#include "runtime/device/convert_tensor_utils.h" |
|
|
|
#include "ir/dtype/type.h" |
|
|
|
#include "ir/tensor.h" |
|
|
|
#include "backend/kernel_compiler/common_utils.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "common/utils.h" |
|
|
|
#include "common/trans.h" |
|
|
|
@@ -34,6 +39,58 @@ |
|
|
|
#include "debug/tensor_load.h" |
|
|
|
#endif |
|
|
|
|
|
|
|
namespace { |
|
|
|
const std::unordered_map<mindspore::TypeId, std::string> type_id_name_map = { |
|
|
|
{mindspore::kNumberTypeBool, "bool"}, {mindspore::kNumberTypeInt8, "int8"}, |
|
|
|
{mindspore::kNumberTypeInt16, "int16"}, {mindspore::kNumberTypeInt32, "int32"}, |
|
|
|
{mindspore::kNumberTypeInt64, "int64"}, {mindspore::kNumberTypeFloat16, "float16"}, |
|
|
|
{mindspore::kNumberTypeFloat32, "float32"}, {mindspore::kNumberTypeUInt8, "uint8"}, |
|
|
|
{mindspore::kNumberTypeUInt16, "uint16"}, {mindspore::kNumberTypeUInt32, "uint32"}, |
|
|
|
{mindspore::kNumberTypeUInt64, "uint64"}}; |
|
|
|
const std::set<std::pair<std::string, std::string>> use_trans_data = { |
|
|
|
std::make_pair("float16", mindspore::kOpFormat_NC1HWC0), std::make_pair("float32", mindspore::kOpFormat_NC1HWC0), |
|
|
|
std::make_pair("bool", mindspore::kOpFormat_NC1HWC0), std::make_pair("float32", mindspore::kOpFormat_FRAC_Z), |
|
|
|
std::make_pair("float16", mindspore::kOpFormat_FRAC_Z), std::make_pair("float16", mindspore::kOpFormat_FRAC_NZ), |
|
|
|
std::make_pair("float32", mindspore::kOpFormat_FRAC_NZ), std::make_pair("int32", mindspore::kOpFormat_FRAC_NZ), |
|
|
|
std::make_pair("float16", mindspore::kOpFormat_NHWC), std::make_pair("float32", mindspore::kOpFormat_NHWC), |
|
|
|
std::make_pair("int8", mindspore::kOpFormat_NHWC), std::make_pair("int16", mindspore::kOpFormat_NHWC), |
|
|
|
std::make_pair("int32", mindspore::kOpFormat_NHWC), std::make_pair("int64", mindspore::kOpFormat_NHWC), |
|
|
|
std::make_pair("uint8", mindspore::kOpFormat_NHWC), std::make_pair("uint16", mindspore::kOpFormat_NHWC), |
|
|
|
std::make_pair("uint32", mindspore::kOpFormat_NHWC), std::make_pair("uint64", mindspore::kOpFormat_NHWC), |
|
|
|
std::make_pair("float16", mindspore::kOpFormat_HWCN), std::make_pair("float32", mindspore::kOpFormat_HWCN), |
|
|
|
std::make_pair("int8", mindspore::kOpFormat_HWCN), std::make_pair("int16", mindspore::kOpFormat_HWCN), |
|
|
|
std::make_pair("int32", mindspore::kOpFormat_HWCN), std::make_pair("int64", mindspore::kOpFormat_HWCN), |
|
|
|
std::make_pair("uint8", mindspore::kOpFormat_HWCN), std::make_pair("uint16", mindspore::kOpFormat_HWCN), |
|
|
|
std::make_pair("uint32", mindspore::kOpFormat_HWCN), std::make_pair("uint64", mindspore::kOpFormat_HWCN)}; |
|
|
|
constexpr auto src_format = "src_format"; |
|
|
|
constexpr auto dst_format = "dst_format"; |
|
|
|
constexpr auto src = "src_0"; |
|
|
|
constexpr auto dst = "dst"; |
|
|
|
constexpr auto param_type_required = "required"; |
|
|
|
constexpr auto gen_model_single = "single"; |
|
|
|
constexpr auto trans_data = "trans_data"; |
|
|
|
constexpr auto platform_tbe = "TBE"; |
|
|
|
constexpr auto name = "name"; |
|
|
|
constexpr auto valid = "valid"; |
|
|
|
constexpr auto value = "value"; |
|
|
|
constexpr auto dtype = "dtype"; |
|
|
|
constexpr auto format_str = "format"; |
|
|
|
constexpr auto ori_format = "ori_format"; |
|
|
|
constexpr auto ori_shape = "ori_shape"; |
|
|
|
constexpr auto param_type = "param_type"; |
|
|
|
constexpr auto shape_str = "shape"; |
|
|
|
constexpr auto process_aicore = "aicore"; |
|
|
|
constexpr auto gen_model_str = "gen_model"; |
|
|
|
constexpr auto impl_path_str = "impl_path"; |
|
|
|
constexpr auto attrs_str = "attrs"; |
|
|
|
constexpr auto inputs_str = "inputs"; |
|
|
|
constexpr auto outputs_str = "outputs"; |
|
|
|
constexpr auto kernel_name_str = "kernel_name"; |
|
|
|
constexpr auto op_info_str = "op_info"; |
|
|
|
constexpr auto platform_str = "platform"; |
|
|
|
constexpr auto fractal_z = "FRACTAL_Z"; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace device { |
|
|
|
namespace ascend { |
|
|
|
@@ -96,6 +153,102 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
size_t GetCommonAlignSize(size_t input_size) { |
|
|
|
return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; |
|
|
|
} |
|
|
|
|
|
|
|
nlohmann::json ConstructAttrs(const std::string &format) { |
|
|
|
nlohmann::json real_attr; |
|
|
|
nlohmann::json src_attr; |
|
|
|
nlohmann::json des_attr; |
|
|
|
src_attr[name] = src_format; |
|
|
|
src_attr[valid] = true; |
|
|
|
if (format == kOpFormat_FRAC_Z) { |
|
|
|
src_attr[value] = fractal_z; |
|
|
|
} else { |
|
|
|
src_attr[value] = format; |
|
|
|
} |
|
|
|
des_attr[name] = dst_format; |
|
|
|
des_attr[valid] = true; |
|
|
|
des_attr[value] = kOpFormat_NCHW; |
|
|
|
real_attr.push_back(src_attr); |
|
|
|
real_attr.push_back(des_attr); |
|
|
|
return real_attr; |
|
|
|
} |
|
|
|
|
|
|
|
nlohmann::json ConstructInputs(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape, |
|
|
|
const std::string &format, mindspore::TypeId type) { |
|
|
|
nlohmann::json input; |
|
|
|
nlohmann::json input_json; |
|
|
|
nlohmann::json real_input; |
|
|
|
real_input[dtype] = type_id_name_map.at(type); |
|
|
|
if (format == kOpFormat_FRAC_Z) { |
|
|
|
real_input[format_str] = fractal_z; |
|
|
|
} else { |
|
|
|
real_input[format_str] = format; |
|
|
|
} |
|
|
|
real_input[name] = src; |
|
|
|
real_input[ori_format] = kOpFormat_NCHW; |
|
|
|
for (auto shape : output_shape) { |
|
|
|
real_input[ori_shape].push_back(shape); |
|
|
|
} |
|
|
|
real_input[param_type] = param_type_required; |
|
|
|
// obtain inputs shape |
|
|
|
for (auto shape : input_shape) { |
|
|
|
real_input[shape_str].push_back(shape); |
|
|
|
} |
|
|
|
real_input[valid] = true; |
|
|
|
input_json.push_back(real_input); |
|
|
|
input.push_back(input_json); |
|
|
|
return input; |
|
|
|
} |
|
|
|
|
|
|
|
nlohmann::json ConstructOutputs(const std::vector<size_t> &output_shape, mindspore::TypeId type) { |
|
|
|
nlohmann::json output; |
|
|
|
nlohmann::json output_json; |
|
|
|
nlohmann::json real_output; |
|
|
|
real_output[dtype] = type_id_name_map.at(type); |
|
|
|
real_output[format_str] = kOpFormat_NCHW; |
|
|
|
real_output[name] = dst; |
|
|
|
real_output[ori_format] = kOpFormat_NCHW; |
|
|
|
for (auto shape : output_shape) { |
|
|
|
real_output[ori_shape].push_back(shape); |
|
|
|
} |
|
|
|
real_output[param_type] = param_type_required; |
|
|
|
// obtain outputs shape |
|
|
|
for (auto shape : output_shape) { |
|
|
|
real_output[shape_str].push_back(shape); |
|
|
|
} |
|
|
|
real_output[valid] = true; |
|
|
|
output_json.push_back(real_output); |
|
|
|
output.push_back(output_json); |
|
|
|
return output; |
|
|
|
} |
|
|
|
|
|
|
|
nlohmann::json ConstructTransDataKernelJson(const std::vector<size_t> &host_shape, |
|
|
|
const std::vector<size_t> &device_shape, const std::string &format, |
|
|
|
mindspore::TypeId type) { |
|
|
|
// generate kernel json |
|
|
|
nlohmann::json kernel_json; |
|
|
|
kernel_json[gen_model_str] = gen_model_single; |
|
|
|
kernel_json[impl_path_str] = ""; |
|
|
|
// construct op_info |
|
|
|
nlohmann::json op_info; |
|
|
|
op_info[attrs_str] = ConstructAttrs(format); |
|
|
|
op_info[inputs_str] = ConstructInputs(device_shape, host_shape, format, type); |
|
|
|
op_info[kernel_name_str] = ""; |
|
|
|
op_info[name] = trans_data; |
|
|
|
op_info[outputs_str] = ConstructOutputs(host_shape, type); |
|
|
|
kernel_json[op_info_str] = op_info; |
|
|
|
kernel_json[platform_str] = platform_tbe; |
|
|
|
std::string json_str = kernel_json[op_info_str].dump(); |
|
|
|
size_t hash_id = std::hash<std::string>()(json_str); |
|
|
|
const std::string op_name = op_info[name]; |
|
|
|
const std::string json_name = op_name + "_" + std::to_string(hash_id); |
|
|
|
kernel_json[op_info_str][kernel_name_str] = json_name; |
|
|
|
return kernel_json; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendDeviceAddress::SyncStream() const { |
|
|
|
MS_LOG(INFO) << "Start!"; |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
@@ -158,31 +311,186 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t |
|
|
|
return sync_ok; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, void *output_address_ptr, |
|
|
|
size_t output_size, const std::vector<size_t> &workspace_size_list) const { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod_ptr); |
|
|
|
auto input_address = std::make_shared<kernel::Address>(); |
|
|
|
MS_EXCEPTION_IF_NULL(input_address); |
|
|
|
input_address->addr = ptr_; |
|
|
|
input_address->size = size_; |
|
|
|
auto output_address = std::make_shared<kernel::Address>(); |
|
|
|
MS_EXCEPTION_IF_NULL(output_address); |
|
|
|
output_address->addr = output_address_ptr; |
|
|
|
output_address->size = output_size; |
|
|
|
AddressPtrList kernel_inputs = {input_address}; |
|
|
|
AddressPtrList kernel_outputs = {output_address}; |
|
|
|
AddressPtrList kernel_workspaces; |
|
|
|
std::vector<void *> workspaces_address_ptr(workspace_size_list.size(), nullptr); |
|
|
|
if (!workspace_size_list.empty()) { |
|
|
|
for (size_t i = 0; i < workspace_size_list.size(); ++i) { |
|
|
|
auto workspace_size = GetCommonAlignSize(workspace_size_list[i]); |
|
|
|
auto ret_malloc = rtMalloc(&workspaces_address_ptr[i], workspace_size, RT_MEMORY_HBM); |
|
|
|
if (ret_malloc != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Failed to rtMalloc memory"; |
|
|
|
} |
|
|
|
auto workspace_address = std::make_shared<kernel::Address>(); |
|
|
|
MS_EXCEPTION_IF_NULL(workspace_address); |
|
|
|
workspace_address->addr = workspaces_address_ptr[i]; |
|
|
|
workspace_address->size = workspace_size; |
|
|
|
kernel_workspaces.push_back(workspace_address); |
|
|
|
} |
|
|
|
} |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
auto device_id = ms_context->device_id(); |
|
|
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(runtime_instance); |
|
|
|
auto ret = |
|
|
|
runtime_instance->LaunchTaskBasedOnSingleKernel(kernel_mod_ptr, kernel_inputs, kernel_outputs, kernel_workspaces); |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(ERROR) << "Launch kernel failed."; |
|
|
|
} |
|
|
|
SyncStream(); |
|
|
|
if (!workspace_size_list.empty()) { |
|
|
|
for (size_t i = 0; i < workspace_size_list.size(); ++i) { |
|
|
|
auto ret_free = rtFree(workspaces_address_ptr[i]); |
|
|
|
if (ret_free != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Failed to rtFree memory"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const { |
|
|
|
static std::set<std::string> constructed_kernel; |
|
|
|
auto build_manager = std::make_shared<kernel::ParallelBuildManager>(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_manager); |
|
|
|
std::string processor = process_aicore; |
|
|
|
// get size |
|
|
|
std::vector<size_t> input_size_list; |
|
|
|
std::vector<size_t> output_size_list; |
|
|
|
(void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); |
|
|
|
std::string json_name = kernel_json[op_info_str][kernel_name_str]; |
|
|
|
// op build |
|
|
|
if (constructed_kernel.find(json_name) == constructed_kernel.end()) { |
|
|
|
auto task_id = build_manager->StartCompileOp(kernel_json); |
|
|
|
build_manager->SaveTaskInfo(task_id, nullptr, json_name, input_size_list, output_size_list); |
|
|
|
} |
|
|
|
while (!build_manager->IsAllTaskFinish()) { |
|
|
|
int task_id = -1; |
|
|
|
char *task_result = nullptr; |
|
|
|
char *pre_build_result = nullptr; |
|
|
|
auto ret = build_manager->WaitOne(&task_id, &task_result, &pre_build_result); |
|
|
|
if (!ret) { |
|
|
|
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; |
|
|
|
} |
|
|
|
if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { |
|
|
|
MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; |
|
|
|
} |
|
|
|
(void)build_manager->TaskFinishProcess(task_id, false); |
|
|
|
} |
|
|
|
constructed_kernel.insert(json_name); |
|
|
|
// search cache |
|
|
|
auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); |
|
|
|
MS_EXCEPTION_IF_NULL(cached_kernel_pack); |
|
|
|
auto kernel_mod_ptr = |
|
|
|
build_manager->GenKernelMod(json_name, processor, input_size_list, output_size_list, cached_kernel_pack); |
|
|
|
return kernel_mod_ptr; |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector<size_t> &host_shape, |
|
|
|
const std::vector<size_t> &device_shape, |
|
|
|
size_t size, mindspore::TypeId type, |
|
|
|
void *host_ptr) const { |
|
|
|
bool sync_ok = true; |
|
|
|
// construct trans data kernel json |
|
|
|
nlohmann::json kernel_json = ConstructTransDataKernelJson(host_shape, device_shape, format_, type_id_); |
|
|
|
MS_LOG(INFO) << "Construct trans_data kernel json: " << kernel_json.dump(); |
|
|
|
auto kernel_mod_ptr = CompileTransDataAndObtainKernelMod(kernel_json); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod_ptr); |
|
|
|
auto host_size = size; |
|
|
|
if (type_id_ != type) { |
|
|
|
auto device_dtype_size = trans::TypeIdSize(type_id_); |
|
|
|
if (device_dtype_size < 1) { |
|
|
|
MS_LOG(ERROR) << "Illegal dtype."; |
|
|
|
} |
|
|
|
auto shape_size = trans::ShapeSize(host_shape); |
|
|
|
auto size_tmp = device_dtype_size * shape_size; |
|
|
|
size = GetCommonAlignSize(size_tmp); |
|
|
|
} |
|
|
|
void *output_address_ptr = nullptr; |
|
|
|
auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM); |
|
|
|
if (ret_malloc != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Failed to rtMalloc memory"; |
|
|
|
} |
|
|
|
auto workspace_size_list = GetWorkspaceSizeList(kernel_json); |
|
|
|
// launch |
|
|
|
LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list); |
|
|
|
if (type_id_ == type) { |
|
|
|
SyncMemory(host_ptr, output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
} else { |
|
|
|
auto host = std::vector<uint8_t>(size); |
|
|
|
SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
auto shape_size = trans::ShapeSize(host_shape); |
|
|
|
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size}; |
|
|
|
sync_ok = trans::TransDataType(type_args, host_ptr); |
|
|
|
if (!sync_ok) { |
|
|
|
MS_LOG(ERROR) << "Trans format failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
auto ret_free = rtFree(output_address_ptr); |
|
|
|
if (ret_free != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Failed to rtFree memory"; |
|
|
|
} |
|
|
|
return sync_ok; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::json &kernel_json) const { |
|
|
|
std::string json_name = kernel_json[op_info_str][kernel_name_str]; |
|
|
|
std::string processor = process_aicore; |
|
|
|
auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); |
|
|
|
MS_EXCEPTION_IF_NULL(cached_kernel_pack); |
|
|
|
auto kernel_json_info = cached_kernel_pack->kernel_json_info(); |
|
|
|
return kernel_json_info.workspaces; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const { |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { |
|
|
|
device_shape = trans::TransShapeToDevice(*host_shape, format_); |
|
|
|
} else { |
|
|
|
if (host_shape_.empty()) { |
|
|
|
*host_shape = trans::PaddingShapeTo4d(*host_shape); |
|
|
|
} else { |
|
|
|
host_shape->clear(); |
|
|
|
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), IntToSize); |
|
|
|
} |
|
|
|
device_shape = trans::TransShapeToDevice(*host_shape, format_); |
|
|
|
} |
|
|
|
return device_shape; |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, |
|
|
|
mindspore::TypeId type, void *host_ptr) const { |
|
|
|
MS_LOG(INFO) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) |
|
|
|
<< ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; |
|
|
|
bool sync_ok = false; |
|
|
|
auto host_tmp = std::vector<uint8_t>(size_); |
|
|
|
SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
std::vector<size_t> host_shape; |
|
|
|
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); |
|
|
|
std::vector<size_t> device_shape; |
|
|
|
if (host_shape.empty()) { |
|
|
|
host_shape.emplace_back(1); |
|
|
|
} |
|
|
|
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { |
|
|
|
device_shape = trans::TransShapeToDevice(host_shape, format_); |
|
|
|
} else { |
|
|
|
if (host_shape_.empty()) { |
|
|
|
host_shape = trans::PaddingShapeTo4d(host_shape); |
|
|
|
} else { |
|
|
|
host_shape.clear(); |
|
|
|
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize); |
|
|
|
std::vector<size_t> device_shape = GetDeviceShape(&host_shape); |
|
|
|
if (type_id_name_map.find(type_id_) != type_id_name_map.end()) { |
|
|
|
std::pair<std::string, std::string> type_format = std::make_pair(type_id_name_map.at(type_id_), format_); |
|
|
|
if (use_trans_data.find(type_format) != use_trans_data.end()) { |
|
|
|
sync_ok = SyncDeviceToHostAndConvertFormatBasedOnTransData(host_shape, device_shape, size, type, host_ptr); |
|
|
|
return sync_ok; |
|
|
|
} |
|
|
|
|
|
|
|
device_shape = trans::TransShapeToDevice(host_shape, format_); |
|
|
|
} |
|
|
|
auto host_tmp = std::vector<uint8_t>(size_); |
|
|
|
SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); |
|
|
|
if (type_id_ != type) { |
|
|
|
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, |
|
|
|
host_shape, device_shape, type_id_}; |
|
|
|
|