From: @zhengjun10 Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiangtags/v1.1.0
| @@ -306,6 +306,7 @@ if (ENABLE_CONVERTER) | |||
| tflite_parser_mid | |||
| caffe_parser_mid | |||
| onnx_parser_mid | |||
| tf_parser_mid | |||
| graph_pass_mid | |||
| fusion_mid | |||
| quantizer_mid | |||
| @@ -61,6 +61,7 @@ add_subdirectory(../anf_exporter anf_exporter) | |||
| add_subdirectory(parser/caffe) | |||
| add_subdirectory(parser/tflite) | |||
| add_subdirectory(parser/onnx) | |||
| add_subdirectory(parser/tf) | |||
| add_subdirectory(legacy_optimizer) | |||
| add_subdirectory(quantizer) | |||
| add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../core mindspore_core) | |||
| @@ -111,6 +112,7 @@ endif () | |||
| file(GLOB PROTO_FILE "" | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/*.proto | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) | |||
| ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) | |||
| add_library(proto_mid OBJECT ${PROTO_SRCS}) | |||
| @@ -138,6 +140,7 @@ add_dependencies(converter_lite fbs_inner_src) | |||
| target_link_libraries(converter_lite PRIVATE | |||
| tflite_parser_mid | |||
| tf_parser_mid | |||
| caffe_parser_mid | |||
| onnx_parser_mid | |||
| anf_importer_mid | |||
| @@ -0,0 +1,7 @@ | |||
| file(GLOB_RECURSE TF_SRC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) | |||
| set_property(SOURCE ${TF_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| add_library(tf_parser_mid OBJECT ${TF_SRC_LIST}) | |||
| add_dependencies(tf_parser_mid proto_mid) | |||
| @@ -0,0 +1,62 @@ | |||
| syntax = "proto3"; | |||
| package tensorflow; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "AttrValueProtos"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| import "tensor.proto"; | |||
| import "tensor_shape.proto"; | |||
| import "types.proto"; | |||
| // Protocol buffer representing the value for an attr used to configure an Op. | |||
| // Comment indicates the corresponding attr type. Only the field matching the | |||
| // attr type may be filled. | |||
| message AttrValue { | |||
| // LINT.IfChange | |||
| message ListValue { | |||
| repeated bytes s = 2; // "list(string)" | |||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||
| repeated float f = 4 [packed = true]; // "list(float)" | |||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||
| repeated DataType type = 6 [packed = true]; // "list(type)" | |||
| repeated TensorShapeProto shape = 7; // "list(shape)" | |||
| repeated TensorProto tensor = 8; // "list(tensor)" | |||
| repeated NameAttrList func = 9; // "list(attr)" | |||
| } | |||
| // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) | |||
| oneof value { | |||
| bytes s = 2; // "string" | |||
| int64 i = 3; // "int" | |||
| float f = 4; // "float" | |||
| bool b = 5; // "bool" | |||
| DataType type = 6; // "type" | |||
| TensorShapeProto shape = 7; // "shape" | |||
| TensorProto tensor = 8; // "tensor" | |||
| ListValue list = 1; // any "list(...)" | |||
| // "func" represents a function. func.name is a function's name or | |||
| // a primitive op's name. func.attr.first is the name of an attr | |||
| // defined for that function. func.attr.second is the value for | |||
| // that attr in the instantiation. | |||
| NameAttrList func = 10; | |||
| // This is a placeholder only used in anf_node_map defined inside a | |||
| // function. It indicates the attr value will be supplied when | |||
| // the function is instantiated. For example, let us suppose a | |||
| // node "N" in function "FN". "N" has an attr "A" with value | |||
| // placeholder = "foo". When FN is instantiated with attr "foo" | |||
| // set to "bar", the instantiated node N's attr A will have been | |||
| // given the value "bar". | |||
| string placeholder = 9; | |||
| } | |||
| } | |||
| // A list of attr names and their values. The whole list is attached | |||
| // with a string name. E.g., MatMul[T=float]. | |||
| message NameAttrList { | |||
| string name = 1; | |||
| map<string, AttrValue> attr = 2; | |||
| } | |||
| @@ -0,0 +1,101 @@ | |||
| syntax = "proto3"; | |||
| package tensorflow; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "FunctionProtos"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| import "attr_value.proto"; | |||
| import "node_def.proto"; | |||
| import "op_def.proto"; | |||
| // A library is a set of named functions. | |||
| message FunctionDefLibrary { | |||
| repeated FunctionDef function = 1; | |||
| repeated GradientDef gradient = 2; | |||
| } | |||
| // A function can be instantiated when the runtime can bind every attr | |||
| // with a value. When a GraphDef has a call to a function, it must | |||
| // have binding for every attr defined in the signature. | |||
| // | |||
| // TODO(zhifengc): | |||
| // * device spec, etc. | |||
| message FunctionDef { | |||
| // The definition of the function's name, arguments, return values, | |||
| // attrs etc. | |||
| OpDef signature = 1; | |||
| // Attributes specific to this function definition. | |||
| map<string, AttrValue> attr = 5; | |||
| // NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21. | |||
| // In both of the following fields, there is the need to specify an | |||
| // output that is used as either the input to another node (in | |||
| // `node_def`) or as a return value of the function (in `ret`). | |||
| // Unlike the NodeDefs in GraphDef, we need to be able to specify a | |||
| // list in some cases (instead of just single outputs). Also, we | |||
| // need to be able to deal with lists of unknown length (so the | |||
| // output index may not be known at function definition time). So | |||
| // we use the following format instead: | |||
| // * "fun_in" where "fun_in" is the name of a function input arg in | |||
| // the `signature` field above. This represents that input, whether | |||
| // it is a single tensor or a list. | |||
| // * "fun_in:0" gives the first element of a function input arg (a | |||
| // non-list input is considered a list of length 1 for these | |||
| // purposes). | |||
| // * "node:out" where "node" is the name of a node in `node_def` and | |||
| // "out" is the name one of its op's output arguments (the name | |||
| // comes from the OpDef of the node's op). This represents that | |||
| // node's output, whether it is a single tensor or a list. | |||
| // Note: We enforce that an op's output arguments are never | |||
| // renamed in the backwards-compatibility test. | |||
| // * "node:out:0" gives the first element of a node output arg (a | |||
| // non-list output is considered a list of length 1 for these | |||
| // purposes). | |||
| // | |||
| // NOT CURRENTLY SUPPORTED (but may be in the future): | |||
| // * "node:out:-1" gives last element in a node output list | |||
| // * "node:out:1:" gives a list with all but the first element in a | |||
| // node output list | |||
| // * "node:out::-1" gives a list with all but the last element in a | |||
| // node output list | |||
| // The body of the function. Unlike the NodeDefs in a GraphDef, attrs | |||
| // may have values of type `placeholder` and the `input` field uses | |||
| // the "output" format above. | |||
| // By convention, "op" in node_def is resolved by consulting with a | |||
| // user-defined library first. If not resolved, "func" is assumed to | |||
| // be a builtin op. | |||
| repeated NodeDef node_def = 3; | |||
| // A mapping from the output arg names from `signature` to the | |||
| // outputs from `node_def` that should be returned by the function. | |||
| map<string, string> ret = 4; | |||
| } | |||
| // GradientDef defines the gradient function of a function defined in | |||
| // a function library. | |||
| // | |||
| // A gradient function g (specified by gradient_func) for a function f | |||
| // (specified by function_name) must follow the following: | |||
| // | |||
| // The function 'f' must be a numerical function which takes N inputs | |||
| // and produces M outputs. Its gradient function 'g', which is a | |||
| // function taking N + M inputs and produces N outputs. | |||
| // | |||
| // I.e. if we have | |||
| // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | |||
| // then, g is | |||
| // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | |||
| // dL/dy1, dL/dy2, ..., dL/dy_M), | |||
| // where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | |||
| // loss function). dL/dx_i is the partial derivative of L with respect | |||
| // to x_i. | |||
| message GradientDef { | |||
| string function_name = 1; // The function name. | |||
| string gradient_func = 2; // The gradient function's name. | |||
| } | |||
| @@ -0,0 +1,56 @@ | |||
| syntax = "proto3"; | |||
| package tensorflow; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "GraphProtos"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| import "node_def.proto"; | |||
| import "function.proto"; | |||
| import "versions.proto"; | |||
| // Represents the graph of operations | |||
| message GraphDef { | |||
| repeated NodeDef node = 1; | |||
| // Compatibility versions of the graph. See core/public/version.h for version | |||
| // history. The GraphDef version is distinct from the TensorFlow version, and | |||
| // each release of TensorFlow will support a range of GraphDef versions. | |||
| VersionDef versions = 4; | |||
| // Deprecated single version field; use versions above instead. Since all | |||
| // GraphDef changes before "versions" was introduced were forward | |||
| // compatible, this field is entirely ignored. | |||
| int32 version = 3 [deprecated = true]; | |||
| // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. | |||
| // | |||
| // "library" provides user-defined functions. | |||
| // | |||
| // Naming: | |||
| // * library.function.name are in a flat namespace. | |||
| // NOTE: We may need to change it to be hierarchical to support | |||
| // different orgs. E.g., | |||
| // { "/google/nn", { ... }}, | |||
| // { "/google/vision", { ... }} | |||
| // { "/org_foo/module_bar", { ... }} | |||
| // map<string, FunctionDefLib> named_lib; | |||
| // * If node[i].op is the name of one function in "library", | |||
| // node[i] is deemed as a function call. Otherwise, node[i].op | |||
| // must be a primitive operation supported by the runtime. | |||
| // | |||
| // | |||
| // Function call semantics: | |||
| // | |||
| // * The callee may start execution as soon as some of its inputs | |||
| // are ready. The caller may want to use Tuple() mechanism to | |||
| // ensure all inputs are ready in the same time. | |||
| // | |||
| // * The consumer of return values may start executing as soon as | |||
| // the return values the consumer depends on are ready. The | |||
| // consumer may want to use Tuple() mechanism to ensure the | |||
| // consumer does not start until all return values of the callee | |||
| // function are ready. | |||
| FunctionDefLibrary library = 2; | |||
| }; | |||
| @@ -0,0 +1,63 @@ | |||
| syntax = "proto3"; | |||
| package tensorflow; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "NodeProto"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| import "attr_value.proto"; | |||
| message NodeDef { | |||
| // The name given to this operator. Used for naming inputs, | |||
| // logging, visualization, etc. Unique within a single GraphDef. | |||
| // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". | |||
| string name = 1; | |||
| // The operation name. There may be custom parameters in attrs. | |||
| // Op names starting with an underscore are reserved for internal use. | |||
| string op = 2; | |||
| // Each input is "node:src_output" with "node" being a string name and | |||
| // "src_output" indicating which output tensor to use from "node". If | |||
| // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs | |||
| // may optionally be followed by control inputs that have the format | |||
| // "^node". | |||
| repeated string input = 3; | |||
| // A (possibly partial) specification for the device on which this | |||
| // node should be placed. | |||
| // The expected syntax for this string is as follows: | |||
| // | |||
| // DEVICE_SPEC ::= PARTIAL_SPEC | |||
| // | |||
| // PARTIAL_SPEC ::= ("/" CONSTRAINT) * | |||
| // CONSTRAINT ::= ("job:" JOB_NAME) | |||
| // | ("replica:" [1-9][0-9]*) | |||
| // | ("task:" [1-9][0-9]*) | |||
| // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) | |||
| // | |||
| // Valid values for this string include: | |||
| // * "/job:worker/replica:0/task:1/gpu:3" (full specification) | |||
| // * "/job:worker/gpu:3" (partial specification) | |||
| // * "" (no specification) | |||
| // | |||
| // If the constraints do not resolve to a single device (or if this | |||
| // field is empty or not present), the runtime will attempt to | |||
| // choose a device automatically. | |||
| string device = 4; | |||
| // Operation-specific graph-construction-time configuration. | |||
| // Note that this should include all attrs defined in the | |||
| // corresponding OpDef, including those with a value matching | |||
| // the default -- this allows the default to change and makes | |||
| // NodeDefs easier to interpret on their own. However, if | |||
| // an attr with a default is not specified in this list, the | |||
| // default will be used. | |||
| // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and | |||
| // one of the names from the corresponding OpDef's attr field). | |||
| // The values must have a type matching the corresponding OpDef | |||
| // attr's type field. | |||
| // TODO(josh11b): Add some examples here showing best practices. | |||
| map<string, AttrValue> attr = 5; | |||
| }; | |||
| @@ -0,0 +1,157 @@ | |||
| syntax = "proto3"; | |||
| package tensorflow; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "OpDefProtos"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| import "attr_value.proto"; | |||
| import "types.proto"; | |||
| // Defines an operation. A NodeDef in a GraphDef specifies an Op by | |||
| // using the "op" field which should match the name of a OpDef. | |||
| message OpDef { | |||
| // Op names starting with an underscore are reserved for internal use. | |||
| // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". | |||
| string name = 1; | |||
| // For describing inputs and outputs. | |||
| message ArgDef { | |||
| // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". | |||
| string name = 1; | |||
| // Human readable description. | |||
| string description = 2; | |||
| // Describes the type of one or more tensors that are accepted/produced | |||
| // by this input/output arg. The only legal combinations are: | |||
| // * For a single tensor: either the "type" field is set or the | |||
| // "type_attr" field is set to the name of an attr with type "type". | |||
| // * For a sequence of tensors with the same type: the "number_attr" | |||
| // field will be set to the name of an attr with type "int", and | |||
| // either the "type" or "type_attr" field will be set as for | |||
| // single tensors. | |||
| // * For a sequence of tensors, the "type_list_attr" field will be set | |||
| // to the name of an attr with type "list(type)". | |||
| DataType type = 3; | |||
| string type_attr = 4; // if specified, attr must have type "type" | |||
| string number_attr = 5; // if specified, attr must have type "int" | |||
| // If specified, attr must have type "list(type)", and none of | |||
| // type, type_attr, and number_attr may be specified. | |||
| string type_list_attr = 6; | |||
| // For inputs: if true, the inputs are required to be refs. | |||
| // By default, inputs can be either refs or non-refs. | |||
| // For outputs: if true, outputs are refs, otherwise they are not. | |||
| bool is_ref = 16; | |||
| }; | |||
| // Description of the input(s). | |||
| repeated ArgDef input_arg = 2; | |||
| // Description of the output(s). | |||
| repeated ArgDef output_arg = 3; | |||
| // Description of the graph-construction-time configuration of this | |||
| // Op. That is to say, this describes the attr fields that will | |||
| // be specified in the NodeDef. | |||
| message AttrDef { | |||
| // A descriptive name for the argument. May be used, e.g. by the | |||
| // Python client, as a keyword argument name, and so should match | |||
| // the regexp "[a-z][a-z0-9_]+". | |||
| string name = 1; | |||
| // One of the type names from attr_value.proto ("string", "list(string)", | |||
| // "int", etc.). | |||
| string type = 2; | |||
| // A reasonable default for this attribute if the user does not supply | |||
| // a value. If not specified, the user must supply a value. | |||
| AttrValue default_value = 3; | |||
| // Human-readable description. | |||
| string description = 4; | |||
| // TODO(josh11b): bool is_optional? | |||
| // --- Constraints --- | |||
| // These constraints are only in effect if specified. Default is no | |||
| // constraints. | |||
| // For type == "int", this is a minimum value. For "list(___)" | |||
| // types, this is the minimum length. | |||
| bool has_minimum = 5; | |||
| int64 minimum = 6; | |||
| // The set of allowed values. Has type that is the "list" version | |||
| // of the "type" field above (uses the "list" field of AttrValue). | |||
| // If type == "type" or "list(type)" above, then the "type" field | |||
| // of "allowed_values.list" has the set of allowed DataTypes. | |||
| // If type == "string" or "list(string)", then the "s" field of | |||
| // "allowed_values.list" has the set of allowed strings. | |||
| AttrValue allowed_values = 7; | |||
| } | |||
| repeated AttrDef attr = 4; | |||
| // Optional deprecation based on GraphDef versions. | |||
| OpDeprecation deprecation = 8; | |||
| // One-line human-readable description of what the Op does. | |||
| string summary = 5; | |||
| // Additional, longer human-readable description of what the Op does. | |||
| string description = 6; | |||
| // ------------------------------------------------------------------------- | |||
| // Which optimizations this operation can participate in. | |||
| // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) | |||
| bool is_commutative = 18; | |||
| // If is_aggregate is true, then this operation accepts N >= 2 | |||
| // inputs and produces 1 output all of the same type. Should be | |||
| // associative and commutative, and produce output with the same | |||
| // shape as the input. The optimizer may replace an aggregate op | |||
| // taking input from multiple devices with a tree of aggregate ops | |||
| // that aggregate locally within each device (and possibly within | |||
| // groups of nearby devices) before communicating. | |||
| // TODO(josh11b): Implement that optimization. | |||
| bool is_aggregate = 16; // for things like add | |||
| // Other optimizations go here, like | |||
| // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. | |||
| // ------------------------------------------------------------------------- | |||
| // Optimization constraints. | |||
| // By default Ops may be moved between devices. Stateful ops should | |||
| // either not be moved, or should only be moved if that state can also | |||
| // be moved (e.g. via some sort of save / restore). | |||
| // Stateful ops are guaranteed to never be optimized away by Common | |||
| // Subexpression Elimination (CSE). | |||
| bool is_stateful = 17; // for things like variables, queue | |||
| // ------------------------------------------------------------------------- | |||
| // Non-standard options. | |||
| // By default, all inputs to an Op must be initialized Tensors. Ops | |||
| // that may initialize tensors for the first time should set this | |||
| // field to true, to allow the Op to take an uninitialized Tensor as | |||
| // input. | |||
| bool allows_uninitialized_input = 19; // for Assign, etc. | |||
| }; | |||
| // Information about version-dependent deprecation of an op | |||
| message OpDeprecation { | |||
| // First GraphDef version at which the op is disallowed. | |||
| int32 version = 1; | |||
| // Explanation of why it was deprecated and what to use instead. | |||
| string explanation = 2; | |||
| }; | |||
| // A collection of OpDefs | |||
| message OpList { | |||
| repeated OpDef op = 1; | |||
| }; | |||
| @@ -0,0 +1,29 @@ | |||
| syntax = "proto3"; | |||
| package tensorflow; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "ResourceHandle"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| // Protocol buffer representing a handle to a tensorflow resource. Handles are | |||
| // not valid across executions, but can be serialized back and forth from within | |||
| // a single run. | |||
| message ResourceHandleProto { | |||
| // Unique name for the device containing the resource. | |||
| string device = 1; | |||
| // Container in which this resource is placed. | |||
| string container = 2; | |||
| // Unique name of this resource. | |||
| string name = 3; | |||
| // Hash code for the type of the resource. Is only valid in the same device | |||
| // and in the same execution. | |||
| uint64 hash_code = 4; | |||
| // For debug-only, the name of the type pointed to by this handle, if | |||
| // available. | |||
| string maybe_type_name = 5; | |||
| }; | |||
| @@ -0,0 +1,88 @@ | |||
| syntax = "proto3"; | |||
| package tensorflow; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "TensorProtos"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| import "resource_handle.proto"; | |||
| import "tensor_shape.proto"; | |||
| import "types.proto"; | |||
| // Protocol buffer representing a tensor. | |||
| message TensorProto { | |||
| DataType dtype = 1; | |||
| // Shape of the tensor. TODO(touts): sort out the 0-rank issues. | |||
| TensorShapeProto tensor_shape = 2; | |||
| // Only one of the representations below is set, one of "tensor_contents" and | |||
| // the "xxx_val" attributes. We are not using oneof because as oneofs cannot | |||
| // contain repeated fields it would require another extra set of messages. | |||
| // Version number. | |||
| // | |||
| // In version 0, if the "repeated xxx" representations contain only one | |||
| // element, that element is repeated to fill the shape. This makes it easy | |||
| // to represent a constant Tensor with a single value. | |||
| int32 version_number = 3; | |||
| // Serialized raw tensor content from either Tensor::AsProtoTensorContent or | |||
| // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation | |||
| // can be used for all tensor types. The purpose of this representation is to | |||
| // reduce serialization overhead during RPC call by avoiding serialization of | |||
| // many repeated small items. | |||
| bytes tensor_content = 4; | |||
| // Type specific representations that make it easy to create tensor protos in | |||
| // all languages. Only the representation corresponding to "dtype" can | |||
| // be set. The values hold the flattened representation of the tensor in | |||
| // row major order. | |||
| // DT_HALF. Note that since protobuf has no int16 type, we'll have some | |||
| // pointless zero padding for each value here. | |||
| repeated int32 half_val = 13 [packed = true]; | |||
| // DT_FLOAT. | |||
| repeated float float_val = 5 [packed = true]; | |||
| // DT_DOUBLE. | |||
| repeated double double_val = 6 [packed = true]; | |||
| // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. | |||
| repeated int32 int_val = 7 [packed = true]; | |||
| // DT_STRING | |||
| repeated bytes string_val = 8; | |||
| // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real | |||
| // and imaginary parts of i-th single precision complex. | |||
| repeated float scomplex_val = 9 [packed = true]; | |||
| // DT_INT64 | |||
| repeated int64 int64_val = 10 [packed = true]; | |||
| // DT_BOOL | |||
| repeated bool bool_val = 11 [packed = true]; | |||
| // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real | |||
| // and imaginary parts of i-th double precision complex. | |||
| repeated double dcomplex_val = 12 [packed = true]; | |||
| // DT_RESOURCE | |||
| repeated ResourceHandleProto resource_handle_val = 14; | |||
| // DT_VARIANT | |||
| repeated VariantTensorDataProto variant_val = 15; | |||
| }; | |||
| // Protocol buffer representing the serialization format of DT_VARIANT tensors. | |||
| message VariantTensorDataProto { | |||
| // Name of the type of objects being serialized. | |||
| string type_name = 1; | |||
| // Portions of the object that are not Tensors. | |||
| bytes metadata = 2; | |||
| // Tensors contained within objects being serialized. | |||
| repeated TensorProto tensors = 3; | |||
| } | |||
| @@ -0,0 +1,45 @@ | |||
| // Protocol buffer representing the shape of tensors. | |||
| syntax = "proto3"; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "TensorShapeProtos"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| package tensorflow; | |||
| // Dimensions of a tensor. | |||
| message TensorShapeProto { | |||
| // One dimension of the tensor. | |||
| message Dim { | |||
| // Size of the tensor in that dimension. | |||
| // This value must be >= -1, but values of -1 are reserved for "unknown" | |||
| // shapes (values of -1 mean "unknown" dimension). Certain wrappers | |||
| // that work with TensorShapeProto may fail at runtime when deserializing | |||
| // a TensorShapeProto containing a dim value of -1. | |||
| int64 size = 1; | |||
| // Optional name of the tensor dimension. | |||
| string name = 2; | |||
| }; | |||
| // Dimensions of the tensor, such as {"input", 30}, {"output", 40} | |||
| // for a 30 x 40 2D tensor. If an entry has size -1, this | |||
| // corresponds to a dimension of unknown size. The names are | |||
| // optional. | |||
| // | |||
| // The order of entries in "dim" matters: It indicates the layout of the | |||
| // values in the tensor in-memory representation. | |||
| // | |||
| // The first entry in "dim" is the outermost dimension used to layout the | |||
| // values, the last entry is the innermost dimension. This matches the | |||
| // in-memory layout of RowMajor Eigen tensors. | |||
| // | |||
| // If "dim.size()" > 0, "unknown_rank" must be false. | |||
| repeated Dim dim = 2; | |||
| // If true, the number of dimensions in the shape is unknown. | |||
| // | |||
| // If true, "dim.size()" must be 0. | |||
| bool unknown_rank = 3; | |||
| }; | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_add_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFAddParser::Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||
| PrimitiveC *primitiveC, int *output_size) { | |||
| auto attr = std::make_unique<schema::PrimitiveT>(); | |||
| attr->value.type = schema::PrimitiveType_Add; | |||
| primitiveC = PrimitiveC::Create(attr.release()); | |||
| MS_LOG(INFO) << "primitive name" << primitiveC->type_name(); | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfAddParser("Add", new TFAddParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||
| #include <memory> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFAddParser : public TFNodeParser { | |||
| public: | |||
| TFAddParser() = default; | |||
| ~TFAddParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||
| PrimitiveC *primitiveC, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||
| @@ -0,0 +1,286 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * distributed under the License is distributed on an AS | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_model_parser.h" | |||
| #include <map> | |||
| #include <algorithm> | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/parser/tf/tf_util.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "src/param_value_lite.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| auto status = ValidateFileStr(modelFile, ".prototxt"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| if (!TensorFlowUtils::TfReadProtoFromBinary(modelFile.c_str(), tf_graph_def.get())) { | |||
| MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| funcGraphPtr = std::make_shared<FuncGraph>(); | |||
| status = ConvertGraphInputs(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert graph inputs failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| status = ConvertOps(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert ops failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| status = ConvertGraphOutputs(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert graph outputs failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| return funcGraphPtr; | |||
| } | |||
| STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef *node, ParameterPtr parameter) { | |||
| tensorflow::AttrValue attr_value; | |||
| if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { | |||
| tensorflow::AttrValue data_type; | |||
| tensorflow::DataType type = tensorflow::DT_FLOAT; | |||
| // datatype | |||
| if (TensorFlowUtils::FindAttrValue(node, "dtype", &data_type)) { | |||
| type = data_type.type(); | |||
| } | |||
| const tensorflow::TensorProto &tensorProto = attr_value.tensor(); | |||
| const tensorflow::TensorShapeProto &tensorShape = tensorProto.tensor_shape(); | |||
| parameter = funcGraphPtr->add_parameter(); | |||
| std::vector<int64_t> shape_vector; | |||
| int shape_size = 1; | |||
| shape_vector.resize(tensorShape.dim_size()); | |||
| for (int i = 0; i < tensorShape.dim_size(); i++) { | |||
| shape_vector[i] = tensorShape.dim(i).size(); | |||
| shape_size *= shape_vector[i]; | |||
| } | |||
| // convert const to paramter | |||
| TypePtr ms_data_ype; | |||
| auto paramValue = std::make_shared<ParamValueLite>(); | |||
| if (type == tensorflow::DT_FLOAT) { | |||
| ms_data_ype = kFloat32; | |||
| auto tensor_data = new (std::nothrow) float[shape_size]; | |||
| if (tensorProto.float_val_size() == 1) { | |||
| float value = tensorProto.float_val(0); | |||
| for (int i = 0; i < shape_size; i++) { | |||
| tensor_data[i] = value; | |||
| } | |||
| } | |||
| if (tensorProto.tensor_content().size() == shape_size * sizeof(float)) { | |||
| const auto addr = reinterpret_cast<const float *>(tensorProto.tensor_content().data()); | |||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| paramValue->set_tensor_addr(tensor_data); | |||
| paramValue->set_tensor_size(shape_size * sizeof(float)); | |||
| } else if (type == tensorflow::DT_INT32) { | |||
| ms_data_ype = kInt32; | |||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||
| if (tensorProto.int_val_size() == 1) { | |||
| int value = tensorProto.int_val(0); | |||
| for (int i = 0; i < shape_size; i++) { | |||
| tensor_data[i] = value; | |||
| } | |||
| } | |||
| if (tensorProto.tensor_content().size() == shape_size * sizeof(int32_t)) { | |||
| const auto addr = reinterpret_cast<const int32_t *>(tensorProto.tensor_content().data()); | |||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| paramValue->set_tensor_addr(tensor_data); | |||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||
| } else if (type == tensorflow::DT_BOOL) { | |||
| ms_data_ype = kFloat32; | |||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||
| if (tensorProto.bool_val_size() == 1) { | |||
| int value = tensorProto.bool_val(0); | |||
| for (int i = 0; i < shape_size; i++) { | |||
| tensor_data[i] = value; | |||
| } | |||
| } | |||
| paramValue->set_tensor_addr(tensor_data); | |||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport dataType," << node->name(); | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(ms_data_ype, shape_vector); | |||
| parameter->set_abstract(abstract_tensor); | |||
| parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); | |||
| std::vector<int> param_shape; | |||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(param_shape), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| MS_ASSERT(paramValue != nullptr); | |||
| paramValue->set_tensor_shape(param_shape); | |||
| paramValue->set_tensor_type(ms_data_ype->type_id()); | |||
| paramValue->set_format(schema::Format::Format_NHWC); | |||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||
| parameter->set_default_param(paramValue); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size) { | |||
| if (output_size == 1) { | |||
| std::vector<int64_t> shape_vector; | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||
| anf_node_map.insert(std::pair(op->name(), anf_node)); | |||
| } else { | |||
| AbstractBasePtrList abstractList; | |||
| for (int output_idx = 0; output_idx < output_size; output_idx++) { | |||
| std::vector<int64_t> shape_vector; | |||
| abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||
| auto tupleGetItemPrimPtr = GetTupleGetItemPrim(); | |||
| if (tupleGetItemPrimPtr == nullptr) { | |||
| MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); | |||
| auto getItemValue = NewValueNode(MakeValue<int>(output_idx)); | |||
| std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; | |||
| CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); | |||
| std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | |||
| getItemCNode->set_fullname_with_scope(output_item_name); | |||
| anf_node_map.insert(std::pair(output_item_name, getItemCNode)); | |||
| } | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertOps() { | |||
| NoSupportOp::GetInstance()->SetFmkType("TENSORFLOW"); | |||
| STATUS status = RET_OK; | |||
| // redirect identity to it's input0 | |||
| ClipIdentityAndStopGradient(); | |||
| int op_idx = 0; | |||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||
| auto node_def = tf_graph_def->mutable_node(i); | |||
| tf_node_map[node_def->name()] = node_def; | |||
| auto tf_op_type = node_def->op(); | |||
| if (tf_op_type == "Placeholder" || tf_op_type == "Const") { | |||
| continue; | |||
| } | |||
| auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(tf_op_type); | |||
| if (node_parser == nullptr) { | |||
| NoSupportOp::GetInstance()->InsertOp(tf_op_type); | |||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||
| MS_LOG(ERROR) << "cannot find node parser:" << tf_op_type; | |||
| continue; | |||
| } | |||
| PrimitiveC *primitiveC = nullptr; | |||
| if (status == RET_OK) { | |||
| int output_size = 1; | |||
| status = node_parser->Parse(node_def, tf_graph_def, primitiveC, &output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "node " << tf_op_type.c_str() << " parser failed"; | |||
| continue; | |||
| } | |||
| std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))}; | |||
| // parse inputs | |||
| for (int j = 0; j < node_def->input_size(); j++) { | |||
| auto input_node = tf_node_map[node_def->input(i)]; | |||
| // last node output | |||
| if (anf_node_map.find(input_node->name()) != anf_node_map.end()) { | |||
| opInputs.emplace_back(anf_node_map[input_node->name()]); | |||
| continue; | |||
| } | |||
| // const tensor | |||
| if (input_node->op() == "Const") { | |||
| ParameterPtr parameter; | |||
| if (ConvertConstTensor(input_node, parameter) != RET_OK) { | |||
| MS_LOG(ERROR) << "convert const tensor failed," << input_node->name(); | |||
| return RET_ERROR; | |||
| } | |||
| opInputs.emplace_back(parameter); | |||
| anf_node_map[parameter->fullname_with_scope()] = parameter; | |||
| continue; | |||
| } | |||
| MS_LOG(ERROR) << "node" << node_def->name() << "has inputs neither a node output nor a weight tensor."; | |||
| return RET_ERROR; | |||
| } | |||
| auto anf_node = funcGraphPtr->NewCNode(opInputs); | |||
| anf_node->set_fullname_with_scope(tf_op_type + "-" + std::to_string(op_idx++)); | |||
| // parse outputs | |||
| status = ConvertOutputTensor(node_def, anf_node, output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return status; | |||
| } | |||
| } | |||
| // redirect identity to it's input0 | |||
| ClipIdentityAndStopGradient(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertGraphInputs() { | |||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||
| auto node_def = tf_graph_def->mutable_node(i); | |||
| tf_node_map[node_def->name()] = node_def; | |||
| if (node_def->op() == "Placeholder") { | |||
| auto parameter = funcGraphPtr->add_parameter(); | |||
| if (ConvertConstTensor(node_def, parameter) != RET_OK) { | |||
| MS_LOG(ERROR) << "convert const tensor failed"; | |||
| return RET_ERROR; | |||
| } | |||
| anf_node_map[node_def->name()] = parameter; | |||
| graph_input_names.emplace_back(node_def->name()); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertGraphOutputs() { return RET_OK; } | |||
| std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { | |||
| if (node.op() != "Identity" && node.op() != "StopGradient") { | |||
| return node.name(); | |||
| } | |||
| auto tmpNode = node; | |||
| while (tmpNode.op() == "Identity" || tmpNode.op() == "StopGradient") { | |||
| tmpNode = *tf_node_map[tmpNode.input(0)]; | |||
| } | |||
| return tmpNode.name(); | |||
| } | |||
| void TFModelParser::ClipIdentityAndStopGradient() { | |||
| for (auto &pair : tf_node_map) { | |||
| pair.second = tf_node_map[GetOriginInputName(*pair.second)]; | |||
| } | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include "securec/include/securec.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/model_parser.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "proto/node_def.pb.h" | |||
| #include "proto/graph.pb.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFModelParser { | |||
| public: | |||
| TFModelParser() = default; | |||
| ~TFModelParser() = default; | |||
| FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); | |||
| private: | |||
| STATUS ConvertConstTensor(const tensorflow::NodeDef *op, ParameterPtr parameter); | |||
| STATUS ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size); | |||
| STATUS ConvertOps(); | |||
| STATUS ConvertGraphInputs(); | |||
| STATUS ConvertGraphOutputs(); | |||
| std::string GetOriginInputName(const tensorflow::NodeDef &node); | |||
| void ClipIdentityAndStopGradient(); | |||
| FuncGraphPtr funcGraphPtr; | |||
| std::unique_ptr<tensorflow::GraphDef> tf_graph_def; | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_node_map; | |||
| std::unordered_map<std::string, AnfNodePtr> anf_node_map; | |||
| std::vector<std::string> graph_input_names, graphOutputNames; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | |||
| #include <string> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "tools/converter/parser/tf/tf_util.h" | |||
| #include "proto/graph.pb.h" | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFNodeParser { | |||
| public: | |||
| TFNodeParser() = default; | |||
| virtual ~TFNodeParser() = default; | |||
| virtual STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||
| PrimitiveC *primitiveC, int *output_size) { | |||
| return RET_OK; | |||
| } | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * distributed under the License is distributed on an AS | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include <map> | |||
| #include "src/common/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| TFNodeParserRegistry::~TFNodeParserRegistry() { | |||
| for (const auto &iter : parsers) { | |||
| delete iter.second; | |||
| } | |||
| this->parsers.clear(); | |||
| } | |||
| TFNodeParserRegistry *TFNodeParserRegistry::GetInstance() { | |||
| static TFNodeParserRegistry instance; | |||
| return &instance; | |||
| } | |||
| TFNodeParser *TFNodeParserRegistry::GetNodeParser(const std::string &name) { | |||
| auto it = parsers.find(name); | |||
| if (it != parsers.end()) { | |||
| return it->second; | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFNodeParserRegistry { | |||
| public: | |||
| TFNodeParserRegistry() = default; | |||
| virtual ~TFNodeParserRegistry(); | |||
| static TFNodeParserRegistry *GetInstance(); | |||
| TFNodeParser *GetNodeParser(const std::string &name); | |||
| std::unordered_map<std::string, TFNodeParser *> parsers; | |||
| }; | |||
| class TFNodeRegistrar { | |||
| public: | |||
| TFNodeRegistrar(const std::string &name, TFNodeParser *parser) { | |||
| TFNodeParserRegistry::GetInstance()->parsers[name] = parser; | |||
| } | |||
| ~TFNodeRegistrar() = default; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_REGISTRY_H | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_util.h" | |||
| #include <cstdio> | |||
| #include <fstream> | |||
| #include <string> | |||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, | |||
| tensorflow::AttrValue *attr_value) { | |||
| const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = nodeDef->attr(); | |||
| const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | |||
| if (it != attr.end()) { | |||
| *attr_value = it->second; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| bool TensorFlowUtils::TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message) { | |||
| std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); | |||
| if (!fs.is_open()) { | |||
| fprintf(stderr, "open failed %s\n", filepath); | |||
| return false; | |||
| } | |||
| google::protobuf::io::IstreamInputStream input(&fs); | |||
| google::protobuf::io::CodedInputStream codedstr(&input); | |||
| codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); | |||
| bool success = message->ParseFromCodedStream(&codedstr); | |||
| fs.close(); | |||
| return success; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H | |||
| #include <string> | |||
| #include "proto/node_def.pb.h" | |||
| #include "ir/dtype/type_id.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TensorFlowUtils { | |||
| public: | |||
| static bool FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, | |||
| tensorflow::AttrValue *attr_value); | |||
| static bool TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H | |||
| @@ -0,0 +1,66 @@ | |||
| syntax = "proto3"; | |||
| package tensorflow; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "TypesProtos"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| // LINT.IfChange | |||
| enum DataType { | |||
| // Not a legal value for DataType. Used to indicate a DataType field | |||
| // has not been set. | |||
| DT_INVALID = 0; | |||
| // Data types that all computation devices are expected to be | |||
| // capable to support. | |||
| DT_FLOAT = 1; | |||
| DT_DOUBLE = 2; | |||
| DT_INT32 = 3; | |||
| DT_UINT8 = 4; | |||
| DT_INT16 = 5; | |||
| DT_INT8 = 6; | |||
| DT_STRING = 7; | |||
| DT_COMPLEX64 = 8; // Single-precision complex | |||
| DT_INT64 = 9; | |||
| DT_BOOL = 10; | |||
| DT_QINT8 = 11; // Quantized int8 | |||
| DT_QUINT8 = 12; // Quantized uint8 | |||
| DT_QINT32 = 13; // Quantized int32 | |||
| DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. | |||
| DT_QINT16 = 15; // Quantized int16 | |||
| DT_QUINT16 = 16; // Quantized uint16 | |||
| DT_UINT16 = 17; | |||
| DT_COMPLEX128 = 18; // Double-precision complex | |||
| DT_HALF = 19; | |||
| DT_RESOURCE = 20; | |||
| DT_VARIANT = 21; // Arbitrary C++ data types | |||
| // TODO(josh11b): DT_GENERIC_PROTO = ??; | |||
| // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? | |||
| // Do not use! These are only for parameters. Every enum above | |||
| // should have a corresponding value below (verified by types_test). | |||
| DT_FLOAT_REF = 101; | |||
| DT_DOUBLE_REF = 102; | |||
| DT_INT32_REF = 103; | |||
| DT_UINT8_REF = 104; | |||
| DT_INT16_REF = 105; | |||
| DT_INT8_REF = 106; | |||
| DT_STRING_REF = 107; | |||
| DT_COMPLEX64_REF = 108; | |||
| DT_INT64_REF = 109; | |||
| DT_BOOL_REF = 110; | |||
| DT_QINT8_REF = 111; | |||
| DT_QUINT8_REF = 112; | |||
| DT_QINT32_REF = 113; | |||
| DT_BFLOAT16_REF = 114; | |||
| DT_QINT16_REF = 115; | |||
| DT_QUINT16_REF = 116; | |||
| DT_UINT16_REF = 117; | |||
| DT_COMPLEX128_REF = 118; | |||
| DT_HALF_REF = 119; | |||
| DT_RESOURCE_REF = 120; | |||
| DT_VARIANT_REF = 121; | |||
| } | |||
| // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) | |||
| @@ -0,0 +1,31 @@ | |||
| syntax = "proto3"; | |||
| package tensorflow; | |||
| option cc_enable_arenas = true; | |||
| option java_outer_classname = "VersionsProtos"; | |||
| option java_multiple_files = true; | |||
| option java_package = "org.tensorflow.framework"; | |||
| // Version information for a piece of serialized data | |||
| // | |||
| // There are different types of versions for each type of data | |||
| // (GraphDef, etc.), but they all have the same common shape | |||
| // described here. | |||
| // | |||
| // Each consumer has "consumer" and "min_producer" versions (specified | |||
| // elsewhere). A consumer is allowed to consume this data if | |||
| // | |||
| // producer >= min_producer | |||
| // consumer >= min_consumer | |||
| // consumer not in bad_consumers | |||
| // | |||
| message VersionDef { | |||
| // The version of the code that produced this data. | |||
| int32 producer = 1; | |||
| // Any consumer below this version is not allowed to consume this data. | |||
| int32 min_consumer = 2; | |||
| // Specific consumer versions which are disallowed (e.g. due to bugs). | |||
| repeated int32 bad_consumers = 3; | |||
| }; | |||