Browse Source

Ascend MindRT bugs fix

tags/v1.6.0
hwjiaorui 4 years ago
parent
commit
5ef3c3ad7f
3 changed files with 14 additions and 7 deletions
  1. +6
    -5
      mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc
  2. +7
    -2
      mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc
  3. +1
    -0
      mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc

+ 6
- 5
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc View File

@@ -438,16 +438,17 @@ bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &shape, size_t si
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) {
return true;
}
BindDevice();

if (size_ < size) {
MS_LOG(ERROR) << "src size is greater than det size, src size is: " << size << ", dst size is: " << size_;
return false;
}
if (format_ != format || type_id_ != type) {
MS_LOG(ERROR) << "format or type is different, src(format:" << format << ", type_id:" << TypeIdLabel(type)
<< "), dst(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_);
return false;
}
if (size_ < size) {
MS_LOG(ERROR) << "src size is greater than det size, src size is: " << size << ", dst size is: " << size_;
return false;
}
BindDevice();
auto ret_rt_memcpy = aclrtMemcpy(ptr_, size, src_ptr, size, ACL_MEMCPY_DEVICE_TO_DEVICE);
if (ret_rt_memcpy != RT_ERROR_NONE) {
MS_LOG(ERROR) << "SyncDeviceToDevice failed, rtMemcpy mem size [" << size << "], ret [" << ret_rt_memcpy << "]";


+ 7
- 2
mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc View File

@@ -362,10 +362,15 @@ void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> output_format, const CNodePtr &kernel_node,
size_t input_index, bool force_fresh = false) {
MS_EXCEPTION_IF_NULL(real_input_node);
if (real_input_node->isa<CNode>() || (AnfAlgo::OutputAddrExist(real_input_node, 0) &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown)) {
if (real_input_node->isa<CNode>()) {
return;
}
if (AnfAlgo::OutputAddrExist(real_input_node, 0)) {
auto output_addr = AnfAlgo::GetOutputAddr(real_input_node, 0);
if (output_addr->GetPtr() != nullptr) {
return;
}
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool disable_convert = real_input_node->isa<Parameter>() || real_input_node->isa<ValueNode>();


+ 1
- 0
mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc View File

@@ -624,6 +624,7 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node,
std::string error_info = "Sync data error.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
}
host_tensor_address = device_tensor;
} else {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
host_tensor_address->SetNodeIndex(backend_node, 0);


Loading…
Cancel
Save