diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 0eb31fc422..1077099496 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -615,7 +615,7 @@ class PConstant : public PBase > { return new_vnode; } // x is not nullptr - if (x->isa()) { + if (x->isa() || x->isa()) { if ((x->abstract() == nullptr) || !x->abstract()->isa()) { return nullptr; } @@ -650,8 +650,9 @@ class PConstant : public PBase > { ret = memcpy_s(data, mem_size, source_data, mem_size); } if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size" - << new_tensor_ptr->DataSize(); + MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << mem_size << "dest size" + << new_tensor_ptr->DataSize(); + return nullptr; } auto new_vnode = NewValueNode(new_tensor_ptr); new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); @@ -735,46 +736,60 @@ class PConstant : public PBase > { auto tensor_1_abstract = vnode_1->abstract()->cast(); auto tensor_2_abstract = vnode_1->abstract()->cast(); - auto tensor_3_abstract = node_3->abstract()->cast(); - TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); - TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); - if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || - (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { - return nullptr; - } - ShapeVector tensor_out_shape = tensor_3_abstract->shape()->shape(); - int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); - if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { - return nullptr; - } - if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { - return nullptr; + ShapeVector tensor_out_shape; + int data_out_size; + tensor::TensorPtr new_tensor_ptr; + + if ((tensor_1_abstract->shape()->shape() == tensor_2_abstract->shape()->shape()) && + (tensor_1_type_ptr->type_id() == tensor_2_type_ptr->type_id())) { + // If two constant nodes have the same shape, then create a new one with this shape + tensor_out_shape = tensor_1_abstract->shape()->shape(); + data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); + + new_tensor_ptr = std::make_shared(tensor_1_type_ptr->type_id(), tensor_out_shape); + } else { + // If two constant nodes have different shapes, then create a new one node with the shape of the 3rd node + auto tensor_3_abstract = node_3->abstract()->cast(); + + TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); + if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || + (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { + return nullptr; + } + tensor_out_shape = tensor_3_abstract->shape()->shape(); + data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); + if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { + return nullptr; + } + if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { + return nullptr; + } + new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); } - auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); - size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + size_t mem_size = GetTypeByte(new_tensor_ptr->Dtype()) * IntToSize(new_tensor_ptr->ElementsNum()); char *data = reinterpret_cast(new_tensor_ptr->data_c()); int ret = 0; void *data_out = nullptr; - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { + if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat32) || + (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat)) { Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), &data_out, data_out_size); ret = memcpy_s(data, mem_size, data_out, mem_size); delete[] reinterpret_cast(data_out); } else { - if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { + if (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat64) { Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), &data_out, data_out_size); ret = memcpy_s(data, mem_size, data_out, mem_size); delete[] reinterpret_cast(data_out); } else { - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { + if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeInt32) || + (new_tensor_ptr->data_type() == TypeId::kNumberTypeInt)) { Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), &data_out, data_out_size); ret = memcpy_s(data, mem_size, data_out, mem_size);