From: @wangzhe128 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -22,3 +22,10 @@ int Fill(float *output, int size, float data) { | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int FillInt32(int *output, int size, int data) { | |||
| for (int i = 0; i < size; ++i) { | |||
| output[i] = data; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -35,6 +35,8 @@ typedef struct FillParameter { | |||
| extern "C" { | |||
| #endif | |||
| int Fill(float *output, int size, float data); | |||
| int FillInt32(int *output, int size, int data); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -56,26 +56,6 @@ PrimitiveC *FillCreator(const schema::Primitive *primitive) { return PrimitiveC: | |||
| Registry FillRegistry(schema::PrimitiveType_Fill, FillCreator); | |||
| #endif | |||
| template <typename T> | |||
| void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) { | |||
| int input_count = inputs[0]->ElementsNum(); | |||
| int index = 0; | |||
| int size = 1; | |||
| for (int i = 0; i < shape_size; i++) { | |||
| if (static_cast<int>(data[i]) == -1) { | |||
| index = i; | |||
| } else if (static_cast<int>(data[i]) == 0) { | |||
| size *= inputs[0]->shape().at(i); | |||
| } else { | |||
| size *= data[i]; | |||
| } | |||
| out_shape->push_back(data[i]); | |||
| } | |||
| if (static_cast<int>(data[index]) == -1) { | |||
| (*out_shape).at(index) = input_count / size; | |||
| } | |||
| } | |||
| int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive_ != nullptr); | |||
| auto input = inputs_.front(); | |||
| @@ -94,54 +74,23 @@ int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||
| return RET_INFER_INVALID; | |||
| } | |||
| std::vector<int> out_shape; | |||
| std::vector<int> output_shape; | |||
| auto param_dims = GetDims(); | |||
| for (size_t i = 0; i < param_dims.size(); i++) { | |||
| output_shape.push_back(param_dims.at(i)); | |||
| } | |||
| if (inputs_.size() == kDoubleNum) { | |||
| auto shape_tensor = inputs_.at(1); | |||
| if (shape_tensor->IsConst()) { | |||
| if (shape_tensor->data_c() == nullptr || (shape_tensor->shape().size() == 1 && shape_tensor->shape()[0] == 0)) { | |||
| MS_LOG(DEBUG) << "reshape to a scalar."; | |||
| output->set_shape(out_shape); | |||
| return RET_OK; | |||
| } | |||
| } | |||
| if (shape_tensor->data_c() == nullptr) { | |||
| MS_LOG(INFO) << "Do infer shape in runtime."; | |||
| auto input_dims = inputs_.at(1); | |||
| MS_ASSERT(input_dims != nullptr); | |||
| if (input_dims->data_c() == nullptr) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| size_t shape_size = shape_tensor->ElementsNum(); | |||
| switch (shape_tensor->data_type()) { | |||
| case kNumberTypeInt8: { | |||
| auto data = reinterpret_cast<int8_t *>(shape_tensor->MutableData()); | |||
| CalShape<int8_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| case kNumberTypeInt32: { | |||
| auto data = reinterpret_cast<int32_t *>(shape_tensor->MutableData()); | |||
| CalShape<int32_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| case kNumberTypeInt64: { | |||
| auto data = reinterpret_cast<int64_t *>(shape_tensor->MutableData()); | |||
| CalShape<int64_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| case kNumberTypeFloat: { | |||
| auto data = reinterpret_cast<float *>(shape_tensor->MutableData()); | |||
| CalShape<float>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| case kNumberTypeUInt32: { | |||
| auto data = reinterpret_cast<uint32_t *>(shape_tensor->MutableData()); | |||
| CalShape<uint32_t>(data, inputs_, &out_shape, shape_size); | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); | |||
| return RET_INFER_ERR; | |||
| } | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < GetDims().size(); i++) { | |||
| out_shape.push_back(GetDims().at(i)); | |||
| } | |||
| int *dims_data = reinterpret_cast<int *>(input_dims->data_c()); | |||
| output_shape = std::vector<int>{dims_data, dims_data + input_dims->ElementsNum()}; | |||
| } | |||
| output->set_shape(out_shape); | |||
| output->set_shape(output_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -116,6 +116,15 @@ int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o | |||
| MS_ASSERT(output != nullptr); | |||
| std::vector<int> perm = GetPerm(); | |||
| if (inputs_.size() == kDoubleNum) { | |||
| auto input_perm = inputs_.at(1); | |||
| MS_ASSERT(input_perm != nullptr); | |||
| if (input_perm->data_c() == nullptr) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| int *perm_data = reinterpret_cast<int *>(input_perm->data_c()); | |||
| perm = std::vector<int>{perm_data, perm_data + input_perm->ElementsNum()}; | |||
| } | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| std::vector<int> in_shape = input->shape(); | |||
| @@ -48,7 +48,15 @@ int FillCPUKernel::DoFill(int task_id) { | |||
| return RET_OK; | |||
| } | |||
| int offset = task_id * thread_sz_stride_; | |||
| int ret = Fill(out_ptr_ + offset, size, src_data_); | |||
| auto input_tensor = in_tensors_.at(0); | |||
| int ret = RET_OK; | |||
| if (input_tensor->data_type() == kNumberTypeFloat32 || input_tensor->data_type() == kNumberTypeFloat) { | |||
| ret = Fill(out_ptr_ + offset, size, src_data_); | |||
| } else if (input_tensor->data_type() == kNumberTypeInt32 || input_tensor->data_type() == kNumberTypeInt) { | |||
| ret = FillInt32(int32_out_ptr_ + offset, size, int32_src_data_); | |||
| } else { | |||
| return RET_ERROR; | |||
| } | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "FillRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return ret; | |||
| @@ -67,11 +75,20 @@ int FillRun(void *cdata, int task_id) { | |||
| } | |||
| int FillCPUKernel::Run() { | |||
| auto fillData = in_tensors_.at(in_tensors_.size() - 1); | |||
| auto fill_input = in_tensors_.front(); | |||
| auto output = out_tensors_.front(); | |||
| auto fill_data = reinterpret_cast<float *>(fillData->MutableData()); | |||
| src_data_ = fill_data[0]; | |||
| out_ptr_ = reinterpret_cast<float *>(output->MutableData()); | |||
| if (fill_input->data_type() == kNumberTypeFloat32 || fill_input->data_type() == kNumberTypeFloat) { | |||
| auto fill_data = reinterpret_cast<float *>(fill_input->MutableData()); | |||
| src_data_ = fill_data[0]; | |||
| out_ptr_ = reinterpret_cast<float *>(output->MutableData()); | |||
| } else if (fill_input->data_type() == kNumberTypeInt32 || fill_input->data_type() == kNumberTypeInt) { | |||
| auto fill_data = reinterpret_cast<int *>(fill_input->MutableData()); | |||
| int32_src_data_ = fill_data[0]; | |||
| int32_out_ptr_ = reinterpret_cast<int *>(output->MutableData()); | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported fill data type " << fill_input->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, FillRun, this, thread_sz_count_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]"; | |||
| @@ -80,5 +97,6 @@ int FillCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Fill, LiteKernelCreator<FillCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Fill, LiteKernelCreator<FillCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -44,6 +44,8 @@ class FillCPUKernel : public LiteKernel { | |||
| int data_size_; | |||
| float src_data_; | |||
| float *out_ptr_; | |||
| int int32_src_data_; | |||
| int *int32_out_ptr_; | |||
| int thread_count_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -138,10 +138,10 @@ int GruCPUKernel::Run() { | |||
| MS_ASSERT(output != nullptr); | |||
| auto input_ptr = reinterpret_cast<float *>(input->data_c()); | |||
| MS_ASSERT(input_ptr); | |||
| auto output_ptr = reinterpret_cast<float *>(output->MutableData()); | |||
| auto output_ptr = reinterpret_cast<float *>(output->data_c()); | |||
| MS_ASSERT(output_ptr); | |||
| auto output_hidden_state = out_tensors_[1]; | |||
| memcpy(output_hidden_state->MutableData(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float)); | |||
| memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float)); | |||
| int check_seq_len = gru_parm_->seq_len_; | |||
| if (in_tensors_.size() == 6) { | |||
| auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c()); | |||
| @@ -152,12 +152,12 @@ int GruCPUKernel::Run() { | |||
| check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0])); | |||
| } | |||
| MS_ASSERT(weight_g_ptr_); | |||
| MS_ASSERT(weight_r_ptr_); | |||
| MS_ASSERT(bias_ptr_); | |||
| MS_ASSERT(gate_buffer_); | |||
| MS_ASSERT(weight_g_ptr_ != nullptr); | |||
| MS_ASSERT(weight_r_ptr_ != nullptr); | |||
| MS_ASSERT(bias_ptr_ != nullptr); | |||
| MS_ASSERT(gate_buffer_ != nullptr); | |||
| Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, | |||
| reinterpret_cast<float *>(output_hidden_state->MutableData()), gate_buffer_, check_seq_len, gru_parm_); | |||
| reinterpret_cast<float *>(output_hidden_state->data_c()), gate_buffer_, check_seq_len, gru_parm_); | |||
| return RET_OK; | |||
| } | |||
| @@ -39,7 +39,7 @@ int TransposeCPUKernel::Init() { | |||
| int TransposeCPUKernel::ReSize() { | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_); | |||
| if (in_tensors_.at(kInputIndex)->shape().size() != static_cast<size_t>(param->num_axes_)) { | |||
| if (in_tensors_.at(kInputIndex)->shape().size() != static_cast<size_t>(param->num_axes_) && in_tensors_.size() != 2) { | |||
| return RET_OK; | |||
| } | |||
| auto &inTensor = in_tensors_.front(); | |||
| @@ -89,6 +89,20 @@ int TransposeCPUKernel::Run() { | |||
| MS_ASSERT(out_data_); | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||
| if (in_tensors_.size() == 2) { | |||
| auto input_perm = in_tensors_.at(1); | |||
| MS_ASSERT(input_perm != nullptr); | |||
| MS_ASSERT(input_perm->data_c() != nullptr); | |||
| int *perm_data = reinterpret_cast<int *>(input_perm->data_c()); | |||
| auto perm = std::vector<int>{perm_data, perm_data + input_perm->ElementsNum()}; | |||
| for (int i = 0; i < input_perm->ElementsNum(); ++i) { | |||
| param->perm_[i] = perm[i]; | |||
| } | |||
| for (int i = input_perm->ElementsNum(); i <= 8; ++i) { | |||
| param->perm_[i] = 0; | |||
| } | |||
| param->num_axes_ = input_perm->ElementsNum(); | |||
| } | |||
| if (in_tensor->shape().size() != static_cast<size_t>(param->num_axes_)) { | |||
| memcpy(out_data_, in_data_, in_tensor->ElementsNum() * sizeof(float)); | |||
| return RET_OK; | |||
| @@ -162,7 +162,9 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| inne_context_ptr->Init(); | |||
| const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr)); | |||
| } | |||
| const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>()); | |||
| auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>(); | |||
| update_conv2d_param_pass->SetFmkType(config->fmk); | |||
| const_fold_pm->AddPass(update_conv2d_param_pass); | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | |||
| convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); | |||
| if (config->fmk == lite::converter::FmkType_TFLITE) { | |||
| @@ -280,6 +280,7 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() { | |||
| second_partial_node_->outputIndex.push_back(graph_->allTensors.size() - 1); | |||
| } | |||
| auto origin_switch_outputs = switch_node_->outputIndex; | |||
| switch_node_->outputIndex.clear(); | |||
| for (size_t i = 3; i < switch_node_->inputIndex.size(); i++) { | |||
| auto &switch_in_tensor = graph_->allTensors.at(i); | |||
| @@ -338,7 +339,7 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() { | |||
| merge_node->inputIndex.insert(merge_node->inputIndex.end(), second_partial_node_->outputIndex.begin(), | |||
| second_partial_node_->outputIndex.end()); | |||
| } | |||
| merge_node->outputIndex = origin_switch_output_tensor_indices_; | |||
| merge_node->outputIndex = origin_switch_outputs; | |||
| graph_->nodes.push_back(std::move(merge_node)); | |||
| return RET_OK; | |||
| } | |||
| @@ -67,19 +67,23 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| attr->strideW = strides[1]; | |||
| auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (weight_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find Conv2D input weights failed"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int64_t> kernels(4); | |||
| status = ParseKernels(*weight_node, attr->format, &kernels); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| if (weight_node != nullptr) { | |||
| std::vector<int64_t> kernels(4); | |||
| status = ParseKernels(*weight_node, attr->format, &kernels); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| attr->kernelH = kernels[0]; | |||
| attr->kernelW = kernels[1]; | |||
| attr->channelIn = kernels[2]; | |||
| attr->channelOut = kernels[3]; | |||
| } else { | |||
| attr->kernelH = -1; | |||
| attr->kernelW = -1; | |||
| attr->channelIn = -1; | |||
| attr->channelOut = -1; | |||
| MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed"; | |||
| } | |||
| attr->kernelH = kernels[0]; | |||
| attr->kernelW = kernels[1]; | |||
| attr->channelIn = kernels[2]; | |||
| attr->channelOut = kernels[3]; | |||
| status = ParsePadMode(tf_op, &attr->padMode); | |||
| if (status != RET_OK) { | |||
| @@ -42,20 +42,15 @@ STATUS TFFillParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Fill; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| inputs->emplace_back(tf_op.input(1)); | |||
| // parse dims | |||
| tensorflow::AttrValue attr_value; | |||
| auto dims_node = GetConstInputNode(tf_node_map, tf_op.input(0)); | |||
| MS_ASSERT(dims_node != nullptr); | |||
| if (dims_node != nullptr && TensorFlowUtils::FindAttrValue(*dims_node, "value", &attr_value)) { | |||
| if (dims_node != nullptr) { | |||
| if (!TensorFlowUtils::FindAttrValue(*dims_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "fill dims input not have value attr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (attr_value.value_case() != tensorflow::AttrValue::kTensor) { | |||
| MS_LOG(ERROR) << "The attrValue of value should have tensor type, actual: " << attr_value.value_case() | |||
| << ", node: " << tf_op.name().c_str(); | |||
| @@ -66,32 +61,44 @@ STATUS TFFillParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "The dimsTensor dataType should be DT_INT32, actual : " << dims_tensor.dtype(); | |||
| return RET_ERROR; | |||
| } | |||
| const tensorflow::TensorShapeProto &dimsTensorShape = dims_tensor.tensor_shape(); | |||
| size_t shapeSize = 1; | |||
| for (int i = 0; i < dimsTensorShape.dim_size(); i++) { | |||
| shapeSize *= dimsTensorShape.dim(i).size(); | |||
| const tensorflow::TensorShapeProto &dims_tensor_shape = dims_tensor.tensor_shape(); | |||
| size_t shape_size = 1; | |||
| for (int i = 0; i < dims_tensor_shape.dim_size(); i++) { | |||
| shape_size *= dims_tensor_shape.dim(i).size(); | |||
| } | |||
| size_t size = dims_tensor.int_val().size(); | |||
| if (size > 0) { | |||
| for (size_t i = 0; i < shapeSize; i++) { | |||
| attr->dims.emplace_back(dims_tensor.int_val().Get(0)); | |||
| for (size_t i = 0; i < shape_size; i++) { | |||
| attr->dims.emplace_back(dims_tensor.int_val().Get(i)); | |||
| } | |||
| } else { | |||
| size = dims_tensor.tensor_content().length(); | |||
| if (size == shapeSize * sizeof(int32_t)) { | |||
| attr->dims.resize(shapeSize); | |||
| if (size > 0) { | |||
| if (size != shape_size * sizeof(int32_t)) { | |||
| MS_LOG(ERROR) << "tensor size mismatch"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->dims.resize(shape_size); | |||
| if (EOK != ::memcpy_s(attr->dims.data(), size, dims_tensor.tensor_content().data(), size)) { | |||
| MS_LOG(ERROR) << "Memcpy_s from dimsTensor to attr failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Can not find weight data, node: " << dims_node->name().c_str(); | |||
| return RET_ERROR; | |||
| MS_LOG(DEBUG) << "empty dims"; | |||
| } | |||
| } | |||
| } else { | |||
| inputs->emplace_back(tf_op.input(0)); | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Fill; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfFillParser("Fill", new TFFillParser()); | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| @@ -46,7 +46,7 @@ const NodeDef *TFNodeParser::GetConstInputNode(const std::map<string, const tens | |||
| node = tf_node_map.at(flatten_input_name); | |||
| } | |||
| if (node->op() != "Const") { | |||
| MS_LOG(ERROR) << "Attr node is not Const"; | |||
| MS_LOG(DEBUG) << "Attr node is not Const"; | |||
| return nullptr; | |||
| } | |||
| return node; | |||
| @@ -54,7 +54,7 @@ STATUS TFPoolParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| if (attr_value.s() == "VALID") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| } else if (attr_value.s() == "SAME") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| attr->padMode = schema::PadMode_SAME_UPPER; | |||
| } | |||
| } | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_rsqrt_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFRsqrtParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF RsqrtParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::RsqrtT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Rsqrt; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfRsqrtParser("Rsqrt", new TFRsqrtParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RSQRT_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RSQRT_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFRsqrtParser : public TFNodeParser { | |||
| public: | |||
| TFRsqrtParser() = default; | |||
| ~TFRsqrtParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RSQRT_PARSER_H_ | |||
| @@ -41,28 +41,36 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->conjugate = false; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| auto perm_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (perm_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find Transpose input perm failed"; | |||
| return RET_ERROR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| auto tensor_proto = attr_value.tensor(); | |||
| if (tensor_proto.int_val_size() > 0) { | |||
| for (int i = 0; i < tensor_proto.int_val_size(); ++i) { | |||
| attr->perm.push_back(tensor_proto.int_val(i)); | |||
| status = AddOpInput(tf_op, 1, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } else { | |||
| auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); | |||
| auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()); | |||
| for (size_t i = 0; i < data_num; ++i) { | |||
| attr->perm.push_back(data[i]); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| auto tensor_proto = attr_value.tensor(); | |||
| if (tensor_proto.int_val_size() > 0) { | |||
| for (int i = 0; i < tensor_proto.int_val_size(); ++i) { | |||
| attr->perm.push_back(tensor_proto.int_val(i)); | |||
| } | |||
| } else { | |||
| auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); | |||
| auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()); | |||
| for (size_t i = 0; i < data_num; ++i) { | |||
| attr->perm.push_back(data[i]); | |||
| } | |||
| } | |||
| } | |||
| @@ -75,7 +83,6 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfTransposeParser("Transpose", new TFTransposeParser()); | |||
| @@ -693,7 +693,7 @@ STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int3 | |||
| tensor->set_tensor_shape({filterC, filterK, filterH, filterW}); | |||
| } else if (type == kKHWC2CHWK) { | |||
| tensor->set_tensor_shape({filterC, filterH, filterW, filterK}); | |||
| } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) { | |||
| } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC || type == kHWCK2KHWC) { | |||
| tensor->set_tensor_shape({filterK, filterH, filterW, filterC}); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | |||
| @@ -812,7 +812,8 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| } | |||
| } break; | |||
| case kHWCK2KCHW: | |||
| case kHWCK2CKHW: { | |||
| case kHWCK2CKHW: | |||
| case kHWCK2KHWC: { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| @@ -821,9 +822,12 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType | |||
| if (type == kHWCK2KCHW) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| } else if (type == kHWCK2CKHW) { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| @@ -25,7 +25,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kWhileCommonInputsLength = 2; | |||
| constexpr size_t kWhileUniqInputsLength = 6; | |||
| constexpr size_t kCondNodesNum = 12; | |||
| constexpr size_t kCondCNodesNum = 4; | |||
| @@ -47,16 +46,11 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, | |||
| : PatternProcessPass(name, multigraph) { | |||
| /* | |||
| * vars for while input | |||
| * common: | |||
| * 0:const0 1:init_state | |||
| * fw_while_inputs: | |||
| * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias | |||
| * bw_while_inputs: | |||
| * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias | |||
| */ | |||
| for (size_t i = 0; i < kWhileCommonInputsLength; ++i) { | |||
| common_vars_.emplace_back(std::make_shared<Var>()); | |||
| } | |||
| for (size_t i = 0; i < kWhileUniqInputsLength; ++i) { | |||
| fw_vars_.emplace_back(std::make_shared<Var>()); | |||
| bw_vars_.emplace_back(std::make_shared<Var>()); | |||
| @@ -64,17 +58,16 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, | |||
| input_ = std::make_shared<Var>(); | |||
| input_length_ = std::make_shared<Var>(); | |||
| transpose_input_ = std::make_shared<Var>(); | |||
| fw_init_state_ = std::make_shared<Var>(); | |||
| bw_init_state_ = std::make_shared<Var>(); | |||
| } | |||
| const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { | |||
| auto const1 = std::make_shared<CondVar>(IsParameterNode); | |||
| auto ele_shape = std::make_shared<CondVar>(IsParameterNode); | |||
| // forward | |||
| auto fw_max1 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); | |||
| auto fw_max2 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, fw_max1}); | |||
| auto fw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), | |||
| std::make_shared<CondVar>(IsParameterNode), fw_max1}); | |||
| auto fw_shape = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_}); | |||
| @@ -84,32 +77,33 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2}); | |||
| auto fw_reserve = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, | |||
| fw_stride}); | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), | |||
| std::make_shared<CondVar>(IsParameterNode), fw_stride}); | |||
| auto fw_from_tensor = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), | |||
| transpose_input_, ele_shape}); | |||
| transpose_input_, std::make_shared<CondVar>(IsParameterNode)}); | |||
| auto is_fw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While)); | |||
| auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], common_vars_[0], fw_stride, common_vars_[0], | |||
| fw_reserve, common_vars_[1], fw_min, fw_from_tensor, input_length_}); | |||
| auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], std::make_shared<CondVar>(IsParameterNode), | |||
| fw_stride, std::make_shared<CondVar>(IsParameterNode), fw_reserve, fw_init_state_, fw_min, | |||
| fw_from_tensor, input_length_}); | |||
| fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end()); | |||
| fw_while.emplace_back(common_vars_[1]); | |||
| fw_while.emplace_back(std::make_shared<Var>()); | |||
| auto fw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), | |||
| fw_while, std::make_shared<Var>()}); | |||
| auto fw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), | |||
| fw_get_item, ele_shape}); | |||
| auto fw_out_trans = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), fw_stack}); | |||
| fw_get_item, std::make_shared<CondVar>(IsParameterNode)}); | |||
| auto fw_out_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), | |||
| fw_stack, std::make_shared<Var>()}); | |||
| // backward | |||
| auto bw_reverse_seq = VectorRef( | |||
| {std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), input_, input_length_}); | |||
| auto bw_max1 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_}); | |||
| auto bw_max2 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, bw_max1}); | |||
| auto bw_trans = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_reverse_seq}); | |||
| auto bw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), | |||
| std::make_shared<CondVar>(IsParameterNode), bw_max1}); | |||
| auto bw_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), | |||
| bw_reverse_seq, std::make_shared<Var>()}); | |||
| auto bw_shape = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans}); | |||
| auto bw_stride = | |||
| @@ -117,22 +111,23 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { | |||
| auto bw_min = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2}); | |||
| auto bw_reserve = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape, | |||
| bw_stride}); | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), | |||
| std::make_shared<CondVar>(IsParameterNode), bw_stride}); | |||
| auto bw_from_tensor = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), bw_trans, | |||
| ele_shape}); | |||
| std::make_shared<CondVar>(IsParameterNode)}); | |||
| auto is_bw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While)); | |||
| auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], common_vars_[0], bw_stride, common_vars_[0], | |||
| bw_reserve, common_vars_[1], bw_min, bw_from_tensor, input_length_}); | |||
| auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], std::make_shared<CondVar>(IsParameterNode), | |||
| bw_stride, std::make_shared<CondVar>(IsParameterNode), bw_reserve, bw_init_state_, bw_min, | |||
| bw_from_tensor, input_length_}); | |||
| bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end()); | |||
| bw_while.emplace_back(common_vars_[1]); | |||
| bw_while.emplace_back(std::make_shared<Var>()); | |||
| auto bw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)), | |||
| bw_while, std::make_shared<Var>()}); | |||
| auto bw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)), | |||
| bw_get_item, ele_shape}); | |||
| auto bw_out_trans = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_stack}); | |||
| bw_get_item, std::make_shared<CondVar>(IsParameterNode)}); | |||
| auto bw_out_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), | |||
| bw_stack, std::make_shared<Var>()}); | |||
| auto bw_reverse1 = | |||
| VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), bw_out_trans, | |||
| input_length_}); | |||
| @@ -416,10 +411,12 @@ STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, | |||
| } | |||
| CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &hidden_state, | |||
| const AnfNodePtr &fw_init_state, | |||
| const AnfNodePtr &bw_init_state, | |||
| const std::string base_name) const { | |||
| MS_ASSERT(func_graph); | |||
| MS_ASSERT(hidden_state); | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(fw_init_state != nullptr); | |||
| MS_ASSERT(bw_init_state != nullptr); | |||
| auto stack_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::StackT> attr = std::make_unique<schema::StackT>(); | |||
| attr->axis = 0; | |||
| @@ -427,9 +424,9 @@ CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &f | |||
| stack_primitive->value.value = attr.release(); | |||
| auto stack_cvalue = lite::PrimitiveC::Create(stack_primitive.release()); | |||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(stack_cvalue)); | |||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, hidden_state, hidden_state}; | |||
| std::vector<AnfNodePtr> new_node_inputs = {value_node, fw_init_state, bw_init_state}; | |||
| auto new_node = func_graph->NewCNode(new_node_inputs); | |||
| new_node->set_abstract(hidden_state->abstract()->Clone()); | |||
| new_node->set_abstract(fw_init_state->abstract()->Clone()); | |||
| new_node->set_fullname_with_scope("stack_hidden_" + base_name); | |||
| return new_node; | |||
| } | |||
| @@ -452,31 +449,33 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr | |||
| auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(gru_cvalue)); | |||
| auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[2]]); | |||
| MS_ASSERT(fw_gate_kernel); | |||
| MS_ASSERT(fw_gate_kernel != nullptr); | |||
| auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[3]]); | |||
| MS_ASSERT(fw_gate_bias); | |||
| MS_ASSERT(fw_gate_bias != nullptr); | |||
| auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[4]]); | |||
| MS_ASSERT(fw_cand_kernel); | |||
| MS_ASSERT(fw_cand_kernel != nullptr); | |||
| auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[5]]); | |||
| MS_ASSERT(fw_cand_bias); | |||
| MS_ASSERT(fw_cand_bias != nullptr); | |||
| auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[2]]); | |||
| MS_ASSERT(bw_gate_kernel); | |||
| MS_ASSERT(bw_gate_kernel != nullptr); | |||
| auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[3]]); | |||
| MS_ASSERT(bw_gate_bias); | |||
| MS_ASSERT(bw_gate_bias != nullptr); | |||
| auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[4]]); | |||
| MS_ASSERT(bw_cand_kernel); | |||
| MS_ASSERT(bw_cand_kernel != nullptr); | |||
| auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[5]]); | |||
| MS_ASSERT(bw_cand_bias); | |||
| MS_ASSERT(bw_cand_bias != nullptr); | |||
| auto hidden = utils::cast<AnfNodePtr>((*equiv)[common_vars_[1]]); | |||
| MS_ASSERT(hidden); | |||
| auto stacked_hidden = GetStackedHiddenState(func_graph, hidden, base_name); | |||
| auto fw_init_state = utils::cast<AnfNodePtr>((*equiv)[fw_init_state_]); | |||
| MS_ASSERT(fw_init_state != nullptr); | |||
| auto bw_init_state = utils::cast<AnfNodePtr>((*equiv)[bw_init_state_]); | |||
| MS_ASSERT(bw_init_state != nullptr); | |||
| auto stacked_hidden = GetStackedHiddenState(func_graph, fw_init_state, bw_init_state, base_name); | |||
| if (stacked_hidden == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto input_length = utils::cast<AnfNodePtr>((*equiv)[input_length_]); | |||
| MS_ASSERT(hidden); | |||
| MS_ASSERT(hidden != nullptr); | |||
| int input_size = 0; | |||
| int hidden_size = 0; | |||
| @@ -536,8 +535,8 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr | |||
| CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||
| const std::string base_name) const { | |||
| MS_ASSERT(func_graph); | |||
| MS_ASSERT(gru_output); | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(gru_output != nullptr); | |||
| auto split_primitive = std::make_unique<schema::PrimitiveT>(); | |||
| std::unique_ptr<schema::SplitT> split_attr = std::make_unique<schema::SplitT>(); | |||
| split_attr->numberSplit = 2; | |||
| @@ -603,8 +602,8 @@ CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func | |||
| const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node, | |||
| const EquivPtr &equiv) const { | |||
| MS_ASSERT(func_graph); | |||
| MS_ASSERT(concat_node); | |||
| MS_ASSERT(func_graph != nullptr); | |||
| MS_ASSERT(concat_node != nullptr); | |||
| MS_LOG(DEBUG) << "bidirection tf gru fusion pass"; | |||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) { | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||
| @@ -612,7 +611,7 @@ const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_gr | |||
| } | |||
| auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]); | |||
| MS_ASSERT(transpose_input); | |||
| MS_ASSERT(transpose_input != nullptr); | |||
| if (!utils::isa<CNodePtr>(transpose_input) || GetCNodeType(transpose_input) != schema::PrimitiveType_Transpose) { | |||
| return nullptr; | |||
| } | |||
| @@ -54,18 +54,19 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass { | |||
| float *tensor_data) const; | |||
| void CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1, const int c0, | |||
| const int c1, float *data, bool t = false) const; | |||
| CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &hidden_state, | |||
| const std::string base_name) const; | |||
| CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state, | |||
| const AnfNodePtr &bw_init_state, const std::string base_name) const; | |||
| CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output, | |||
| const std::string base_name) const; | |||
| private: | |||
| std::vector<VarPtr> common_vars_; | |||
| std::vector<VarPtr> fw_vars_; | |||
| std::vector<VarPtr> bw_vars_; | |||
| VarPtr input_; | |||
| VarPtr input_length_; | |||
| VarPtr transpose_input_; | |||
| VarPtr fw_init_state_; | |||
| VarPtr bw_init_state_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -53,7 +53,44 @@ bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) { | |||
| primT->value.AsDepthwiseConv2D()->channelIn = weight->tensor_shape().at(0); | |||
| } | |||
| } | |||
| } else if (type == schema::PrimitiveType_Conv2D) { | |||
| auto conv2d_cnode = node->cast<CNodePtr>(); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv2d_cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "Conv2D node has no primitiveC."; | |||
| continue; | |||
| } | |||
| auto primT = primitive_c->primitiveT(); | |||
| if (primT == nullptr) { | |||
| MS_LOG(ERROR) << "Conv2D node has no primitiveT."; | |||
| continue; | |||
| } | |||
| auto conv2d_primt = primT->value.AsConv2D(); | |||
| auto weight_node = conv2d_cnode->input(lite::kAnfPopulaterInputNumTwo); | |||
| if (weight_node == nullptr) { | |||
| MS_LOG(ERROR) << "Conv2D weight node is nullptr."; | |||
| continue; | |||
| } | |||
| if (!weight_node->isa<Parameter>()) { | |||
| MS_LOG(ERROR) << "Conv2D weight node is not parameter."; | |||
| continue; | |||
| } | |||
| auto weight_param = weight_node->cast<ParameterPtr>(); | |||
| if (!weight_param->has_default()) { | |||
| MS_LOG(ERROR) << "Conv2D weight node is not parameter."; | |||
| continue; | |||
| } | |||
| auto default_param = weight_param->default_param(); | |||
| auto weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(default_param); | |||
| auto weight_shape = weight_tensor->tensor_shape(); | |||
| if (fmk_type == lite::converter::FmkType_TF && conv2d_primt->format == schema::Format_NHWC) { | |||
| conv2d_primt->kernelH = weight_shape[0]; | |||
| conv2d_primt->kernelW = weight_shape[1]; | |||
| conv2d_primt->channelIn = weight_shape[2]; | |||
| conv2d_primt->channelOut = weight_shape[3]; | |||
| } | |||
| } | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "remove identity pass is failed."; | |||
| return false; | |||
| @@ -19,13 +19,19 @@ | |||
| #include "schema/inner/model_generated.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| namespace mindspore::opt { | |||
| class UpdateConv2DParamPass : public Pass { | |||
| public: | |||
| UpdateConv2DParamPass() : Pass("update_conv2d_param_pass") {} | |||
| ~UpdateConv2DParamPass() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| void SetFmkType(FmkType fmk_type) { this->fmk_type = fmk_type; } | |||
| private: | |||
| FmkType fmk_type = lite::converter::FmkType_ONNX; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_ | |||