Browse Source

sync data for cpu

tags/v0.5.0-beta
kswang 5 years ago
parent
commit
83ff625d52
3 changed files with 20 additions and 8 deletions
  1. +4
    -0
      mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc
  2. +8
    -6
      mindspore/ccsrc/session/gpu_session.cc
  3. +8
    -2
      mindspore/ccsrc/vm/transform.cc

+ 4
- 0
mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc View File

@@ -192,8 +192,12 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
if (item->isa<Parameter>()) {
auto address = AnfAlgo::GetMutableOutputAddr(item, 0);
auto tensor = inputs[input_idx];
auto tensor_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(tensor);
if (tensor_address != nullptr && tensor_address != address) {
(void)tensor->data_sync();
}
std::vector<int> data_shape = tensor->shape();
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) {


+ 8
- 6
mindspore/ccsrc/session/gpu_session.cc View File

@@ -103,17 +103,19 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>();
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
auto tensor_address = tensor->device_address();
bool need_sync = false;
if (ms_context->enable_pynative_infer()) {
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
if (tensor_address.get() == nullptr || tensor_address != device_address) {
need_sync = true;
}
} else {
if (tensor->is_dirty()) {
} else if (tensor->is_dirty()) {
need_sync = true;
} else if (tensor_address != device_address) {
if (tensor_address->DeviceType() == device_address->DeviceType()) {
AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get());
} else {
need_sync = true;
} else if (tensor->device_address() != device_address) {
AnfAlgo::SetOutputAddr(tensor->device_address(), 0, pk_node.get());
need_sync = false;
}
}
if (need_sync) {


+ 8
- 2
mindspore/ccsrc/vm/transform.cc View File

@@ -76,9 +76,15 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
return default_target;
}
auto primitive = value->cast<PrimitivePtr>();
ValuePtr att_target = primitive->GetAttr("primitive_target");
auto att_target = primitive->GetAttr("primitive_target");
if (att_target != nullptr) {
std::string target = GetValue<std::string>(att_target);
if (!att_target->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
auto target = GetValue<std::string>(att_target);
if (kTargetSet.find(target) == kTargetSet.end()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
return target;
}
return default_target;


Loading…
Cancel
Save