| @@ -307,6 +307,7 @@ if (ENABLE_CONVERTER) | |||||
| tflite_parser_mid | tflite_parser_mid | ||||
| caffe_parser_mid | caffe_parser_mid | ||||
| onnx_parser_mid | onnx_parser_mid | ||||
| tf_parser_mid | |||||
| graph_pass_mid | graph_pass_mid | ||||
| fusion_mid | fusion_mid | ||||
| quantizer_mid | quantizer_mid | ||||
| @@ -61,6 +61,7 @@ add_subdirectory(../anf_exporter anf_exporter) | |||||
| add_subdirectory(parser/caffe) | add_subdirectory(parser/caffe) | ||||
| add_subdirectory(parser/tflite) | add_subdirectory(parser/tflite) | ||||
| add_subdirectory(parser/onnx) | add_subdirectory(parser/onnx) | ||||
| add_subdirectory(parser/tf) | |||||
| add_subdirectory(legacy_optimizer) | add_subdirectory(legacy_optimizer) | ||||
| add_subdirectory(quantizer) | add_subdirectory(quantizer) | ||||
| add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../core mindspore_core) | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../core mindspore_core) | ||||
| @@ -112,6 +113,7 @@ endif () | |||||
| file(GLOB PROTO_FILE "" | file(GLOB PROTO_FILE "" | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto | ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/*.proto | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) | ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) | ||||
| ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) | ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) | ||||
| add_library(proto_mid OBJECT ${PROTO_SRCS}) | add_library(proto_mid OBJECT ${PROTO_SRCS}) | ||||
| @@ -139,6 +141,7 @@ add_dependencies(converter_lite fbs_inner_src) | |||||
| target_link_libraries(converter_lite PRIVATE | target_link_libraries(converter_lite PRIVATE | ||||
| tflite_parser_mid | tflite_parser_mid | ||||
| tf_parser_mid | |||||
| caffe_parser_mid | caffe_parser_mid | ||||
| onnx_parser_mid | onnx_parser_mid | ||||
| anf_importer_mid | anf_importer_mid | ||||
| @@ -0,0 +1,7 @@ | |||||
| file(GLOB_RECURSE TF_SRC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) | |||||
| set_property(SOURCE ${TF_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||||
| add_library(tf_parser_mid OBJECT ${TF_SRC_LIST}) | |||||
| add_dependencies(tf_parser_mid proto_mid) | |||||
| @@ -0,0 +1,62 @@ | |||||
| syntax = "proto3"; | |||||
| package tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "AttrValueProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "tensor.proto"; | |||||
| import "tensor_shape.proto"; | |||||
| import "types.proto"; | |||||
| // Protocol buffer representing the value for an attr used to configure an Op. | |||||
| // Comment indicates the corresponding attr type. Only the field matching the | |||||
| // attr type may be filled. | |||||
| message AttrValue { | |||||
| // LINT.IfChange | |||||
| message ListValue { | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated DataType type = 6 [packed = true]; // "list(type)" | |||||
| repeated TensorShapeProto shape = 7; // "list(shape)" | |||||
| repeated TensorProto tensor = 8; // "list(tensor)" | |||||
| repeated NameAttrList func = 9; // "list(attr)" | |||||
| } | |||||
| // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) | |||||
| oneof value { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| DataType type = 6; // "type" | |||||
| TensorShapeProto shape = 7; // "shape" | |||||
| TensorProto tensor = 8; // "tensor" | |||||
| ListValue list = 1; // any "list(...)" | |||||
| // "func" represents a function. func.name is a function's name or | |||||
| // a primitive op's name. func.attr.first is the name of an attr | |||||
| // defined for that function. func.attr.second is the value for | |||||
| // that attr in the instantiation. | |||||
| NameAttrList func = 10; | |||||
| // This is a placeholder only used in anf_node_map defined inside a | |||||
| // function. It indicates the attr value will be supplied when | |||||
| // the function is instantiated. For example, let us suppose a | |||||
| // node "N" in function "FN". "N" has an attr "A" with value | |||||
| // placeholder = "foo". When FN is instantiated with attr "foo" | |||||
| // set to "bar", the instantiated node N's attr A will have been | |||||
| // given the value "bar". | |||||
| string placeholder = 9; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NameAttrList { | |||||
| string name = 1; | |||||
| map<string, AttrValue> attr = 2; | |||||
| } | |||||
| @@ -0,0 +1,101 @@ | |||||
| syntax = "proto3"; | |||||
| package tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "FunctionProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| import "node_def.proto"; | |||||
| import "op_def.proto"; | |||||
| // A library is a set of named functions. | |||||
| message FunctionDefLibrary { | |||||
| repeated FunctionDef function = 1; | |||||
| repeated GradientDef gradient = 2; | |||||
| } | |||||
| // A function can be instantiated when the runtime can bind every attr | |||||
| // with a value. When a GraphDef has a call to a function, it must | |||||
| // have binding for every attr defined in the signature. | |||||
| // | |||||
| // TODO(zhifengc): | |||||
| // * device spec, etc. | |||||
| message FunctionDef { | |||||
| // The definition of the function's name, arguments, return values, | |||||
| // attrs etc. | |||||
| OpDef signature = 1; | |||||
| // Attributes specific to this function definition. | |||||
| map<string, AttrValue> attr = 5; | |||||
| // NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21. | |||||
| // In both of the following fields, there is the need to specify an | |||||
| // output that is used as either the input to another node (in | |||||
| // `node_def`) or as a return value of the function (in `ret`). | |||||
| // Unlike the NodeDefs in GraphDef, we need to be able to specify a | |||||
| // list in some cases (instead of just single outputs). Also, we | |||||
| // need to be able to deal with lists of unknown length (so the | |||||
| // output index may not be known at function definition time). So | |||||
| // we use the following format instead: | |||||
| // * "fun_in" where "fun_in" is the name of a function input arg in | |||||
| // the `signature` field above. This represents that input, whether | |||||
| // it is a single tensor or a list. | |||||
| // * "fun_in:0" gives the first element of a function input arg (a | |||||
| // non-list input is considered a list of length 1 for these | |||||
| // purposes). | |||||
| // * "node:out" where "node" is the name of a node in `node_def` and | |||||
| // "out" is the name one of its op's output arguments (the name | |||||
| // comes from the OpDef of the node's op). This represents that | |||||
| // node's output, whether it is a single tensor or a list. | |||||
| // Note: We enforce that an op's output arguments are never | |||||
| // renamed in the backwards-compatibility test. | |||||
| // * "node:out:0" gives the first element of a node output arg (a | |||||
| // non-list output is considered a list of length 1 for these | |||||
| // purposes). | |||||
| // | |||||
| // NOT CURRENTLY SUPPORTED (but may be in the future): | |||||
| // * "node:out:-1" gives last element in a node output list | |||||
| // * "node:out:1:" gives a list with all but the first element in a | |||||
| // node output list | |||||
| // * "node:out::-1" gives a list with all but the last element in a | |||||
| // node output list | |||||
| // The body of the function. Unlike the NodeDefs in a GraphDef, attrs | |||||
| // may have values of type `placeholder` and the `input` field uses | |||||
| // the "output" format above. | |||||
| // By convention, "op" in node_def is resolved by consulting with a | |||||
| // user-defined library first. If not resolved, "func" is assumed to | |||||
| // be a builtin op. | |||||
| repeated NodeDef node_def = 3; | |||||
| // A mapping from the output arg names from `signature` to the | |||||
| // outputs from `node_def` that should be returned by the function. | |||||
| map<string, string> ret = 4; | |||||
| } | |||||
| // GradientDef defines the gradient function of a function defined in | |||||
| // a function library. | |||||
| // | |||||
| // A gradient function g (specified by gradient_func) for a function f | |||||
| // (specified by function_name) must follow the following: | |||||
| // | |||||
| // The function 'f' must be a numerical function which takes N inputs | |||||
| // and produces M outputs. Its gradient function 'g', which is a | |||||
| // function taking N + M inputs and produces N outputs. | |||||
| // | |||||
| // I.e. if we have | |||||
| // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | |||||
| // then, g is | |||||
| // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | |||||
| // dL/dy1, dL/dy2, ..., dL/dy_M), | |||||
| // where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | |||||
| // loss function). dL/dx_i is the partial derivative of L with respect | |||||
| // to x_i. | |||||
| message GradientDef { | |||||
| string function_name = 1; // The function name. | |||||
| string gradient_func = 2; // The gradient function's name. | |||||
| } | |||||
| @@ -0,0 +1,56 @@ | |||||
| syntax = "proto3"; | |||||
| package tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "GraphProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "node_def.proto"; | |||||
| import "function.proto"; | |||||
| import "versions.proto"; | |||||
| // Represents the graph of operations | |||||
| message GraphDef { | |||||
| repeated NodeDef node = 1; | |||||
| // Compatibility versions of the graph. See core/public/version.h for version | |||||
| // history. The GraphDef version is distinct from the TensorFlow version, and | |||||
| // each release of TensorFlow will support a range of GraphDef versions. | |||||
| VersionDef versions = 4; | |||||
| // Deprecated single version field; use versions above instead. Since all | |||||
| // GraphDef changes before "versions" was introduced were forward | |||||
| // compatible, this field is entirely ignored. | |||||
| int32 version = 3 [deprecated = true]; | |||||
| // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. | |||||
| // | |||||
| // "library" provides user-defined functions. | |||||
| // | |||||
| // Naming: | |||||
| // * library.function.name are in a flat namespace. | |||||
| // NOTE: We may need to change it to be hierarchical to support | |||||
| // different orgs. E.g., | |||||
| // { "/google/nn", { ... }}, | |||||
| // { "/google/vision", { ... }} | |||||
| // { "/org_foo/module_bar", { ... }} | |||||
| // map<string, FunctionDefLib> named_lib; | |||||
| // * If node[i].op is the name of one function in "library", | |||||
| // node[i] is deemed as a function call. Otherwise, node[i].op | |||||
| // must be a primitive operation supported by the runtime. | |||||
| // | |||||
| // | |||||
| // Function call semantics: | |||||
| // | |||||
| // * The callee may start execution as soon as some of its inputs | |||||
| // are ready. The caller may want to use Tuple() mechanism to | |||||
| // ensure all inputs are ready in the same time. | |||||
| // | |||||
| // * The consumer of return values may start executing as soon as | |||||
| // the return values the consumer depends on are ready. The | |||||
| // consumer may want to use Tuple() mechanism to ensure the | |||||
| // consumer does not start until all return values of the callee | |||||
| // function are ready. | |||||
| FunctionDefLibrary library = 2; | |||||
| }; | |||||
| @@ -0,0 +1,63 @@ | |||||
| syntax = "proto3"; | |||||
| package tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "NodeProto"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| message NodeDef { | |||||
| // The name given to this operator. Used for naming inputs, | |||||
| // logging, visualization, etc. Unique within a single GraphDef. | |||||
| // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". | |||||
| string name = 1; | |||||
| // The operation name. There may be custom parameters in attrs. | |||||
| // Op names starting with an underscore are reserved for internal use. | |||||
| string op = 2; | |||||
| // Each input is "node:src_output" with "node" being a string name and | |||||
| // "src_output" indicating which output tensor to use from "node". If | |||||
| // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs | |||||
| // may optionally be followed by control inputs that have the format | |||||
| // "^node". | |||||
| repeated string input = 3; | |||||
| // A (possibly partial) specification for the device on which this | |||||
| // node should be placed. | |||||
| // The expected syntax for this string is as follows: | |||||
| // | |||||
| // DEVICE_SPEC ::= PARTIAL_SPEC | |||||
| // | |||||
| // PARTIAL_SPEC ::= ("/" CONSTRAINT) * | |||||
| // CONSTRAINT ::= ("job:" JOB_NAME) | |||||
| // | ("replica:" [1-9][0-9]*) | |||||
| // | ("task:" [1-9][0-9]*) | |||||
| // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) | |||||
| // | |||||
| // Valid values for this string include: | |||||
| // * "/job:worker/replica:0/task:1/gpu:3" (full specification) | |||||
| // * "/job:worker/gpu:3" (partial specification) | |||||
| // * "" (no specification) | |||||
| // | |||||
| // If the constraints do not resolve to a single device (or if this | |||||
| // field is empty or not present), the runtime will attempt to | |||||
| // choose a device automatically. | |||||
| string device = 4; | |||||
| // Operation-specific graph-construction-time configuration. | |||||
| // Note that this should include all attrs defined in the | |||||
| // corresponding OpDef, including those with a value matching | |||||
| // the default -- this allows the default to change and makes | |||||
| // NodeDefs easier to interpret on their own. However, if | |||||
| // an attr with a default is not specified in this list, the | |||||
| // default will be used. | |||||
| // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and | |||||
| // one of the names from the corresponding OpDef's attr field). | |||||
| // The values must have a type matching the corresponding OpDef | |||||
| // attr's type field. | |||||
| // TODO(josh11b): Add some examples here showing best practices. | |||||
| map<string, AttrValue> attr = 5; | |||||
| }; | |||||
| @@ -0,0 +1,157 @@ | |||||
| syntax = "proto3"; | |||||
| package tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "OpDefProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| import "types.proto"; | |||||
| // Defines an operation. A NodeDef in a GraphDef specifies an Op by | |||||
| // using the "op" field which should match the name of a OpDef. | |||||
| message OpDef { | |||||
| // Op names starting with an underscore are reserved for internal use. | |||||
| // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". | |||||
| string name = 1; | |||||
| // For describing inputs and outputs. | |||||
| message ArgDef { | |||||
| // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". | |||||
| string name = 1; | |||||
| // Human readable description. | |||||
| string description = 2; | |||||
| // Describes the type of one or more tensors that are accepted/produced | |||||
| // by this input/output arg. The only legal combinations are: | |||||
| // * For a single tensor: either the "type" field is set or the | |||||
| // "type_attr" field is set to the name of an attr with type "type". | |||||
| // * For a sequence of tensors with the same type: the "number_attr" | |||||
| // field will be set to the name of an attr with type "int", and | |||||
| // either the "type" or "type_attr" field will be set as for | |||||
| // single tensors. | |||||
| // * For a sequence of tensors, the "type_list_attr" field will be set | |||||
| // to the name of an attr with type "list(type)". | |||||
| DataType type = 3; | |||||
| string type_attr = 4; // if specified, attr must have type "type" | |||||
| string number_attr = 5; // if specified, attr must have type "int" | |||||
| // If specified, attr must have type "list(type)", and none of | |||||
| // type, type_attr, and number_attr may be specified. | |||||
| string type_list_attr = 6; | |||||
| // For inputs: if true, the inputs are required to be refs. | |||||
| // By default, inputs can be either refs or non-refs. | |||||
| // For outputs: if true, outputs are refs, otherwise they are not. | |||||
| bool is_ref = 16; | |||||
| }; | |||||
| // Description of the input(s). | |||||
| repeated ArgDef input_arg = 2; | |||||
| // Description of the output(s). | |||||
| repeated ArgDef output_arg = 3; | |||||
| // Description of the graph-construction-time configuration of this | |||||
| // Op. That is to say, this describes the attr fields that will | |||||
| // be specified in the NodeDef. | |||||
| message AttrDef { | |||||
| // A descriptive name for the argument. May be used, e.g. by the | |||||
| // Python client, as a keyword argument name, and so should match | |||||
| // the regexp "[a-z][a-z0-9_]+". | |||||
| string name = 1; | |||||
| // One of the type names from attr_value.proto ("string", "list(string)", | |||||
| // "int", etc.). | |||||
| string type = 2; | |||||
| // A reasonable default for this attribute if the user does not supply | |||||
| // a value. If not specified, the user must supply a value. | |||||
| AttrValue default_value = 3; | |||||
| // Human-readable description. | |||||
| string description = 4; | |||||
| // TODO(josh11b): bool is_optional? | |||||
| // --- Constraints --- | |||||
| // These constraints are only in effect if specified. Default is no | |||||
| // constraints. | |||||
| // For type == "int", this is a minimum value. For "list(___)" | |||||
| // types, this is the minimum length. | |||||
| bool has_minimum = 5; | |||||
| int64 minimum = 6; | |||||
| // The set of allowed values. Has type that is the "list" version | |||||
| // of the "type" field above (uses the "list" field of AttrValue). | |||||
| // If type == "type" or "list(type)" above, then the "type" field | |||||
| // of "allowed_values.list" has the set of allowed DataTypes. | |||||
| // If type == "string" or "list(string)", then the "s" field of | |||||
| // "allowed_values.list" has the set of allowed strings. | |||||
| AttrValue allowed_values = 7; | |||||
| } | |||||
| repeated AttrDef attr = 4; | |||||
| // Optional deprecation based on GraphDef versions. | |||||
| OpDeprecation deprecation = 8; | |||||
| // One-line human-readable description of what the Op does. | |||||
| string summary = 5; | |||||
| // Additional, longer human-readable description of what the Op does. | |||||
| string description = 6; | |||||
| // ------------------------------------------------------------------------- | |||||
| // Which optimizations this operation can participate in. | |||||
| // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) | |||||
| bool is_commutative = 18; | |||||
| // If is_aggregate is true, then this operation accepts N >= 2 | |||||
| // inputs and produces 1 output all of the same type. Should be | |||||
| // associative and commutative, and produce output with the same | |||||
| // shape as the input. The optimizer may replace an aggregate op | |||||
| // taking input from multiple devices with a tree of aggregate ops | |||||
| // that aggregate locally within each device (and possibly within | |||||
| // groups of nearby devices) before communicating. | |||||
| // TODO(josh11b): Implement that optimization. | |||||
| bool is_aggregate = 16; // for things like add | |||||
| // Other optimizations go here, like | |||||
| // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. | |||||
| // ------------------------------------------------------------------------- | |||||
| // Optimization constraints. | |||||
| // By default Ops may be moved between devices. Stateful ops should | |||||
| // either not be moved, or should only be moved if that state can also | |||||
| // be moved (e.g. via some sort of save / restore). | |||||
| // Stateful ops are guaranteed to never be optimized away by Common | |||||
| // Subexpression Elimination (CSE). | |||||
| bool is_stateful = 17; // for things like variables, queue | |||||
| // ------------------------------------------------------------------------- | |||||
| // Non-standard options. | |||||
| // By default, all inputs to an Op must be initialized Tensors. Ops | |||||
| // that may initialize tensors for the first time should set this | |||||
| // field to true, to allow the Op to take an uninitialized Tensor as | |||||
| // input. | |||||
| bool allows_uninitialized_input = 19; // for Assign, etc. | |||||
| }; | |||||
| // Information about version-dependent deprecation of an op | |||||
| message OpDeprecation { | |||||
| // First GraphDef version at which the op is disallowed. | |||||
| int32 version = 1; | |||||
| // Explanation of why it was deprecated and what to use instead. | |||||
| string explanation = 2; | |||||
| }; | |||||
| // A collection of OpDefs | |||||
| message OpList { | |||||
| repeated OpDef op = 1; | |||||
| }; | |||||
| @@ -0,0 +1,29 @@ | |||||
| syntax = "proto3"; | |||||
| package tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "ResourceHandle"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // Protocol buffer representing a handle to a tensorflow resource. Handles are | |||||
| // not valid across executions, but can be serialized back and forth from within | |||||
| // a single run. | |||||
| message ResourceHandleProto { | |||||
| // Unique name for the device containing the resource. | |||||
| string device = 1; | |||||
| // Container in which this resource is placed. | |||||
| string container = 2; | |||||
| // Unique name of this resource. | |||||
| string name = 3; | |||||
| // Hash code for the type of the resource. Is only valid in the same device | |||||
| // and in the same execution. | |||||
| uint64 hash_code = 4; | |||||
| // For debug-only, the name of the type pointed to by this handle, if | |||||
| // available. | |||||
| string maybe_type_name = 5; | |||||
| }; | |||||
| @@ -0,0 +1,88 @@ | |||||
| syntax = "proto3"; | |||||
| package tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TensorProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "resource_handle.proto"; | |||||
| import "tensor_shape.proto"; | |||||
| import "types.proto"; | |||||
| // Protocol buffer representing a tensor. | |||||
| message TensorProto { | |||||
| DataType dtype = 1; | |||||
| // Shape of the tensor. TODO(touts): sort out the 0-rank issues. | |||||
| TensorShapeProto tensor_shape = 2; | |||||
| // Only one of the representations below is set, one of "tensor_contents" and | |||||
| // the "xxx_val" attributes. We are not using oneof because as oneofs cannot | |||||
| // contain repeated fields it would require another extra set of messages. | |||||
| // Version number. | |||||
| // | |||||
| // In version 0, if the "repeated xxx" representations contain only one | |||||
| // element, that element is repeated to fill the shape. This makes it easy | |||||
| // to represent a constant Tensor with a single value. | |||||
| int32 version_number = 3; | |||||
| // Serialized raw tensor content from either Tensor::AsProtoTensorContent or | |||||
| // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation | |||||
| // can be used for all tensor types. The purpose of this representation is to | |||||
| // reduce serialization overhead during RPC call by avoiding serialization of | |||||
| // many repeated small items. | |||||
| bytes tensor_content = 4; | |||||
| // Type specific representations that make it easy to create tensor protos in | |||||
| // all languages. Only the representation corresponding to "dtype" can | |||||
| // be set. The values hold the flattened representation of the tensor in | |||||
| // row major order. | |||||
| // DT_HALF. Note that since protobuf has no int16 type, we'll have some | |||||
| // pointless zero padding for each value here. | |||||
| repeated int32 half_val = 13 [packed = true]; | |||||
| // DT_FLOAT. | |||||
| repeated float float_val = 5 [packed = true]; | |||||
| // DT_DOUBLE. | |||||
| repeated double double_val = 6 [packed = true]; | |||||
| // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. | |||||
| repeated int32 int_val = 7 [packed = true]; | |||||
| // DT_STRING | |||||
| repeated bytes string_val = 8; | |||||
| // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real | |||||
| // and imaginary parts of i-th single precision complex. | |||||
| repeated float scomplex_val = 9 [packed = true]; | |||||
| // DT_INT64 | |||||
| repeated int64 int64_val = 10 [packed = true]; | |||||
| // DT_BOOL | |||||
| repeated bool bool_val = 11 [packed = true]; | |||||
| // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real | |||||
| // and imaginary parts of i-th double precision complex. | |||||
| repeated double dcomplex_val = 12 [packed = true]; | |||||
| // DT_RESOURCE | |||||
| repeated ResourceHandleProto resource_handle_val = 14; | |||||
| // DT_VARIANT | |||||
| repeated VariantTensorDataProto variant_val = 15; | |||||
| }; | |||||
| // Protocol buffer representing the serialization format of DT_VARIANT tensors. | |||||
| message VariantTensorDataProto { | |||||
| // Name of the type of objects being serialized. | |||||
| string type_name = 1; | |||||
| // Portions of the object that are not Tensors. | |||||
| bytes metadata = 2; | |||||
| // Tensors contained within objects being serialized. | |||||
| repeated TensorProto tensors = 3; | |||||
| } | |||||
| @@ -0,0 +1,45 @@ | |||||
| // Protocol buffer representing the shape of tensors. | |||||
| syntax = "proto3"; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TensorShapeProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| package tensorflow; | |||||
| // Dimensions of a tensor. | |||||
| message TensorShapeProto { | |||||
| // One dimension of the tensor. | |||||
| message Dim { | |||||
| // Size of the tensor in that dimension. | |||||
| // This value must be >= -1, but values of -1 are reserved for "unknown" | |||||
| // shapes (values of -1 mean "unknown" dimension). Certain wrappers | |||||
| // that work with TensorShapeProto may fail at runtime when deserializing | |||||
| // a TensorShapeProto containing a dim value of -1. | |||||
| int64 size = 1; | |||||
| // Optional name of the tensor dimension. | |||||
| string name = 2; | |||||
| }; | |||||
| // Dimensions of the tensor, such as {"input", 30}, {"output", 40} | |||||
| // for a 30 x 40 2D tensor. If an entry has size -1, this | |||||
| // corresponds to a dimension of unknown size. The names are | |||||
| // optional. | |||||
| // | |||||
| // The order of entries in "dim" matters: It indicates the layout of the | |||||
| // values in the tensor in-memory representation. | |||||
| // | |||||
| // The first entry in "dim" is the outermost dimension used to layout the | |||||
| // values, the last entry is the innermost dimension. This matches the | |||||
| // in-memory layout of RowMajor Eigen tensors. | |||||
| // | |||||
| // If "dim.size()" > 0, "unknown_rank" must be false. | |||||
| repeated Dim dim = 2; | |||||
| // If true, the number of dimensions in the shape is unknown. | |||||
| // | |||||
| // If true, "dim.size()" must be 0. | |||||
| bool unknown_rank = 3; | |||||
| }; | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_add_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFAddParser::Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||||
| PrimitiveC *primitiveC, int *output_size) { | |||||
| auto attr = std::make_unique<schema::PrimitiveT>(); | |||||
| attr->value.type = schema::PrimitiveType_Add; | |||||
| primitiveC = PrimitiveC::Create(attr.release()); | |||||
| MS_LOG(INFO) << "primitive name" << primitiveC->type_name(); | |||||
| return RET_OK; | |||||
| } | |||||
| TFNodeRegistrar g_tfAddParser("Add", new TFAddParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFAddParser : public TFNodeParser { | |||||
| public: | |||||
| TFAddParser() = default; | |||||
| ~TFAddParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||||
| PrimitiveC *primitiveC, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||||
| @@ -0,0 +1,286 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_model_parser.h" | |||||
| #include <map> | |||||
| #include <algorithm> | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "tools/converter/parser/tf/tf_util.h" | |||||
| #include "tools/common/graph_util.h" | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| #include "src/param_value_lite.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| auto status = ValidateFileStr(modelFile, ".prototxt"); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| if (!TensorFlowUtils::TfReadProtoFromBinary(modelFile.c_str(), tf_graph_def.get())) { | |||||
| MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return nullptr; | |||||
| } | |||||
| funcGraphPtr = std::make_shared<FuncGraph>(); | |||||
| status = ConvertGraphInputs(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert graph inputs failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| status = ConvertOps(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert ops failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| status = ConvertGraphOutputs(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert graph outputs failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return nullptr; | |||||
| } | |||||
| return funcGraphPtr; | |||||
| } | |||||
| STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef *node, ParameterPtr parameter) { | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { | |||||
| tensorflow::AttrValue data_type; | |||||
| tensorflow::DataType type = tensorflow::DT_FLOAT; | |||||
| // datatype | |||||
| if (TensorFlowUtils::FindAttrValue(node, "dtype", &data_type)) { | |||||
| type = data_type.type(); | |||||
| } | |||||
| const tensorflow::TensorProto &tensorProto = attr_value.tensor(); | |||||
| const tensorflow::TensorShapeProto &tensorShape = tensorProto.tensor_shape(); | |||||
| parameter = funcGraphPtr->add_parameter(); | |||||
| std::vector<int64_t> shape_vector; | |||||
| int shape_size = 1; | |||||
| shape_vector.resize(tensorShape.dim_size()); | |||||
| for (int i = 0; i < tensorShape.dim_size(); i++) { | |||||
| shape_vector[i] = tensorShape.dim(i).size(); | |||||
| shape_size *= shape_vector[i]; | |||||
| } | |||||
| // convert const to paramter | |||||
| TypePtr ms_data_ype; | |||||
| auto paramValue = std::make_shared<ParamValueLite>(); | |||||
| if (type == tensorflow::DT_FLOAT) { | |||||
| ms_data_ype = kFloat32; | |||||
| auto tensor_data = new (std::nothrow) float[shape_size]; | |||||
| if (tensorProto.float_val_size() == 1) { | |||||
| float value = tensorProto.float_val(0); | |||||
| for (int i = 0; i < shape_size; i++) { | |||||
| tensor_data[i] = value; | |||||
| } | |||||
| } | |||||
| if (tensorProto.tensor_content().size() == shape_size * sizeof(float)) { | |||||
| const auto addr = reinterpret_cast<const float *>(tensorProto.tensor_content().data()); | |||||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| paramValue->set_tensor_addr(tensor_data); | |||||
| paramValue->set_tensor_size(shape_size * sizeof(float)); | |||||
| } else if (type == tensorflow::DT_INT32) { | |||||
| ms_data_ype = kInt32; | |||||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||||
| if (tensorProto.int_val_size() == 1) { | |||||
| int value = tensorProto.int_val(0); | |||||
| for (int i = 0; i < shape_size; i++) { | |||||
| tensor_data[i] = value; | |||||
| } | |||||
| } | |||||
| if (tensorProto.tensor_content().size() == shape_size * sizeof(int32_t)) { | |||||
| const auto addr = reinterpret_cast<const int32_t *>(tensorProto.tensor_content().data()); | |||||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| paramValue->set_tensor_addr(tensor_data); | |||||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||||
| } else if (type == tensorflow::DT_BOOL) { | |||||
| ms_data_ype = kFloat32; | |||||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||||
| if (tensorProto.bool_val_size() == 1) { | |||||
| int value = tensorProto.bool_val(0); | |||||
| for (int i = 0; i < shape_size; i++) { | |||||
| tensor_data[i] = value; | |||||
| } | |||||
| } | |||||
| paramValue->set_tensor_addr(tensor_data); | |||||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupport dataType," << node->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(ms_data_ype, shape_vector); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); | |||||
| std::vector<int> param_shape; | |||||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(param_shape), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| MS_ASSERT(paramValue != nullptr); | |||||
| paramValue->set_tensor_shape(param_shape); | |||||
| paramValue->set_tensor_type(ms_data_ype->type_id()); | |||||
| paramValue->set_format(schema::Format::Format_NHWC); | |||||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||||
| parameter->set_default_param(paramValue); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size) { | |||||
| if (output_size == 1) { | |||||
| std::vector<int64_t> shape_vector; | |||||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||||
| anf_node_map.insert(std::pair(op->name(), anf_node)); | |||||
| } else { | |||||
| AbstractBasePtrList abstractList; | |||||
| for (int output_idx = 0; output_idx < output_size; output_idx++) { | |||||
| std::vector<int64_t> shape_vector; | |||||
| abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||||
| auto tupleGetItemPrimPtr = GetTupleGetItemPrim(); | |||||
| if (tupleGetItemPrimPtr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); | |||||
| auto getItemValue = NewValueNode(MakeValue<int>(output_idx)); | |||||
| std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; | |||||
| CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); | |||||
| std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | |||||
| getItemCNode->set_fullname_with_scope(output_item_name); | |||||
| anf_node_map.insert(std::pair(output_item_name, getItemCNode)); | |||||
| } | |||||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TFModelParser::ConvertOps() { | |||||
| NoSupportOp::GetInstance()->SetFmkType("TENSORFLOW"); | |||||
| STATUS status = RET_OK; | |||||
| // redirect identity to it's input0 | |||||
| ClipIdentityAndStopGradient(); | |||||
| int op_idx = 0; | |||||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||||
| auto node_def = tf_graph_def->mutable_node(i); | |||||
| tf_node_map[node_def->name()] = node_def; | |||||
| auto tf_op_type = node_def->op(); | |||||
| if (tf_op_type == "Placeholder" || tf_op_type == "Const") { | |||||
| continue; | |||||
| } | |||||
| auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(tf_op_type); | |||||
| if (node_parser == nullptr) { | |||||
| NoSupportOp::GetInstance()->InsertOp(tf_op_type); | |||||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||||
| MS_LOG(ERROR) << "cannot find node parser:" << tf_op_type; | |||||
| continue; | |||||
| } | |||||
| PrimitiveC *primitiveC = nullptr; | |||||
| if (status == RET_OK) { | |||||
| int output_size = 1; | |||||
| status = node_parser->Parse(node_def, tf_graph_def, primitiveC, &output_size); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "node " << tf_op_type.c_str() << " parser failed"; | |||||
| continue; | |||||
| } | |||||
| std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))}; | |||||
| // parse inputs | |||||
| for (int j = 0; j < node_def->input_size(); j++) { | |||||
| auto input_node = tf_node_map[node_def->input(i)]; | |||||
| // last node output | |||||
| if (anf_node_map.find(input_node->name()) != anf_node_map.end()) { | |||||
| opInputs.emplace_back(anf_node_map[input_node->name()]); | |||||
| continue; | |||||
| } | |||||
| // const tensor | |||||
| if (input_node->op() == "Const") { | |||||
| ParameterPtr parameter; | |||||
| if (ConvertConstTensor(input_node, parameter) != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert const tensor failed," << input_node->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| opInputs.emplace_back(parameter); | |||||
| anf_node_map[parameter->fullname_with_scope()] = parameter; | |||||
| continue; | |||||
| } | |||||
| MS_LOG(ERROR) << "node" << node_def->name() << "has inputs neither a node output nor a weight tensor."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto anf_node = funcGraphPtr->NewCNode(opInputs); | |||||
| anf_node->set_fullname_with_scope(tf_op_type + "-" + std::to_string(op_idx++)); | |||||
| // parse outputs | |||||
| status = ConvertOutputTensor(node_def, anf_node, output_size); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return status; | |||||
| } | |||||
| } | |||||
| // redirect identity to it's input0 | |||||
| ClipIdentityAndStopGradient(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TFModelParser::ConvertGraphInputs() { | |||||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||||
| auto node_def = tf_graph_def->mutable_node(i); | |||||
| tf_node_map[node_def->name()] = node_def; | |||||
| if (node_def->op() == "Placeholder") { | |||||
| auto parameter = funcGraphPtr->add_parameter(); | |||||
| if (ConvertConstTensor(node_def, parameter) != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert const tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| anf_node_map[node_def->name()] = parameter; | |||||
| graph_input_names.emplace_back(node_def->name()); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TFModelParser::ConvertGraphOutputs() { return RET_OK; } | |||||
| std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { | |||||
| if (node.op() != "Identity" && node.op() != "StopGradient") { | |||||
| return node.name(); | |||||
| } | |||||
| auto tmpNode = node; | |||||
| while (tmpNode.op() == "Identity" || tmpNode.op() == "StopGradient") { | |||||
| tmpNode = *tf_node_map[tmpNode.input(0)]; | |||||
| } | |||||
| return tmpNode.name(); | |||||
| } | |||||
| void TFModelParser::ClipIdentityAndStopGradient() { | |||||
| for (auto &pair : tf_node_map) { | |||||
| pair.second = tf_node_map[GetOriginInputName(*pair.second)]; | |||||
| } | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <unordered_map> | |||||
| #include "securec/include/securec.h" | |||||
| #include "tools/common/tensor_util.h" | |||||
| #include "tools/converter/model_parser.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "proto/node_def.pb.h" | |||||
| #include "proto/graph.pb.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFModelParser { | |||||
| public: | |||||
| TFModelParser() = default; | |||||
| ~TFModelParser() = default; | |||||
| FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); | |||||
| private: | |||||
| STATUS ConvertConstTensor(const tensorflow::NodeDef *op, ParameterPtr parameter); | |||||
| STATUS ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size); | |||||
| STATUS ConvertOps(); | |||||
| STATUS ConvertGraphInputs(); | |||||
| STATUS ConvertGraphOutputs(); | |||||
| std::string GetOriginInputName(const tensorflow::NodeDef &node); | |||||
| void ClipIdentityAndStopGradient(); | |||||
| FuncGraphPtr funcGraphPtr; | |||||
| std::unique_ptr<tensorflow::GraphDef> tf_graph_def; | |||||
| std::map<std::string, const tensorflow::NodeDef *> tf_node_map; | |||||
| std::unordered_map<std::string, AnfNodePtr> anf_node_map; | |||||
| std::vector<std::string> graph_input_names, graphOutputNames; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | |||||
| #include <string> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tf/tf_util.h" | |||||
| #include "proto/graph.pb.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFNodeParser { | |||||
| public: | |||||
| TFNodeParser() = default; | |||||
| virtual ~TFNodeParser() = default; | |||||
| virtual STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||||
| PrimitiveC *primitiveC, int *output_size) { | |||||
| return RET_OK; | |||||
| } | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * distributed under the License is distributed on an AS | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| #include <map> | |||||
| #include "src/common/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| TFNodeParserRegistry::~TFNodeParserRegistry() { | |||||
| for (const auto &iter : parsers) { | |||||
| delete iter.second; | |||||
| } | |||||
| this->parsers.clear(); | |||||
| } | |||||
| TFNodeParserRegistry *TFNodeParserRegistry::GetInstance() { | |||||
| static TFNodeParserRegistry instance; | |||||
| return &instance; | |||||
| } | |||||
| TFNodeParser *TFNodeParserRegistry::GetNodeParser(const std::string &name) { | |||||
| auto it = parsers.find(name); | |||||
| if (it != parsers.end()) { | |||||
| return it->second; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFNodeParserRegistry { | |||||
| public: | |||||
| TFNodeParserRegistry() = default; | |||||
| virtual ~TFNodeParserRegistry(); | |||||
| static TFNodeParserRegistry *GetInstance(); | |||||
| TFNodeParser *GetNodeParser(const std::string &name); | |||||
| std::unordered_map<std::string, TFNodeParser *> parsers; | |||||
| }; | |||||
| class TFNodeRegistrar { | |||||
| public: | |||||
| TFNodeRegistrar(const std::string &name, TFNodeParser *parser) { | |||||
| TFNodeParserRegistry::GetInstance()->parsers[name] = parser; | |||||
| } | |||||
| ~TFNodeRegistrar() = default; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H | |||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_util.h" | |||||
| #include <cstdio> | |||||
| #include <fstream> | |||||
| #include <string> | |||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, | |||||
| tensorflow::AttrValue *attr_value) { | |||||
| const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = nodeDef->attr(); | |||||
| const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | |||||
| if (it != attr.end()) { | |||||
| *attr_value = it->second; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool TensorFlowUtils::TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message) { | |||||
| std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); | |||||
| if (!fs.is_open()) { | |||||
| fprintf(stderr, "open failed %s\n", filepath); | |||||
| return false; | |||||
| } | |||||
| google::protobuf::io::IstreamInputStream input(&fs); | |||||
| google::protobuf::io::CodedInputStream codedstr(&input); | |||||
| codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); | |||||
| bool success = message->ParseFromCodedStream(&codedstr); | |||||
| fs.close(); | |||||
| return success; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H | |||||
| #include <string> | |||||
| #include "proto/node_def.pb.h" | |||||
| #include "ir/dtype/type_id.h" | |||||
| #include "include/errorcode.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TensorFlowUtils { | |||||
| public: | |||||
| static bool FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, | |||||
| tensorflow::AttrValue *attr_value); | |||||
| static bool TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message); | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H | |||||
| @@ -0,0 +1,66 @@ | |||||
| syntax = "proto3"; | |||||
| package tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TypesProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // LINT.IfChange | |||||
| enum DataType { | |||||
| // Not a legal value for DataType. Used to indicate a DataType field | |||||
| // has not been set. | |||||
| DT_INVALID = 0; | |||||
| // Data types that all computation devices are expected to be | |||||
| // capable to support. | |||||
| DT_FLOAT = 1; | |||||
| DT_DOUBLE = 2; | |||||
| DT_INT32 = 3; | |||||
| DT_UINT8 = 4; | |||||
| DT_INT16 = 5; | |||||
| DT_INT8 = 6; | |||||
| DT_STRING = 7; | |||||
| DT_COMPLEX64 = 8; // Single-precision complex | |||||
| DT_INT64 = 9; | |||||
| DT_BOOL = 10; | |||||
| DT_QINT8 = 11; // Quantized int8 | |||||
| DT_QUINT8 = 12; // Quantized uint8 | |||||
| DT_QINT32 = 13; // Quantized int32 | |||||
| DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. | |||||
| DT_QINT16 = 15; // Quantized int16 | |||||
| DT_QUINT16 = 16; // Quantized uint16 | |||||
| DT_UINT16 = 17; | |||||
| DT_COMPLEX128 = 18; // Double-precision complex | |||||
| DT_HALF = 19; | |||||
| DT_RESOURCE = 20; | |||||
| DT_VARIANT = 21; // Arbitrary C++ data types | |||||
| // TODO(josh11b): DT_GENERIC_PROTO = ??; | |||||
| // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? | |||||
| // Do not use! These are only for parameters. Every enum above | |||||
| // should have a corresponding value below (verified by types_test). | |||||
| DT_FLOAT_REF = 101; | |||||
| DT_DOUBLE_REF = 102; | |||||
| DT_INT32_REF = 103; | |||||
| DT_UINT8_REF = 104; | |||||
| DT_INT16_REF = 105; | |||||
| DT_INT8_REF = 106; | |||||
| DT_STRING_REF = 107; | |||||
| DT_COMPLEX64_REF = 108; | |||||
| DT_INT64_REF = 109; | |||||
| DT_BOOL_REF = 110; | |||||
| DT_QINT8_REF = 111; | |||||
| DT_QUINT8_REF = 112; | |||||
| DT_QINT32_REF = 113; | |||||
| DT_BFLOAT16_REF = 114; | |||||
| DT_QINT16_REF = 115; | |||||
| DT_QUINT16_REF = 116; | |||||
| DT_UINT16_REF = 117; | |||||
| DT_COMPLEX128_REF = 118; | |||||
| DT_HALF_REF = 119; | |||||
| DT_RESOURCE_REF = 120; | |||||
| DT_VARIANT_REF = 121; | |||||
| } | |||||
| // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) | |||||
| @@ -0,0 +1,31 @@ | |||||
| syntax = "proto3"; | |||||
| package tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "VersionsProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // Version information for a piece of serialized data | |||||
| // | |||||
| // There are different types of versions for each type of data | |||||
| // (GraphDef, etc.), but they all have the same common shape | |||||
| // described here. | |||||
| // | |||||
| // Each consumer has "consumer" and "min_producer" versions (specified | |||||
| // elsewhere). A consumer is allowed to consume this data if | |||||
| // | |||||
| // producer >= min_producer | |||||
| // consumer >= min_consumer | |||||
| // consumer not in bad_consumers | |||||
| // | |||||
| message VersionDef { | |||||
| // The version of the code that produced this data. | |||||
| int32 producer = 1; | |||||
| // Any consumer below this version is not allowed to consume this data. | |||||
| int32 min_consumer = 2; | |||||
| // Specific consumer versions which are disallowed (e.g. due to bugs). | |||||
| repeated int32 bad_consumers = 3; | |||||
| }; | |||||