Browse Source

fix activation grad bug

pull/15111/head
zhengjun10 4 years ago
parent
commit
a9e5529496
2 changed files with 7 additions and 6 deletions
  1. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc
  2. +6
    -5
      mindspore/lite/tools/anf_exporter/anf_exporter.cc

+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc View File

@@ -66,7 +66,7 @@ int ActivationGradCPUKernel::DoActivation(int task_id) {
// Sigmoid gets the input tensors in reverse order!
error_code = SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationType_TANH) {
error_code = TanhGrad(yt_addr + start, input_addr + start, count, output_addr + start);
error_code = TanhGrad(input_addr + start, yt_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) {
error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) {


+ 6
- 5
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -854,11 +854,12 @@ int AnfExporter::ProcessValueSequence(const ValueNodePtr &value_node, std::uniqu
(*schema_tensor)->dims = {static_cast<int32_t>(shape.size())};
(*schema_tensor)->nodeType = NodeType_ValueNode;
(*schema_tensor)->data.resize(shape.size() * sizeof(int));
ret = memcpy_s((*schema_tensor)->data.data(), shape.size() * sizeof(int32_t), shape.data(),
shape.size() * sizeof(int32_t));
if (ret != RET_OK) {
MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed.";
return RET_ERROR;
if (!shape.empty()) {
if (EOK != memcpy_s((*schema_tensor)->data.data(), shape.size() * sizeof(int32_t), shape.data(),
shape.size() * sizeof(int32_t))) {
MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed.";
return RET_MEMORY_FAILED;
}
}
node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());


Loading…
Cancel
Save