| @@ -39,6 +39,11 @@ namespace Tensorflow | |||
| return buffer._handle; | |||
| } | |||
| public static implicit operator byte[](Buffer buffer) | |||
| { | |||
| return buffer.Data; | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TF_DeleteBuffer(_handle); | |||
| @@ -38,6 +38,16 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||
| /// <summary> | |||
| /// Write out a serialized representation of `graph` (as a GraphDef protocol | |||
| /// message) to `output_graph_def` (allocated by TF_NewBuffer()). | |||
| /// </summary> | |||
| /// <param name="graph"></param> | |||
| /// <param name="output_graph_def"></param> | |||
| /// <param name="status"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphToGraphDef(IntPtr graph, IntPtr output_graph_def, IntPtr status); | |||
| /// <summary> | |||
| /// Returns the number of dimensions of the Tensor referenced by `output` | |||
| /// in `graph`. | |||
| @@ -26,15 +26,15 @@ namespace Tensorflow | |||
| public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); | |||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
| public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | |||
| public TF_Input[] OutputConsumers(int index, int max_consumers) | |||
| public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | |||
| { | |||
| IntPtr handle = IntPtr.Zero; | |||
| int size = Marshal.SizeOf<TF_Input>(); | |||
| int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), ref handle, max_consumers); | |||
| var handle = (TF_Input*)Marshal.AllocHGlobal(size); | |||
| int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | |||
| var consumers = new TF_Input[num]; | |||
| for(int i = 0; i < num; i++) | |||
| { | |||
| consumers[0] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||
| consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index); | |||
| } | |||
| return consumers; | |||
| @@ -112,7 +112,7 @@ namespace Tensorflow | |||
| /// <param name="max_consumers"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationOutputConsumers(TF_Output oper_out, ref IntPtr consumers, int max_consumers); | |||
| public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | |||
| @@ -0,0 +1,604 @@ | |||
| // <auto-generated> | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: function.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using pbr = global::Google.Protobuf.Reflection; | |||
| using scg = global::System.Collections.Generic; | |||
| namespace Tensorflow { | |||
| /// <summary>Holder for reflection information generated from function.proto</summary> | |||
| public static partial class FunctionReflection { | |||
| #region Descriptor | |||
| /// <summary>File descriptor for function.proto</summary> | |||
| public static pbr::FileDescriptor Descriptor { | |||
| get { return descriptor; } | |||
| } | |||
| private static pbr::FileDescriptor descriptor; | |||
| static FunctionReflection() { | |||
| byte[] descriptorData = global::System.Convert.FromBase64String( | |||
| string.Concat( | |||
| "Cg5mdW5jdGlvbi5wcm90bxIKdGVuc29yZmxvdxoQYXR0cl92YWx1ZS5wcm90", | |||
| "bxoObm9kZV9kZWYucHJvdG8aDG9wX2RlZi5wcm90byJqChJGdW5jdGlvbkRl", | |||
| "ZkxpYnJhcnkSKQoIZnVuY3Rpb24YASADKAsyFy50ZW5zb3JmbG93LkZ1bmN0", | |||
| "aW9uRGVmEikKCGdyYWRpZW50GAIgAygLMhcudGVuc29yZmxvdy5HcmFkaWVu", | |||
| "dERlZiKwAgoLRnVuY3Rpb25EZWYSJAoJc2lnbmF0dXJlGAEgASgLMhEudGVu", | |||
| "c29yZmxvdy5PcERlZhIvCgRhdHRyGAUgAygLMiEudGVuc29yZmxvdy5GdW5j", | |||
| "dGlvbkRlZi5BdHRyRW50cnkSJQoIbm9kZV9kZWYYAyADKAsyEy50ZW5zb3Jm", | |||
| "bG93Lk5vZGVEZWYSLQoDcmV0GAQgAygLMiAudGVuc29yZmxvdy5GdW5jdGlv", | |||
| "bkRlZi5SZXRFbnRyeRpCCglBdHRyRW50cnkSCwoDa2V5GAEgASgJEiQKBXZh", | |||
| "bHVlGAIgASgLMhUudGVuc29yZmxvdy5BdHRyVmFsdWU6AjgBGioKCFJldEVu", | |||
| "dHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoCToCOAFKBAgCEAMiOwoL", | |||
| "R3JhZGllbnREZWYSFQoNZnVuY3Rpb25fbmFtZRgBIAEoCRIVCg1ncmFkaWVu", | |||
| "dF9mdW5jGAIgASgJQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IORnVu", | |||
| "Y3Rpb25Qcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZs", | |||
| "b3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, }, | |||
| new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDefLibrary), global::Tensorflow.FunctionDefLibrary.Parser, new[]{ "Function", "Gradient" }, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef), global::Tensorflow.FunctionDef.Parser, new[]{ "Signature", "Attr", "NodeDef", "Ret" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, null, }), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GradientDef), global::Tensorflow.GradientDef.Parser, new[]{ "FunctionName", "GradientFunc" }, null, null, null) | |||
| })); | |||
| } | |||
| #endregion | |||
| } | |||
| #region Messages | |||
| /// <summary> | |||
| /// A library is a set of named functions. | |||
| /// </summary> | |||
| public sealed partial class FunctionDefLibrary : pb::IMessage<FunctionDefLibrary> { | |||
| private static readonly pb::MessageParser<FunctionDefLibrary> _parser = new pb::MessageParser<FunctionDefLibrary>(() => new FunctionDefLibrary()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pb::MessageParser<FunctionDefLibrary> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public FunctionDefLibrary() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public FunctionDefLibrary(FunctionDefLibrary other) : this() { | |||
| function_ = other.function_.Clone(); | |||
| gradient_ = other.gradient_.Clone(); | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public FunctionDefLibrary Clone() { | |||
| return new FunctionDefLibrary(this); | |||
| } | |||
| /// <summary>Field number for the "function" field.</summary> | |||
| public const int FunctionFieldNumber = 1; | |||
| private static readonly pb::FieldCodec<global::Tensorflow.FunctionDef> _repeated_function_codec | |||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.FunctionDef.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.FunctionDef> function_ = new pbc::RepeatedField<global::Tensorflow.FunctionDef>(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<global::Tensorflow.FunctionDef> Function { | |||
| get { return function_; } | |||
| } | |||
| /// <summary>Field number for the "gradient" field.</summary> | |||
| public const int GradientFieldNumber = 2; | |||
| private static readonly pb::FieldCodec<global::Tensorflow.GradientDef> _repeated_gradient_codec | |||
| = pb::FieldCodec.ForMessage(18, global::Tensorflow.GradientDef.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.GradientDef> gradient_ = new pbc::RepeatedField<global::Tensorflow.GradientDef>(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<global::Tensorflow.GradientDef> Gradient { | |||
| get { return gradient_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as FunctionDefLibrary); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool Equals(FunctionDefLibrary other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if(!function_.Equals(other.function_)) return false; | |||
| if(!gradient_.Equals(other.gradient_)) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= function_.GetHashCode(); | |||
| hash ^= gradient_.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| function_.WriteTo(output, _repeated_function_codec); | |||
| gradient_.WriteTo(output, _repeated_gradient_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += function_.CalculateSize(_repeated_function_codec); | |||
| size += gradient_.CalculateSize(_repeated_gradient_codec); | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(FunctionDefLibrary other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| function_.Add(other.function_); | |||
| gradient_.Add(other.gradient_); | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| function_.AddEntriesFrom(input, _repeated_function_codec); | |||
| break; | |||
| } | |||
| case 18: { | |||
| gradient_.AddEntriesFrom(input, _repeated_gradient_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// A function can be instantiated when the runtime can bind every attr | |||
| /// with a value. When a GraphDef has a call to a function, it must | |||
| /// have binding for every attr defined in the signature. | |||
| /// | |||
| /// TODO(zhifengc): | |||
| /// * device spec, etc. | |||
| /// </summary> | |||
| public sealed partial class FunctionDef : pb::IMessage<FunctionDef> { | |||
| private static readonly pb::MessageParser<FunctionDef> _parser = new pb::MessageParser<FunctionDef>(() => new FunctionDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pb::MessageParser<FunctionDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[1]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public FunctionDef() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public FunctionDef(FunctionDef other) : this() { | |||
| signature_ = other.signature_ != null ? other.signature_.Clone() : null; | |||
| attr_ = other.attr_.Clone(); | |||
| nodeDef_ = other.nodeDef_.Clone(); | |||
| ret_ = other.ret_.Clone(); | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public FunctionDef Clone() { | |||
| return new FunctionDef(this); | |||
| } | |||
| /// <summary>Field number for the "signature" field.</summary> | |||
| public const int SignatureFieldNumber = 1; | |||
| private global::Tensorflow.OpDef signature_; | |||
| /// <summary> | |||
| /// The definition of the function's name, arguments, return values, | |||
| /// attrs etc. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public global::Tensorflow.OpDef Signature { | |||
| get { return signature_; } | |||
| set { | |||
| signature_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "attr" field.</summary> | |||
| public const int AttrFieldNumber = 5; | |||
| private static readonly pbc::MapField<string, global::Tensorflow.AttrValue>.Codec _map_attr_codec | |||
| = new pbc::MapField<string, global::Tensorflow.AttrValue>.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 42); | |||
| private readonly pbc::MapField<string, global::Tensorflow.AttrValue> attr_ = new pbc::MapField<string, global::Tensorflow.AttrValue>(); | |||
| /// <summary> | |||
| /// Attributes specific to this function definition. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::MapField<string, global::Tensorflow.AttrValue> Attr { | |||
| get { return attr_; } | |||
| } | |||
| /// <summary>Field number for the "node_def" field.</summary> | |||
| public const int NodeDefFieldNumber = 3; | |||
| private static readonly pb::FieldCodec<global::Tensorflow.NodeDef> _repeated_nodeDef_codec | |||
| = pb::FieldCodec.ForMessage(26, global::Tensorflow.NodeDef.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.NodeDef> nodeDef_ = new pbc::RepeatedField<global::Tensorflow.NodeDef>(); | |||
| /// <summary> | |||
| /// By convention, "op" in node_def is resolved by consulting with a | |||
| /// user-defined library first. If not resolved, "func" is assumed to | |||
| /// be a builtin op. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<global::Tensorflow.NodeDef> NodeDef { | |||
| get { return nodeDef_; } | |||
| } | |||
| /// <summary>Field number for the "ret" field.</summary> | |||
| public const int RetFieldNumber = 4; | |||
| private static readonly pbc::MapField<string, string>.Codec _map_ret_codec | |||
| = new pbc::MapField<string, string>.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForString(18), 34); | |||
| private readonly pbc::MapField<string, string> ret_ = new pbc::MapField<string, string>(); | |||
| /// <summary> | |||
| /// A mapping from the output arg names from `signature` to the | |||
| /// outputs from `node_def` that should be returned by the function. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::MapField<string, string> Ret { | |||
| get { return ret_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as FunctionDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool Equals(FunctionDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if (!object.Equals(Signature, other.Signature)) return false; | |||
| if (!Attr.Equals(other.Attr)) return false; | |||
| if(!nodeDef_.Equals(other.nodeDef_)) return false; | |||
| if (!Ret.Equals(other.Ret)) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (signature_ != null) hash ^= Signature.GetHashCode(); | |||
| hash ^= Attr.GetHashCode(); | |||
| hash ^= nodeDef_.GetHashCode(); | |||
| hash ^= Ret.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| if (signature_ != null) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(Signature); | |||
| } | |||
| nodeDef_.WriteTo(output, _repeated_nodeDef_codec); | |||
| ret_.WriteTo(output, _map_ret_codec); | |||
| attr_.WriteTo(output, _map_attr_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (signature_ != null) { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(Signature); | |||
| } | |||
| size += attr_.CalculateSize(_map_attr_codec); | |||
| size += nodeDef_.CalculateSize(_repeated_nodeDef_codec); | |||
| size += ret_.CalculateSize(_map_ret_codec); | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(FunctionDef other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| if (other.signature_ != null) { | |||
| if (signature_ == null) { | |||
| signature_ = new global::Tensorflow.OpDef(); | |||
| } | |||
| Signature.MergeFrom(other.Signature); | |||
| } | |||
| attr_.Add(other.attr_); | |||
| nodeDef_.Add(other.nodeDef_); | |||
| ret_.Add(other.ret_); | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| if (signature_ == null) { | |||
| signature_ = new global::Tensorflow.OpDef(); | |||
| } | |||
| input.ReadMessage(signature_); | |||
| break; | |||
| } | |||
| case 26: { | |||
| nodeDef_.AddEntriesFrom(input, _repeated_nodeDef_codec); | |||
| break; | |||
| } | |||
| case 34: { | |||
| ret_.AddEntriesFrom(input, _map_ret_codec); | |||
| break; | |||
| } | |||
| case 42: { | |||
| attr_.AddEntriesFrom(input, _map_attr_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// GradientDef defines the gradient function of a function defined in | |||
| /// a function library. | |||
| /// | |||
| /// A gradient function g (specified by gradient_func) for a function f | |||
| /// (specified by function_name) must follow the following: | |||
| /// | |||
| /// The function 'f' must be a numerical function which takes N inputs | |||
| /// and produces M outputs. Its gradient function 'g', which is a | |||
| /// function taking N + M inputs and produces N outputs. | |||
| /// | |||
| /// I.e. if we have | |||
| /// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | |||
| /// then, g is | |||
| /// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | |||
| /// dL/dy1, dL/dy2, ..., dL/dy_M), | |||
| /// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | |||
| /// loss function). dL/dx_i is the partial derivative of L with respect | |||
| /// to x_i. | |||
| /// </summary> | |||
| public sealed partial class GradientDef : pb::IMessage<GradientDef> { | |||
| private static readonly pb::MessageParser<GradientDef> _parser = new pb::MessageParser<GradientDef>(() => new GradientDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pb::MessageParser<GradientDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[2]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public GradientDef() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public GradientDef(GradientDef other) : this() { | |||
| functionName_ = other.functionName_; | |||
| gradientFunc_ = other.gradientFunc_; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public GradientDef Clone() { | |||
| return new GradientDef(this); | |||
| } | |||
| /// <summary>Field number for the "function_name" field.</summary> | |||
| public const int FunctionNameFieldNumber = 1; | |||
| private string functionName_ = ""; | |||
| /// <summary> | |||
| /// The function name. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public string FunctionName { | |||
| get { return functionName_; } | |||
| set { | |||
| functionName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
| } | |||
| } | |||
| /// <summary>Field number for the "gradient_func" field.</summary> | |||
| public const int GradientFuncFieldNumber = 2; | |||
| private string gradientFunc_ = ""; | |||
| /// <summary> | |||
| /// The gradient function's name. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public string GradientFunc { | |||
| get { return gradientFunc_; } | |||
| set { | |||
| gradientFunc_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as GradientDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool Equals(GradientDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if (FunctionName != other.FunctionName) return false; | |||
| if (GradientFunc != other.GradientFunc) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (FunctionName.Length != 0) hash ^= FunctionName.GetHashCode(); | |||
| if (GradientFunc.Length != 0) hash ^= GradientFunc.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| if (FunctionName.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(FunctionName); | |||
| } | |||
| if (GradientFunc.Length != 0) { | |||
| output.WriteRawTag(18); | |||
| output.WriteString(GradientFunc); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (FunctionName.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(FunctionName); | |||
| } | |||
| if (GradientFunc.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(GradientFunc); | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(GradientDef other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| if (other.FunctionName.Length != 0) { | |||
| FunctionName = other.FunctionName; | |||
| } | |||
| if (other.GradientFunc.Length != 0) { | |||
| GradientFunc = other.GradientFunc; | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| FunctionName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| GradientFunc = input.ReadString(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| #endregion Designer generated code | |||
| @@ -0,0 +1,309 @@ | |||
| // <auto-generated> | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: graph.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using pbr = global::Google.Protobuf.Reflection; | |||
| using scg = global::System.Collections.Generic; | |||
| namespace Tensorflow { | |||
| /// <summary>Holder for reflection information generated from graph.proto</summary> | |||
| public static partial class GraphReflection { | |||
| #region Descriptor | |||
| /// <summary>File descriptor for graph.proto</summary> | |||
| public static pbr::FileDescriptor Descriptor { | |||
| get { return descriptor; } | |||
| } | |||
| private static pbr::FileDescriptor descriptor; | |||
| static GraphReflection() { | |||
| byte[] descriptorData = global::System.Convert.FromBase64String( | |||
| string.Concat( | |||
| "CgtncmFwaC5wcm90bxIKdGVuc29yZmxvdxoObm9kZV9kZWYucHJvdG8aDmZ1", | |||
| "bmN0aW9uLnByb3RvGg52ZXJzaW9ucy5wcm90byKdAQoIR3JhcGhEZWYSIQoE", | |||
| "bm9kZRgBIAMoCzITLnRlbnNvcmZsb3cuTm9kZURlZhIoCgh2ZXJzaW9ucxgE", | |||
| "IAEoCzIWLnRlbnNvcmZsb3cuVmVyc2lvbkRlZhITCgd2ZXJzaW9uGAMgASgF", | |||
| "QgIYARIvCgdsaWJyYXJ5GAIgASgLMh4udGVuc29yZmxvdy5GdW5jdGlvbkRl", | |||
| "ZkxpYnJhcnlCawoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQgtHcmFwaFBy", | |||
| "b3Rvc1ABWj1naXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5z", | |||
| "b3JmbG93L2dvL2NvcmUvZnJhbWV3b3Jr+AEBYgZwcm90bzM=")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.FunctionReflection.Descriptor, global::Tensorflow.VersionsReflection.Descriptor, }, | |||
| new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphDef), global::Tensorflow.GraphDef.Parser, new[]{ "Node", "Versions", "Version", "Library" }, null, null, null) | |||
| })); | |||
| } | |||
| #endregion | |||
| } | |||
| #region Messages | |||
| /// <summary> | |||
| /// Represents the graph of operations | |||
| /// </summary> | |||
| public sealed partial class GraphDef : pb::IMessage<GraphDef> { | |||
| private static readonly pb::MessageParser<GraphDef> _parser = new pb::MessageParser<GraphDef>(() => new GraphDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pb::MessageParser<GraphDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.GraphReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public GraphDef() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public GraphDef(GraphDef other) : this() { | |||
| node_ = other.node_.Clone(); | |||
| versions_ = other.versions_ != null ? other.versions_.Clone() : null; | |||
| version_ = other.version_; | |||
| library_ = other.library_ != null ? other.library_.Clone() : null; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public GraphDef Clone() { | |||
| return new GraphDef(this); | |||
| } | |||
| /// <summary>Field number for the "node" field.</summary> | |||
| public const int NodeFieldNumber = 1; | |||
| private static readonly pb::FieldCodec<global::Tensorflow.NodeDef> _repeated_node_codec | |||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.NodeDef.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.NodeDef> node_ = new pbc::RepeatedField<global::Tensorflow.NodeDef>(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<global::Tensorflow.NodeDef> Node { | |||
| get { return node_; } | |||
| } | |||
| /// <summary>Field number for the "versions" field.</summary> | |||
| public const int VersionsFieldNumber = 4; | |||
| private global::Tensorflow.VersionDef versions_; | |||
| /// <summary> | |||
| /// Compatibility versions of the graph. See core/public/version.h for version | |||
| /// history. The GraphDef version is distinct from the TensorFlow version, and | |||
| /// each release of TensorFlow will support a range of GraphDef versions. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public global::Tensorflow.VersionDef Versions { | |||
| get { return versions_; } | |||
| set { | |||
| versions_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "version" field.</summary> | |||
| public const int VersionFieldNumber = 3; | |||
| private int version_; | |||
| /// <summary> | |||
| /// Deprecated single version field; use versions above instead. Since all | |||
| /// GraphDef changes before "versions" was introduced were forward | |||
| /// compatible, this field is entirely ignored. | |||
| /// </summary> | |||
| [global::System.ObsoleteAttribute] | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int Version { | |||
| get { return version_; } | |||
| set { | |||
| version_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "library" field.</summary> | |||
| public const int LibraryFieldNumber = 2; | |||
| private global::Tensorflow.FunctionDefLibrary library_; | |||
| /// <summary> | |||
| /// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. | |||
| /// | |||
| /// "library" provides user-defined functions. | |||
| /// | |||
| /// Naming: | |||
| /// * library.function.name are in a flat namespace. | |||
| /// NOTE: We may need to change it to be hierarchical to support | |||
| /// different orgs. E.g., | |||
| /// { "/google/nn", { ... }}, | |||
| /// { "/google/vision", { ... }} | |||
| /// { "/org_foo/module_bar", { ... }} | |||
| /// map<string, FunctionDefLib> named_lib; | |||
| /// * If node[i].op is the name of one function in "library", | |||
| /// node[i] is deemed as a function call. Otherwise, node[i].op | |||
| /// must be a primitive operation supported by the runtime. | |||
| /// | |||
| /// Function call semantics: | |||
| /// | |||
| /// * The callee may start execution as soon as some of its inputs | |||
| /// are ready. The caller may want to use Tuple() mechanism to | |||
| /// ensure all inputs are ready in the same time. | |||
| /// | |||
| /// * The consumer of return values may start executing as soon as | |||
| /// the return values the consumer depends on are ready. The | |||
| /// consumer may want to use Tuple() mechanism to ensure the | |||
| /// consumer does not start until all return values of the callee | |||
| /// function are ready. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public global::Tensorflow.FunctionDefLibrary Library { | |||
| get { return library_; } | |||
| set { | |||
| library_ = value; | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as GraphDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool Equals(GraphDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if(!node_.Equals(other.node_)) return false; | |||
| if (!object.Equals(Versions, other.Versions)) return false; | |||
| if (Version != other.Version) return false; | |||
| if (!object.Equals(Library, other.Library)) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= node_.GetHashCode(); | |||
| if (versions_ != null) hash ^= Versions.GetHashCode(); | |||
| if (Version != 0) hash ^= Version.GetHashCode(); | |||
| if (library_ != null) hash ^= Library.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| node_.WriteTo(output, _repeated_node_codec); | |||
| if (library_ != null) { | |||
| output.WriteRawTag(18); | |||
| output.WriteMessage(Library); | |||
| } | |||
| if (Version != 0) { | |||
| output.WriteRawTag(24); | |||
| output.WriteInt32(Version); | |||
| } | |||
| if (versions_ != null) { | |||
| output.WriteRawTag(34); | |||
| output.WriteMessage(Versions); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += node_.CalculateSize(_repeated_node_codec); | |||
| if (versions_ != null) { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(Versions); | |||
| } | |||
| if (Version != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt32Size(Version); | |||
| } | |||
| if (library_ != null) { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(Library); | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(GraphDef other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| node_.Add(other.node_); | |||
| if (other.versions_ != null) { | |||
| if (versions_ == null) { | |||
| versions_ = new global::Tensorflow.VersionDef(); | |||
| } | |||
| Versions.MergeFrom(other.Versions); | |||
| } | |||
| if (other.Version != 0) { | |||
| Version = other.Version; | |||
| } | |||
| if (other.library_ != null) { | |||
| if (library_ == null) { | |||
| library_ = new global::Tensorflow.FunctionDefLibrary(); | |||
| } | |||
| Library.MergeFrom(other.Library); | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| node_.AddEntriesFrom(input, _repeated_node_codec); | |||
| break; | |||
| } | |||
| case 18: { | |||
| if (library_ == null) { | |||
| library_ = new global::Tensorflow.FunctionDefLibrary(); | |||
| } | |||
| input.ReadMessage(library_); | |||
| break; | |||
| } | |||
| case 24: { | |||
| Version = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 34: { | |||
| if (versions_ == null) { | |||
| versions_ = new global::Tensorflow.VersionDef(); | |||
| } | |||
| input.ReadMessage(versions_); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| #endregion Designer generated code | |||
| @@ -1,12 +1,15 @@ | |||
| ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | |||
| ```shell | |||
| set SRC_DIR=D:\Projects\tensorflow\tensorflow\core\framework | |||
| set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Tensorflow | |||
| set SRC_DIR=D:\Projects\tensorflow-1.12.0\tensorflow\core\framework | |||
| set DST_DIR=D:\Projects\TensorFlow.NET\src\TensorFlowNET.Core\Protobuf | |||
| .\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto | |||
| .\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto | |||
| .\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto | |||
| .\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto | |||
| .\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto | |||
| .\protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% resource_handle.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor_shape.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% types.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensor.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% attr_value.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% node_def.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% versions.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% function.proto | |||
| protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% graph.proto | |||
| ``` | |||
| @@ -0,0 +1,247 @@ | |||
| // <auto-generated> | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: versions.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using pbr = global::Google.Protobuf.Reflection; | |||
| using scg = global::System.Collections.Generic; | |||
| namespace Tensorflow { | |||
| /// <summary>Holder for reflection information generated from versions.proto</summary> | |||
| public static partial class VersionsReflection { | |||
| #region Descriptor | |||
| /// <summary>File descriptor for versions.proto</summary> | |||
| public static pbr::FileDescriptor Descriptor { | |||
| get { return descriptor; } | |||
| } | |||
| private static pbr::FileDescriptor descriptor; | |||
| static VersionsReflection() { | |||
| byte[] descriptorData = global::System.Convert.FromBase64String( | |||
| string.Concat( | |||
| "Cg52ZXJzaW9ucy5wcm90bxIKdGVuc29yZmxvdyJLCgpWZXJzaW9uRGVmEhAK", | |||
| "CHByb2R1Y2VyGAEgASgFEhQKDG1pbl9jb25zdW1lchgCIAEoBRIVCg1iYWRf", | |||
| "Y29uc3VtZXJzGAMgAygFQm4KGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IO", | |||
| "VmVyc2lvbnNQcm90b3NQAVo9Z2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNv", | |||
| "cmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29ya/gBAWIGcHJvdG8z")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { }, | |||
| new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.VersionDef), global::Tensorflow.VersionDef.Parser, new[]{ "Producer", "MinConsumer", "BadConsumers" }, null, null, null) | |||
| })); | |||
| } | |||
| #endregion | |||
| } | |||
| #region Messages | |||
| /// <summary> | |||
| /// Version information for a piece of serialized data | |||
| /// | |||
| /// There are different types of versions for each type of data | |||
| /// (GraphDef, etc.), but they all have the same common shape | |||
| /// described here. | |||
| /// | |||
| /// Each consumer has "consumer" and "min_producer" versions (specified | |||
| /// elsewhere). A consumer is allowed to consume this data if | |||
| /// | |||
| /// producer >= min_producer | |||
| /// consumer >= min_consumer | |||
| /// consumer not in bad_consumers | |||
| /// </summary> | |||
| public sealed partial class VersionDef : pb::IMessage<VersionDef> { | |||
| private static readonly pb::MessageParser<VersionDef> _parser = new pb::MessageParser<VersionDef>(() => new VersionDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pb::MessageParser<VersionDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.VersionsReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public VersionDef() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public VersionDef(VersionDef other) : this() { | |||
| producer_ = other.producer_; | |||
| minConsumer_ = other.minConsumer_; | |||
| badConsumers_ = other.badConsumers_.Clone(); | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public VersionDef Clone() { | |||
| return new VersionDef(this); | |||
| } | |||
| /// <summary>Field number for the "producer" field.</summary> | |||
| public const int ProducerFieldNumber = 1; | |||
| private int producer_; | |||
| /// <summary> | |||
| /// The version of the code that produced this data. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int Producer { | |||
| get { return producer_; } | |||
| set { | |||
| producer_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "min_consumer" field.</summary> | |||
| public const int MinConsumerFieldNumber = 2; | |||
| private int minConsumer_; | |||
| /// <summary> | |||
| /// Any consumer below this version is not allowed to consume this data. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int MinConsumer { | |||
| get { return minConsumer_; } | |||
| set { | |||
| minConsumer_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "bad_consumers" field.</summary> | |||
| public const int BadConsumersFieldNumber = 3; | |||
| private static readonly pb::FieldCodec<int> _repeated_badConsumers_codec | |||
| = pb::FieldCodec.ForInt32(26); | |||
| private readonly pbc::RepeatedField<int> badConsumers_ = new pbc::RepeatedField<int>(); | |||
| /// <summary> | |||
| /// Specific consumer versions which are disallowed (e.g. due to bugs). | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<int> BadConsumers { | |||
| get { return badConsumers_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as VersionDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public bool Equals(VersionDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if (Producer != other.Producer) return false; | |||
| if (MinConsumer != other.MinConsumer) return false; | |||
| if(!badConsumers_.Equals(other.badConsumers_)) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (Producer != 0) hash ^= Producer.GetHashCode(); | |||
| if (MinConsumer != 0) hash ^= MinConsumer.GetHashCode(); | |||
| hash ^= badConsumers_.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| if (Producer != 0) { | |||
| output.WriteRawTag(8); | |||
| output.WriteInt32(Producer); | |||
| } | |||
| if (MinConsumer != 0) { | |||
| output.WriteRawTag(16); | |||
| output.WriteInt32(MinConsumer); | |||
| } | |||
| badConsumers_.WriteTo(output, _repeated_badConsumers_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (Producer != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt32Size(Producer); | |||
| } | |||
| if (MinConsumer != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinConsumer); | |||
| } | |||
| size += badConsumers_.CalculateSize(_repeated_badConsumers_codec); | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(VersionDef other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| if (other.Producer != 0) { | |||
| Producer = other.Producer; | |||
| } | |||
| if (other.MinConsumer != 0) { | |||
| MinConsumer = other.MinConsumer; | |||
| } | |||
| badConsumers_.Add(other.badConsumers_); | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 8: { | |||
| Producer = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 16: { | |||
| MinConsumer = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 26: | |||
| case 24: { | |||
| badConsumers_.AddEntriesFrom(input, _repeated_badConsumers_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| #endregion Designer generated code | |||
| @@ -17,7 +17,7 @@ namespace Tensorflow | |||
| /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | |||
| /// struct => struct (TF_Output output) => (TF_Output output) | |||
| /// struct* => struct (TF_Output* output) => (TF_Output[] output) | |||
| /// struct* => ref IntPtr (TF_Input* consumers) => (ref IntPtr handle), if output is struct[] | |||
| /// struct* => struct* for ref | |||
| /// const char* => string | |||
| /// int32_t => int | |||
| /// int64_t* => long[] | |||
| @@ -83,6 +83,42 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.AreEqual(1, feed_port.Length); | |||
| Assert.AreEqual(add, feed_port[0].oper); | |||
| Assert.AreEqual(0, feed_port[0].index); | |||
| // The scalar const oper also has a consumer. | |||
| Assert.AreEqual(1, three.OutputNumConsumers(0)); | |||
| TF_Input[] three_port = three.OutputConsumers(0, 1); | |||
| Assert.AreEqual(add, three_port[0].oper); | |||
| Assert.AreEqual(1, three_port[0].index); | |||
| // Serialize to GraphDef. | |||
| var graph_def = c_test_util.GetGraphDef(graph); | |||
| // Validate GraphDef is what we expect. | |||
| bool found_placeholder = false; | |||
| bool found_scalar_const = false; | |||
| bool found_add = false; | |||
| foreach (var n in graph_def.Node) | |||
| { | |||
| if (c_test_util.IsPlaceholder(n)) | |||
| { | |||
| Assert.IsFalse(found_placeholder); | |||
| found_placeholder = true; | |||
| } | |||
| /*else if (IsScalarConst(n, 3)) | |||
| { | |||
| Assert.IsFalse(found_scalar_const); | |||
| found_scalar_const = true; | |||
| } | |||
| else if (IsAddN(n, 2)) | |||
| { | |||
| Assert.IsFalse(found_add); | |||
| found_add = true; | |||
| } | |||
| else | |||
| { | |||
| ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); | |||
| }*/ | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| var handle = c_api.TF_GetAllOpList(); | |||
| var buffer = new Buffer(handle); | |||
| Assert.IsTrue(buffer.Length == buffer.Data.Length); | |||
| Assert.IsTrue(buffer.Length == buffer.Length); | |||
| } | |||
| [TestMethod] | |||
| @@ -39,11 +39,20 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| var buffer = new Buffer(); | |||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
| attr_value = AttrValue.Parser.ParseFrom(buffer.Data); | |||
| attr_value = AttrValue.Parser.ParseFrom(buffer); | |||
| buffer.Dispose(); | |||
| return s.Code == TF_Code.TF_OK; | |||
| } | |||
| public static GraphDef GetGraphDef(Graph graph) | |||
| { | |||
| var s = new Status(); | |||
| var buffer = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
| s.Check(); | |||
| return GraphDef.Parser.ParseFrom(buffer); | |||
| } | |||
| public static bool GetNodeDef(Operation oper, ref NodeDef node_def) | |||
| { | |||
| var s = new Status(); | |||
| @@ -53,6 +62,37 @@ namespace TensorFlowNET.UnitTest | |||
| return s.Code == TF_Code.TF_OK; | |||
| } | |||
| public static bool IsPlaceholder(NodeDef node_def) | |||
| { | |||
| if (node_def.Op != "Placeholder" || node_def.Name != "feed") | |||
| { | |||
| return false; | |||
| } | |||
| bool found_dtype = false; | |||
| bool found_shape = false; | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| if (attr.Key == "dtype") | |||
| { | |||
| if (attr.Value.Type == DataType.DtInt32) | |||
| { | |||
| found_dtype = true; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| else if (attr.Key == "shape") | |||
| { | |||
| found_shape = true; | |||
| } | |||
| } | |||
| return found_dtype && found_shape; | |||
| } | |||
| public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | |||
| { | |||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||