From: @wangzhe128 Reviewed-by: @hangangqiang,@ddwsky Signed-off-by: @hangangqiangtags/v1.1.0
| @@ -99,9 +99,8 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| auto input0 = inputs_.front(); | |||
| MS_ASSERT(input0 != nullptr); | |||
| auto ele_shape_type = input0->data_type(); | |||
| if (ele_shape_type != kNumberTypeInt) { | |||
| MS_LOG(ERROR) << "ele_shape_tensor.data_type():" << ele_shape_type | |||
| << " must be \"kNumberTypeInt\":" << kNumberTypeInt; | |||
| if (ele_shape_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "ele_shape_tensor.data_type():" << ele_shape_type << " is not int"; | |||
| return RET_ERROR; | |||
| } | |||
| if (input0->data_c() == nullptr) { | |||
| @@ -113,8 +112,8 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| auto input1 = inputs_[1]; | |||
| MS_ASSERT(input1 != nullptr); | |||
| auto num_ele_type = input1->data_type(); | |||
| if (num_ele_type != kNumberTypeInt) { | |||
| MS_LOG(ERROR) << "num_ele_tensor.data_type():" << num_ele_type << " must be \"kNumberTypeInt\":" << kNumberTypeInt; | |||
| if (num_ele_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "num_ele_tensor.data_type():" << num_ele_type << " is not int"; | |||
| return RET_ERROR; | |||
| } | |||
| if (input1->ElementsNum() != 1) { | |||
| @@ -97,9 +97,8 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| MS_ASSERT(input0 != nullptr); | |||
| auto get_index = inputs_[1]; | |||
| MS_ASSERT(get_index != nullptr); | |||
| if (get_index->data_type() != kNumberTypeInt) { | |||
| MS_LOG(ERROR) << "inputs_[1]->data_type():" << get_index->data_type() | |||
| << " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt; | |||
| if (get_index->data_type() != kNumberTypeInt && get_index->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "inputs_[1]->data_type():" << get_index->data_type() << " is not int"; | |||
| return RET_ERROR; | |||
| } | |||
| if (get_index->ElementsNum() != 1) { | |||
| @@ -228,9 +228,8 @@ bool TensorList::IsCompatibleShape(const Tensor *src) { | |||
| if (static_cast<size_t>(src->ElementsNum()) != this->element_shape_.size()) { | |||
| return false; | |||
| } | |||
| if (src->data_type() != kNumberTypeInt) { | |||
| MS_LOG(ERROR) << "src tensor data_type:" << src->data_type() | |||
| << " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt; | |||
| if (src->data_type() != kNumberTypeInt && src->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "src tensor data_type:" << src->data_type() << " is not int"; | |||
| return false; | |||
| } | |||
| auto src_ptr = reinterpret_cast<int *>(src->data_c()); | |||
| @@ -76,6 +76,80 @@ std::string GetOriginInputName(const tensorflow::NodeDef &node, | |||
| } | |||
| } // namespace | |||
| STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, | |||
| const ParamValueLitePtr ¶m_value) { | |||
| MS_ASSERT(param_value != nullptr); | |||
| auto variant_size = tensor_proto.variant_val_size(); | |||
| if (variant_size != 1) { | |||
| MS_LOG(ERROR) << "only support variant_val_size == 1 now"; | |||
| return RET_ERROR; | |||
| } | |||
| auto &variant = tensor_proto.variant_val(0); | |||
| if (variant.type_name() != "tensorflow::TensorList") { | |||
| MS_LOG(ERROR) << "Only TensorList type is supported now"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| auto descriptor = variant.GetMetadata().descriptor; | |||
| auto reflection = variant.GetMetadata().reflection; | |||
| if (descriptor == nullptr || reflection == nullptr) { | |||
| MS_LOG(ERROR) << "descriptor or reflection is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto field_descriptor = descriptor->field(1); | |||
| if (field_descriptor == nullptr) { | |||
| MS_LOG(ERROR) << "field_descriptor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto type = field_descriptor->type(); | |||
| if (type != google::protobuf::FieldDescriptor::TYPE_BYTES) { | |||
| MS_LOG(ERROR) << "metadata type is not TYPE_BYTES"; | |||
| return RET_ERROR; | |||
| } | |||
| auto str = reflection->GetString(variant, field_descriptor); | |||
| std::string_view str_view(str); | |||
| uint64_t scratch; | |||
| if (!TensorFlowUtils::DecodeInt64(&str_view, &scratch)) { | |||
| return RET_ERROR; | |||
| } | |||
| size_t num_invalid_tensors = static_cast<size_t>(scratch); | |||
| for (size_t i = 0; i < num_invalid_tensors; ++i) { | |||
| if (!TensorFlowUtils::DecodeInt64(&str_view, &scratch)) { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (!TensorFlowUtils::DecodeInt64(&str_view, &scratch)) { | |||
| return RET_ERROR; | |||
| } | |||
| size_t element_dtype = static_cast<size_t>(scratch); | |||
| if (!TensorFlowUtils::DecodeInt64(&str_view, &scratch)) { | |||
| return RET_ERROR; | |||
| } | |||
| std::string element_shape_str = std::string(str_view.data(), str_view.size()); | |||
| tensorflow::TensorShapeProto element_shape_proto; | |||
| element_shape_proto.ParseFromString(element_shape_str); | |||
| auto dim_size = element_shape_proto.dim_size(); | |||
| // we encode element_dtype,shape.size,shape[i]... into data | |||
| auto tensor_data = new (std::nothrow) int[dim_size + 2]; | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_data is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| tensor_data[0] = TensorFlowUtils::GetTFDataType(tensorflow::DataType(element_dtype)); | |||
| tensor_data[1] = element_shape_proto.dim_size(); | |||
| for (int i = 0; i < dim_size; ++i) { | |||
| auto dim = element_shape_proto.dim(i).size(); | |||
| if (dim > static_cast<int64_t>(INT32_MAX) || dim < static_cast<int64_t>(INT32_MIN)) { | |||
| MS_LOG(ERROR) << "int64 data " << dim << " too big to fit into int32"; | |||
| delete[] tensor_data; | |||
| return RET_ERROR; | |||
| } else { | |||
| tensor_data[i + 2] = static_cast<int>(dim); | |||
| } | |||
| } | |||
| param_value->SetTensorData(tensor_data, (dim_size + 2) * sizeof(int)); | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, | |||
| const ParameterPtr ¶meter, std::vector<int64_t> *shape_vector) { | |||
| MS_ASSERT(parameter != nullptr); | |||
| @@ -143,6 +217,11 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value | |||
| } | |||
| tensor_size = shape_size * sizeof(int); | |||
| param_value->SetTensorData(tensor_data, tensor_size); | |||
| } else if (type == kObjectTypeTensorType) { | |||
| auto status = ConvertConstVariant(tensor_proto, param_value); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport dataType: " << type; | |||
| return RET_ERROR; | |||
| @@ -28,6 +28,7 @@ | |||
| #include "securec/include/securec.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "mindspore/lite/src/param_value_lite.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -43,6 +44,7 @@ class TFModelParser : public ModelParser { | |||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||
| private: | |||
| STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, const ParamValueLitePtr ¶m_value); | |||
| STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr ¶meter, | |||
| std::vector<int64_t> *shape_vector); | |||
| STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter, | |||
| @@ -16,6 +16,7 @@ | |||
| #include "tools/converter/parser/tf/tf_util.h" | |||
| #include <string> | |||
| #include <string_view> | |||
| #include <unordered_map> | |||
| #include "src/common/log_adapter.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -27,7 +28,7 @@ static const std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = { | |||
| {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, | |||
| {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, | |||
| {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, | |||
| {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, | |||
| {tensorflow::DT_INT32, mindspore::kNumberTypeInt}, | |||
| {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, | |||
| {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, | |||
| {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, | |||
| @@ -65,6 +66,7 @@ TypeId TensorFlowUtils::ParseAttrDataType(const tensorflow::NodeDef &node_def, c | |||
| } | |||
| return GetTFDataType(attr_value.type()); | |||
| } | |||
| schema::Format TensorFlowUtils::ParseNodeFormat(const tensorflow::NodeDef &node_def) { | |||
| tensorflow::AttrValue attr_value; | |||
| if (!FindAttrValue(node_def, "data_format", &attr_value)) { | |||
| @@ -78,5 +80,37 @@ schema::Format TensorFlowUtils::ParseNodeFormat(const tensorflow::NodeDef &node_ | |||
| } | |||
| return schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) { | |||
| if (str_view == nullptr || value == nullptr) { | |||
| *value = 0; | |||
| MS_LOG(ERROR) << "str_view or value is nullptr"; | |||
| return false; | |||
| } | |||
| auto data = str_view->data(); | |||
| const auto end = data + str_view->size(); | |||
| const char *next = nullptr; | |||
| uint64_t result = 0; | |||
| for (uint32_t shift = 0; shift <= 63 && data < end; shift += 7) { | |||
| uint64_t byte = *(reinterpret_cast<const unsigned char *>(data)); | |||
| data++; | |||
| if (byte & 128) { | |||
| result |= ((byte & 127) << shift); | |||
| } else { | |||
| result |= (byte << shift); | |||
| *value = result; | |||
| next = reinterpret_cast<const char *>(data); | |||
| break; | |||
| } | |||
| } | |||
| if (next == nullptr) { | |||
| return false; | |||
| } else { | |||
| *str_view = std::string_view(next, end - next); | |||
| return true; | |||
| } | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H | |||
| #include <string> | |||
| #include <string_view> | |||
| #include "proto/node_def.pb.h" | |||
| #include "ir/dtype/type_id.h" | |||
| #include "include/errorcode.h" | |||
| @@ -32,6 +33,7 @@ class TensorFlowUtils { | |||
| tensorflow::AttrValue *attr_value); | |||
| static TypeId ParseAttrDataType(const tensorflow::NodeDef &node_def, const std::string &attr_name); | |||
| static schema::Format ParseNodeFormat(const tensorflow::NodeDef &node_def); | |||
| static bool DecodeInt64(std::string_view *str_view, uint64_t *value); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -119,11 +119,6 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||
| MS_LOG(ERROR) << "input is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto tensor = std::make_unique<lite::Tensor>(); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new input tensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (utils::isa<ValueNodePtr>(cnode->input(i))) { | |||
| MS_LOG(ERROR) << "input is value node"; | |||
| @@ -149,23 +144,47 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||
| MS_LOG(ERROR) << "ParamValueLite of abstract is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| tensor->set_shape(param_value_lite->tensor_shape()); | |||
| tensor->set_data_type(param_value_lite->tensor_type()); | |||
| tensor->set_format(schema::Format(param_value_lite->format())); | |||
| std::unique_ptr<lite::Tensor> tensor = nullptr; | |||
| if (param_value_lite->tensor_type() != kObjectTypeTensorType) { | |||
| tensor = std::make_unique<lite::Tensor>(); | |||
| } else { | |||
| tensor = std::make_unique<lite::TensorList>(); | |||
| } | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new input tensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param_value_lite->tensor_type() != kObjectTypeTensorType) { | |||
| tensor->set_shape(param_value_lite->tensor_shape()); | |||
| tensor->set_data_type(param_value_lite->tensor_type()); | |||
| tensor->set_format(schema::Format(param_value_lite->format())); | |||
| } | |||
| if (utils::isa<ParameterPtr>(input)) { | |||
| auto parameter = input->cast<ParameterPtr>(); | |||
| if (parameter->has_default()) { | |||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); | |||
| auto ret = tensor->MallocData(); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "Malloc tensor data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size()); | |||
| if (tensor->Size() != 0 && ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| if (param_value_lite->tensor_type() != kObjectTypeTensorType) { | |||
| auto ret = tensor->MallocData(); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "Malloc tensor data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size()); | |||
| if (tensor->Size() != 0 && ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| int *data = reinterpret_cast<int *>(param_value->tensor_addr()); | |||
| auto tensor_list = dynamic_cast<lite::TensorList *>(tensor.get()); | |||
| tensor_list->set_tensors_data_type(TypeId(data[0])); | |||
| std::vector<int> shape; | |||
| for (int j = 0; j < data[1]; ++j) { | |||
| shape.push_back(data[2 + j]); | |||
| } | |||
| tensor_list->set_element_shape(shape); | |||
| } | |||
| } | |||
| } | |||
| @@ -181,13 +200,35 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector< | |||
| MS_LOG(ERROR) << "abstract is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t num_outputs = 1; | |||
| std::vector<TypeId> types; | |||
| if (utils::isa<abstract::AbstractTuple>(abstract)) { | |||
| auto abstract_tuple = abstract->cast<abstract::AbstractTuplePtr>(); | |||
| num_outputs = abstract_tuple->size(); | |||
| auto elements = abstract_tuple->elements(); | |||
| for (auto &element : elements) { | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(element)) { | |||
| MS_LOG(ERROR) << "abstract is not AbstractTensor"; | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(element); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| types.push_back(typePtr->type_id()); | |||
| } | |||
| } else { | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||
| MS_LOG(ERROR) << "abstract is not AbstractTensor"; | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| types.push_back(typePtr->type_id()); | |||
| } | |||
| for (size_t i = 0; i < num_outputs; ++i) { | |||
| auto output_tensor = std::make_unique<lite::Tensor>(); | |||
| for (auto &type : types) { | |||
| std::unique_ptr<lite::Tensor> output_tensor = nullptr; | |||
| if (type == kObjectTypeTensorType) { | |||
| output_tensor = std::make_unique<lite::TensorList>(); | |||
| } else { | |||
| output_tensor = std::make_unique<lite::Tensor>(); | |||
| } | |||
| if (output_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new output tensor failed"; | |||
| return RET_ERROR; | |||
| @@ -22,6 +22,7 @@ | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "mindspore/lite/src/tensor.h" | |||
| #include "mindspore/lite/src/tensorlist.h" | |||
| #include "mindspore/lite/include/errorcode.h" | |||
| using mindspore::lite::STATUS; | |||
| using mindspore::lite::converter::FmkType; | |||