diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 032f96b362..bc2a898d54 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -306,6 +306,7 @@ if (ENABLE_CONVERTER) tflite_parser_mid caffe_parser_mid onnx_parser_mid + tf_parser_mid graph_pass_mid fusion_mid quantizer_mid diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 292b0af7b2..8a7e365bba 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -61,6 +61,7 @@ add_subdirectory(../anf_exporter anf_exporter) add_subdirectory(parser/caffe) add_subdirectory(parser/tflite) add_subdirectory(parser/onnx) +add_subdirectory(parser/tf) add_subdirectory(legacy_optimizer) add_subdirectory(quantizer) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../core mindspore_core) @@ -111,6 +112,7 @@ endif () file(GLOB PROTO_FILE "" ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto + ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/*.proto ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) add_library(proto_mid OBJECT ${PROTO_SRCS}) @@ -138,6 +140,7 @@ add_dependencies(converter_lite fbs_inner_src) target_link_libraries(converter_lite PRIVATE tflite_parser_mid + tf_parser_mid caffe_parser_mid onnx_parser_mid anf_importer_mid diff --git a/mindspore/lite/tools/converter/parser/tf/CMakeLists.txt b/mindspore/lite/tools/converter/parser/tf/CMakeLists.txt new file mode 100644 index 0000000000..d1f991dc37 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/CMakeLists.txt @@ -0,0 +1,7 @@ +file(GLOB_RECURSE TF_SRC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) + +set_property(SOURCE ${TF_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) + +add_library(tf_parser_mid OBJECT ${TF_SRC_LIST}) + +add_dependencies(tf_parser_mid proto_mid) diff --git a/mindspore/lite/tools/converter/parser/tf/attr_value.proto b/mindspore/lite/tools/converter/parser/tf/attr_value.proto new file mode 100644 index 0000000000..249318ce60 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in anf_node_map defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/mindspore/lite/tools/converter/parser/tf/function.proto b/mindspore/lite/tools/converter/parser/tf/function.proto new file mode 100644 index 0000000000..ccb12cb4f2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/function.proto @@ -0,0 +1,101 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// +// TODO(zhifengc): +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21. + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/mindspore/lite/tools/converter/parser/tf/graph.proto b/mindspore/lite/tools/converter/parser/tf/graph.proto new file mode 100644 index 0000000000..65119189ff --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/mindspore/lite/tools/converter/parser/tf/node_def.proto b/mindspore/lite/tools/converter/parser/tf/node_def.proto new file mode 100644 index 0000000000..e6d545ad5c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/gpu:3" (full specification) + // * "/job:worker/gpu:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // TODO(josh11b): Add some examples here showing best practices. + map attr = 5; +}; diff --git a/mindspore/lite/tools/converter/parser/tf/op_def.proto b/mindspore/lite/tools/converter/parser/tf/op_def.proto new file mode 100644 index 0000000000..baf68eaad3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/op_def.proto @@ -0,0 +1,157 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + // TODO(josh11b): bool is_optional? + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + // TODO(josh11b): Implement that optimization. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/mindspore/lite/tools/converter/parser/tf/resource_handle.proto b/mindspore/lite/tools/converter/parser/tf/resource_handle.proto new file mode 100644 index 0000000000..b1921337f5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/mindspore/lite/tools/converter/parser/tf/tensor.proto b/mindspore/lite/tools/converter/parser/tf/tensor.proto new file mode 100644 index 0000000000..c792930234 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tensor.proto @@ -0,0 +1,88 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. TODO(touts): sort out the 0-rank issues. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF. Note that since protobuf has no int16 type, we'll have some + // pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/mindspore/lite/tools/converter/parser/tf/tensor_shape.proto b/mindspore/lite/tools/converter/parser/tf/tensor_shape.proto new file mode 100644 index 0000000000..1ec3c5323c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_add_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_add_parser.cc new file mode 100644 index 0000000000..a04cdbec57 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_add_parser.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tf/tf_add_parser.h" +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFAddParser::Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr &tf_model, + PrimitiveC *primitiveC, int *output_size) { + auto attr = std::make_unique(); + attr->value.type = schema::PrimitiveType_Add; + primitiveC = PrimitiveC::Create(attr.release()); + MS_LOG(INFO) << "primitive name" << primitiveC->type_name(); + return RET_OK; +} +TFNodeRegistrar g_tfAddParser("Add", new TFAddParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_add_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_add_parser.h new file mode 100644 index 0000000000..00ffb2989a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_add_parser.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H + +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFAddParser : public TFNodeParser { + public: + TFAddParser() = default; + ~TFAddParser() override = default; + + STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr &tf_model, + PrimitiveC *primitiveC, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc new file mode 100644 index 0000000000..f41474f09c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -0,0 +1,286 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * distributed under the License is distributed on an AS + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tf/tf_model_parser.h" +#include +#include +#include "src/common/log_adapter.h" +#include "tools/converter/parser/tf/tf_util.h" +#include "tools/common/graph_util.h" +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "src/param_value_lite.h" + +namespace mindspore { +namespace lite { +FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType) { + auto status = ValidateFileStr(modelFile, ".prototxt"); + if (status != RET_OK) { + MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + if (!TensorFlowUtils::TfReadProtoFromBinary(modelFile.c_str(), tf_graph_def.get())) { + MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return nullptr; + } + funcGraphPtr = std::make_shared(); + status = ConvertGraphInputs(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert graph inputs failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + status = ConvertOps(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert ops failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + + status = ConvertGraphOutputs(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert graph outputs failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + return funcGraphPtr; +} +STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef *node, ParameterPtr parameter) { + tensorflow::AttrValue attr_value; + if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { + tensorflow::AttrValue data_type; + tensorflow::DataType type = tensorflow::DT_FLOAT; + // datatype + if (TensorFlowUtils::FindAttrValue(node, "dtype", &data_type)) { + type = data_type.type(); + } + const tensorflow::TensorProto &tensorProto = attr_value.tensor(); + const tensorflow::TensorShapeProto &tensorShape = tensorProto.tensor_shape(); + parameter = funcGraphPtr->add_parameter(); + std::vector shape_vector; + int shape_size = 1; + shape_vector.resize(tensorShape.dim_size()); + for (int i = 0; i < tensorShape.dim_size(); i++) { + shape_vector[i] = tensorShape.dim(i).size(); + shape_size *= shape_vector[i]; + } + // convert const to paramter + TypePtr ms_data_ype; + auto paramValue = std::make_shared(); + if (type == tensorflow::DT_FLOAT) { + ms_data_ype = kFloat32; + auto tensor_data = new (std::nothrow) float[shape_size]; + if (tensorProto.float_val_size() == 1) { + float value = tensorProto.float_val(0); + for (int i = 0; i < shape_size; i++) { + tensor_data[i] = value; + } + } + if (tensorProto.tensor_content().size() == shape_size * sizeof(float)) { + const auto addr = reinterpret_cast(tensorProto.tensor_content().data()); + auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_ERROR; + } + } + paramValue->set_tensor_addr(tensor_data); + paramValue->set_tensor_size(shape_size * sizeof(float)); + } else if (type == tensorflow::DT_INT32) { + ms_data_ype = kInt32; + auto tensor_data = new (std::nothrow) int[shape_size]; + if (tensorProto.int_val_size() == 1) { + int value = tensorProto.int_val(0); + for (int i = 0; i < shape_size; i++) { + tensor_data[i] = value; + } + } + if (tensorProto.tensor_content().size() == shape_size * sizeof(int32_t)) { + const auto addr = reinterpret_cast(tensorProto.tensor_content().data()); + auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_ERROR; + } + } + paramValue->set_tensor_addr(tensor_data); + paramValue->set_tensor_size(shape_size * sizeof(int)); + } else if (type == tensorflow::DT_BOOL) { + ms_data_ype = kFloat32; + auto tensor_data = new (std::nothrow) int[shape_size]; + if (tensorProto.bool_val_size() == 1) { + int value = tensorProto.bool_val(0); + for (int i = 0; i < shape_size; i++) { + tensor_data[i] = value; + } + } + paramValue->set_tensor_addr(tensor_data); + paramValue->set_tensor_size(shape_size * sizeof(int)); + } else { + MS_LOG(ERROR) << "Unsupport dataType," << node->name(); + return RET_ERROR; + } + auto abstract_tensor = std::make_shared(ms_data_ype, shape_vector); + parameter->set_abstract(abstract_tensor); + parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); + + std::vector param_shape; + (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(param_shape), + [](const int64_t &value) { return static_cast(value); }); + + MS_ASSERT(paramValue != nullptr); + paramValue->set_tensor_shape(param_shape); + paramValue->set_tensor_type(ms_data_ype->type_id()); + paramValue->set_format(schema::Format::Format_NHWC); + paramValue->set_tensor_size(shape_size * sizeof(int)); + parameter->set_default_param(paramValue); + } + return RET_OK; +} +STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size) { + if (output_size == 1) { + std::vector shape_vector; + anf_node->set_abstract(std::make_shared(kFloat32, shape_vector)); + anf_node_map.insert(std::pair(op->name(), anf_node)); + } else { + AbstractBasePtrList abstractList; + for (int output_idx = 0; output_idx < output_size; output_idx++) { + std::vector shape_vector; + abstractList.emplace_back(std::make_shared(kFloat32, shape_vector)); + auto tupleGetItemPrimPtr = GetTupleGetItemPrim(); + if (tupleGetItemPrimPtr == nullptr) { + MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; + return RET_NULL_PTR; + } + auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); + auto getItemValue = NewValueNode(MakeValue(output_idx)); + std::vector inputs{tupleGetItemPrim, anf_node, getItemValue}; + CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); + std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); + getItemCNode->set_fullname_with_scope(output_item_name); + anf_node_map.insert(std::pair(output_item_name, getItemCNode)); + } + anf_node->set_abstract(std::make_shared(abstractList)); + } + return RET_OK; +} +STATUS TFModelParser::ConvertOps() { + NoSupportOp::GetInstance()->SetFmkType("TENSORFLOW"); + STATUS status = RET_OK; + + // redirect identity to it's input0 + ClipIdentityAndStopGradient(); + int op_idx = 0; + for (int i = 0; i < tf_graph_def->node_size(); i++) { + auto node_def = tf_graph_def->mutable_node(i); + tf_node_map[node_def->name()] = node_def; + auto tf_op_type = node_def->op(); + if (tf_op_type == "Placeholder" || tf_op_type == "Const") { + continue; + } + auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(tf_op_type); + if (node_parser == nullptr) { + NoSupportOp::GetInstance()->InsertOp(tf_op_type); + status = (status == RET_OK ? RET_NOT_FIND_OP : status); + MS_LOG(ERROR) << "cannot find node parser:" << tf_op_type; + continue; + } + PrimitiveC *primitiveC = nullptr; + if (status == RET_OK) { + int output_size = 1; + status = node_parser->Parse(node_def, tf_graph_def, primitiveC, &output_size); + if (status != RET_OK) { + MS_LOG(ERROR) << "node " << tf_op_type.c_str() << " parser failed"; + continue; + } + std::vector opInputs = {NewValueNode(std::shared_ptr(primitiveC))}; + // parse inputs + for (int j = 0; j < node_def->input_size(); j++) { + auto input_node = tf_node_map[node_def->input(i)]; + // last node output + if (anf_node_map.find(input_node->name()) != anf_node_map.end()) { + opInputs.emplace_back(anf_node_map[input_node->name()]); + continue; + } + // const tensor + if (input_node->op() == "Const") { + ParameterPtr parameter; + if (ConvertConstTensor(input_node, parameter) != RET_OK) { + MS_LOG(ERROR) << "convert const tensor failed," << input_node->name(); + return RET_ERROR; + } + opInputs.emplace_back(parameter); + anf_node_map[parameter->fullname_with_scope()] = parameter; + continue; + } + MS_LOG(ERROR) << "node" << node_def->name() << "has inputs neither a node output nor a weight tensor."; + return RET_ERROR; + } + auto anf_node = funcGraphPtr->NewCNode(opInputs); + anf_node->set_fullname_with_scope(tf_op_type + "-" + std::to_string(op_idx++)); + + // parse outputs + status = ConvertOutputTensor(node_def, anf_node, output_size); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return status; + } + } + // redirect identity to it's input0 + ClipIdentityAndStopGradient(); + } + return RET_OK; +} +STATUS TFModelParser::ConvertGraphInputs() { + for (int i = 0; i < tf_graph_def->node_size(); i++) { + auto node_def = tf_graph_def->mutable_node(i); + tf_node_map[node_def->name()] = node_def; + if (node_def->op() == "Placeholder") { + auto parameter = funcGraphPtr->add_parameter(); + if (ConvertConstTensor(node_def, parameter) != RET_OK) { + MS_LOG(ERROR) << "convert const tensor failed"; + return RET_ERROR; + } + anf_node_map[node_def->name()] = parameter; + graph_input_names.emplace_back(node_def->name()); + } + } + return RET_OK; +} +STATUS TFModelParser::ConvertGraphOutputs() { return RET_OK; } + +std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { + if (node.op() != "Identity" && node.op() != "StopGradient") { + return node.name(); + } + auto tmpNode = node; + while (tmpNode.op() == "Identity" || tmpNode.op() == "StopGradient") { + tmpNode = *tf_node_map[tmpNode.input(0)]; + } + return tmpNode.name(); +} + +void TFModelParser::ClipIdentityAndStopGradient() { + for (auto &pair : tf_node_map) { + pair.second = tf_node_map[GetOriginInputName(*pair.second)]; + } +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h new file mode 100644 index 0000000000..1d3229c19c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H + +#include +#include +#include +#include +#include +#include "securec/include/securec.h" +#include "tools/common/tensor_util.h" +#include "tools/converter/model_parser.h" +#include "schema/inner/model_generated.h" +#include "proto/node_def.pb.h" +#include "proto/graph.pb.h" + +namespace mindspore { +namespace lite { +class TFModelParser { + public: + TFModelParser() = default; + ~TFModelParser() = default; + + FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); + + private: + STATUS ConvertConstTensor(const tensorflow::NodeDef *op, ParameterPtr parameter); + STATUS ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size); + STATUS ConvertOps(); + STATUS ConvertGraphInputs(); + STATUS ConvertGraphOutputs(); + + std::string GetOriginInputName(const tensorflow::NodeDef &node); + + void ClipIdentityAndStopGradient(); + + FuncGraphPtr funcGraphPtr; + std::unique_ptr tf_graph_def; + std::map tf_node_map; + std::unordered_map anf_node_map; + std::vector graph_input_names, graphOutputNames; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h new file mode 100644 index 0000000000..804f31e59d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H + +#include +#include +#include +#include "tools/converter/parser/tf/tf_util.h" +#include "proto/graph.pb.h" +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class TFNodeParser { + public: + TFNodeParser() = default; + + virtual ~TFNodeParser() = default; + + virtual STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr &tf_model, + PrimitiveC *primitiveC, int *output_size) { + return RET_OK; + } +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tf/tf_node_parser_registry.cc b/mindspore/lite/tools/converter/parser/tf/tf_node_parser_registry.cc new file mode 100644 index 0000000000..092f46146e --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_node_parser_registry.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * distributed under the License is distributed on an AS + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include +#include "src/common/log_adapter.h" + +namespace mindspore { +namespace lite { +TFNodeParserRegistry::~TFNodeParserRegistry() { + for (const auto &iter : parsers) { + delete iter.second; + } + this->parsers.clear(); +} + +TFNodeParserRegistry *TFNodeParserRegistry::GetInstance() { + static TFNodeParserRegistry instance; + return &instance; +} + +TFNodeParser *TFNodeParserRegistry::GetNodeParser(const std::string &name) { + auto it = parsers.find(name); + if (it != parsers.end()) { + return it->second; + } + return nullptr; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_node_parser_registry.h b/mindspore/lite/tools/converter/parser/tf/tf_node_parser_registry.h new file mode 100644 index 0000000000..bec92021ad --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_node_parser_registry.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H + +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFNodeParserRegistry { + public: + TFNodeParserRegistry() = default; + + virtual ~TFNodeParserRegistry(); + + static TFNodeParserRegistry *GetInstance(); + TFNodeParser *GetNodeParser(const std::string &name); + + std::unordered_map parsers; +}; + +class TFNodeRegistrar { + public: + TFNodeRegistrar(const std::string &name, TFNodeParser *parser) { + TFNodeParserRegistry::GetInstance()->parsers[name] = parser; + } + ~TFNodeRegistrar() = default; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.cc b/mindspore/lite/tools/converter/parser/tf/tf_util.cc new file mode 100644 index 0000000000..c7cb8fc161 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tf/tf_util.h" +#include +#include +#include +#include "google/protobuf/io/zero_copy_stream_impl.h" + +namespace mindspore { +namespace lite { +bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, + tensorflow::AttrValue *attr_value) { + const google::protobuf::Map &attr = nodeDef->attr(); + const google::protobuf::Map::const_iterator it = attr.find(attr_name); + if (it != attr.end()) { + *attr_value = it->second; + return true; + } + return false; +} + +bool TensorFlowUtils::TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message) { + std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); + if (!fs.is_open()) { + fprintf(stderr, "open failed %s\n", filepath); + return false; + } + + google::protobuf::io::IstreamInputStream input(&fs); + google::protobuf::io::CodedInputStream codedstr(&input); + + codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); + + bool success = message->ParseFromCodedStream(&codedstr); + + fs.close(); + + return success; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.h b/mindspore/lite/tools/converter/parser/tf/tf_util.h new file mode 100644 index 0000000000..21888388f7 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H + +#include +#include "proto/node_def.pb.h" +#include "ir/dtype/type_id.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +class TensorFlowUtils { + public: + static bool FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, + tensorflow::AttrValue *attr_value); + + static bool TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message); +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H diff --git a/mindspore/lite/tools/converter/parser/tf/types.proto b/mindspore/lite/tools/converter/parser/tf/types.proto new file mode 100644 index 0000000000..1beb2a1aa2 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/types.proto @@ -0,0 +1,66 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + + // TODO(josh11b): DT_GENERIC_PROTO = ??; + // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; +} +// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) diff --git a/mindspore/lite/tools/converter/parser/tf/versions.proto b/mindspore/lite/tools/converter/parser/tf/versions.proto new file mode 100644 index 0000000000..7d5e58ae7d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +};