diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index 16f524a4..9eaf8143 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -15,8 +15,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Hub", "src\Te
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Datasets", "src\TensorFlowNET.Datasets\TensorFlowNET.Datasets.csproj", "{494D6CAD-2C0D-4C0B-90E2-B097DB039383}"
EndProject
-Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}"
-EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -99,18 +97,6 @@ Global
{494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|Any CPU.Build.0 = Release|Any CPU
{494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|x64.ActiveCfg = Release|Any CPU
{494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|x64.Build.0 = Release|Any CPU
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Debug|x64.ActiveCfg = Debug|x64
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Debug|x64.Build.0 = Debug|x64
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Publish|Any CPU.ActiveCfg = Publish|Any CPU
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Publish|Any CPU.Build.0 = Publish|Any CPU
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Publish|x64.ActiveCfg = Publish|x64
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Publish|x64.Build.0 = Publish|x64
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Release|Any CPU.ActiveCfg = Release|Any CPU
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Release|Any CPU.Build.0 = Release|Any CPU
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Release|x64.ActiveCfg = Release|x64
- {9249BCC4-3FEB-4EF5-8AB9-789FFE4040B4}.Release|x64.Build.0 = Release|x64
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/SciSharp.TensorFlow.Redist/README.md b/src/SciSharp.TensorFlow.Redist/README.md
index 2c608020..6dfce3e1 100644
--- a/src/SciSharp.TensorFlow.Redist/README.md
+++ b/src/SciSharp.TensorFlow.Redist/README.md
@@ -26,7 +26,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries.
-1. Run `dotnet pack SciSharp.TensorFlow.Redist-CPU.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
-2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.1.14.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json`
+1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
+2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.1.15.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json`
diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj
similarity index 97%
rename from src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj
rename to src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj
index 5de76105..85ca2898 100644
--- a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj
+++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj
@@ -7,7 +7,7 @@
x64
netstandard2.0
- 1.14.0
+ 1.15.0
1
$(BinDir)packages\
@@ -38,7 +38,7 @@
new GradientDescentOptimizer(learning_rate);
+ public Optimizer GradientDescentOptimizer(Tensor learning_rate)
+ => new GradientDescentOptimizer(learning_rate);
+
public Optimizer AdamOptimizer(float learning_rate, string name = "Adam")
=> new AdamOptimizer(learning_rate, name: name);
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index 334f4f74..150fa89a 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -129,6 +129,7 @@ namespace Tensorflow
}
}
+ [DebuggerStepThrough]
[DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception
public static TOut tf_with(TIn py, Func action) where TIn : IObjectLife
{
diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs
index 3f5a2777..05092581 100644
--- a/src/TensorFlowNET.Core/Framework/meta_graph.cs
+++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs
@@ -134,7 +134,7 @@ namespace Tensorflow
}
break;
default:
- Console.WriteLine("import_scoped_meta_graph_with_return_elements");
+ Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}");
continue;
}
}
diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs
index 67102cab..0f1cb76e 100644
--- a/src/TensorFlowNET.Core/Framework/smart_module.cs
+++ b/src/TensorFlowNET.Core/Framework/smart_module.cs
@@ -15,6 +15,7 @@
******************************************************************************/
using System;
+using static Tensorflow.Binding;
namespace Tensorflow.Framework
{
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 0c43582d..9b6906aa 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -75,7 +75,7 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
///
/// https://www.tensorflow.org/guide/graphs
https://www.tensorflow.org/api_docs/python/tf/Graph
- public partial class Graph : DisposableObject, IEnumerable
+ public partial class Graph : DisposableObject//, IEnumerable
{
private Dictionary _nodes_by_id;
public Dictionary _nodes_by_name;
@@ -257,17 +257,17 @@ namespace Tensorflow
if (inputs == null)
inputs = new Tensor[0];
- foreach ((int idx, Tensor a) in enumerate(inputs))
- {
-
- }
-
- if (String.IsNullOrEmpty(name))
+ if (string.IsNullOrEmpty(name))
name = op_type;
// If a names ends with a '/' it is a "name scope" and we use it as-is,
// after removing the trailing '/'.
name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
+
+ if (name.Contains("define_loss/bigger_box_loss/mul_13"))
+ {
+
+ }
var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops);
@@ -526,14 +526,14 @@ namespace Tensorflow
return debugString;*/
}
- private IEnumerable GetEnumerable()
+ /*private IEnumerable GetEnumerable()
=> c_api_util.tf_operations(this);
IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerable().GetEnumerator();
IEnumerator IEnumerable.GetEnumerator()
- => throw new NotImplementedException();
+ => throw new NotImplementedException();*/
public static implicit operator IntPtr(Graph graph)
{
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
index e9f85530..9aa6d619 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
@@ -56,9 +56,9 @@ namespace Tensorflow.Keras.Engine
{
// Instantiate an input layer.
var x = keras.layers.Input(
- batch_shape: batch_shape,
- dtype: dtype,
- name: layer.name + "_input");
+ batch_shape: batch_shape,
+ dtype: dtype,
+ name: layer.name + "_input");
// This will build the current layer
// and create the node connecting the current layer
@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Engine
if (set_inputs)
{
// If an input layer (placeholder) is available.
- // outputs = layer._inbound_nodes;
+ // outputs = layer.inbound_nodes;
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
index 22cef8e1..25161721 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
@@ -106,6 +106,7 @@ namespace Tensorflow.Keras.Layers
VariableScope scope = null)
{
var input_list = inputs;
+ var input = inputs[0];
Tensor outputs = null;
// We will attempt to build a TF graph if & only if all inputs are symbolic.
@@ -139,6 +140,7 @@ namespace Tensorflow.Keras.Layers
_maybe_build(inputs[0]);
outputs = call(inputs[0], training: training);
+ (input, outputs) = _set_connectivity_metadata_(input, outputs);
_handle_activity_regularization(inputs[0], outputs);
_set_mask_metadata(inputs[0], outputs, null);
});
@@ -147,6 +149,12 @@ namespace Tensorflow.Keras.Layers
return outputs;
}
+ private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
+ {
+ //_add_inbound_node(input_tensors: inputs, output_tensors: outputs);
+ return (inputs, outputs);
+ }
+
private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
{
//if(_activity_regularizer != null)
@@ -224,7 +232,7 @@ namespace Tensorflow.Keras.Layers
overwrite: true,
initializer: initializer,
trainable: trainable.Value);
- backend.track_variable(variable);
+ //backend.track_variable(variable);
_trainable_weights.Add(variable);
return variable;
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 4e376d19..f3a63d68 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using static Tensorflow.Binding;
+
namespace Tensorflow.Operations
{
public class gen_nn_ops
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
index 8e317df9..2f61f954 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -54,10 +54,6 @@ namespace Tensorflow
public void _set_control_flow_context(ControlFlowContext ctx)
{
- if(name == "define_loss/conv_sobj_branch/batch_normalization/cond/FusedBatchNorm_1")
- {
-
- }
_control_flow_context = ctx;
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
index c80e99f6..f518c726 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using Newtonsoft.Json;
using System;
using System.Linq;
using System.Runtime.InteropServices;
@@ -37,7 +38,9 @@ namespace Tensorflow
}
return num;
}
-
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
index 6844c892..f4dcdfd6 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using Newtonsoft.Json;
using System;
using System.Linq;
using System.Runtime.InteropServices;
@@ -23,6 +24,9 @@ namespace Tensorflow
{
public partial class Operation
{
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index));
@@ -40,7 +44,9 @@ namespace Tensorflow
private Tensor[] _outputs;
public Tensor[] outputs => _outputs;
-
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public Tensor output => _outputs.FirstOrDefault();
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index 6118602c..2e653e51 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -15,6 +15,9 @@
******************************************************************************/
using Google.Protobuf.Collections;
+#if SERIALIZABLE
+using Newtonsoft.Json;
+#endif
using System;
using System.Collections.Generic;
using System.IO;
@@ -43,20 +46,37 @@ namespace Tensorflow
///
public partial class Operation : ITensorOrOperation
{
- private readonly IntPtr _handle; // _c_op in python
- private readonly Graph _graph;
- private NodeDef _node_def;
+ private readonly IntPtr _handle; // _c_op in python
- public string type => OpType;
- public Graph graph => _graph;
- public int _id => _id_value;
- public int _id_value;
+ private readonly Graph _graph;
+ private NodeDef _node_def;
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
+ public string type => OpType;
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
+ public Graph graph => _graph;
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
+ public int _id => _id_value;
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
+ public int _id_value;
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public Operation op => this;
public TF_DataType dtype => TF_DataType.DtInvalid;
public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle));
- public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle));
-
+ public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle));
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public NodeDef node_def
{
get
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index b7ef6440..86ab150f 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -601,7 +601,16 @@ namespace Tensorflow
}
public static Tensor gather(T1 @params, T2 indices, string name = null, int axis = 0)
- => gen_array_ops.gather_v2(@params, indices, axis, name: name);
+ {
+ if (axis != 0)
+ return gen_array_ops.gather_v2(@params, indices, axis, name: name);
+
+ if (@params is ResourceVariable variable &&
+ indices is Tensor indices_tensor)
+ return variable.sparse_read(indices_tensor, name);
+
+ return gen_array_ops.gather_v2(@params, indices, axis, name: name);
+ }
public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false)
{
diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
index 7b00b604..664572a5 100644
--- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
@@ -31,7 +31,62 @@ namespace Tensorflow
{
var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource });
- return _op;
+ return _op.output;
+ }
+
+ ///
+ /// Creates a handle to a Variable resource.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor var_handle_op(TF_DataType dtype, TensorShape shape,
+ string container ="", string shared_name = "", string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("VarHandleOp", name, new {
+ dtype,
+ shape,
+ container,
+ shared_name
+ });
+
+ return _op.output;
+ }
+
+ ///
+ /// Reads the value of a variable.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new
+ {
+ resource,
+ dtype
+ });
+
+ return _op.output;
+ }
+
+ public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype,
+ int batch_dims = 0, bool validate_indices = true, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("ResourceGather", name, new
+ {
+ resource,
+ indices,
+ dtype,
+ batch_dims,
+ validate_indices
+ });
+
+ return _op.output;
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
index b301063c..f591402e 100644
--- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
@@ -16,6 +16,7 @@
using System;
using Tensorflow.Framework;
+using static Tensorflow.CppShapeInferenceResult.Types;
namespace Tensorflow
{
@@ -91,12 +92,80 @@ namespace Tensorflow
shape, dtype, shared_name, name, graph_mode, initial_value);
}
+ ///
+ /// Create a new variable handle, optionally copying in `extra_handle_data`
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
public static Tensor variable_handle_from_shape_and_dtype(TensorShape shape, TF_DataType dtype,
string shared_name, string name, bool graph_mode, Tensor extra_handle_data = null)
{
+ var container = "";// ops.get_default_graph().container;
+ var handle = gen_resource_variable_ops.var_handle_op(shape: shape,
+ dtype: dtype,
+ shared_name: shared_name,
+ name: name,
+ container: container);
+
+ if (extra_handle_data == null)
+ extra_handle_data = handle;
+
+ if (graph_mode)
+ {
+ var full_handle_data = _combine_handle_data(handle, extra_handle_data);
+ _set_handle_shapes_and_types(handle, full_handle_data, graph_mode);
+ return handle;
+ }
+ else
+ {
+ throw new NotImplementedException("");
+ }
+ }
+
+ private static void _set_handle_shapes_and_types(Tensor handle, HandleData full_handle_data, bool graph_mode)
+ {
+
+ }
+
+ ///
+ /// Concats HandleData from tensors `handle` and `initial_value`.
+ ///
+ ///
+ ///
+ ///
+ private static HandleData _combine_handle_data(Tensor handle, Tensor initial_value)
+ {
+ var variable_handle_data = get_eager_safe_handle_data(initial_value);
+
+ if (initial_value.dtype != dtypes.variant)
+ return variable_handle_data;
+
throw new NotImplementedException("");
}
+ private static HandleData get_eager_safe_handle_data(Tensor handle)
+ {
+ if(handle == IntPtr.Zero)
+ {
+ var data = new HandleData();
+ data.ShapeAndType.Add(new HandleShapeAndType
+ {
+ Shape = handle.TensorShape.as_proto(),
+ Dtype = handle.dtype.as_datatype_enum()
+ });
+ return data;
+ }
+ else
+ {
+ return HandleData.Parser.ParseFrom(handle.BufferToArray());
+ }
+ }
+
///
/// Represents a future for a read of a variable.
/// Pretends to be the tensor if anyone looks.
diff --git a/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs b/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs
new file mode 100644
index 00000000..c6574895
--- /dev/null
+++ b/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs
@@ -0,0 +1,692 @@
+//
+// Generated by the protocol buffer compiler. DO NOT EDIT!
+// source: tensorflow/python/framework/cpp_shape_inference.proto
+//
+#pragma warning disable 1591, 0612, 3021
+#region Designer generated code
+
+using pb = global::Google.Protobuf;
+using pbc = global::Google.Protobuf.Collections;
+using pbr = global::Google.Protobuf.Reflection;
+using scg = global::System.Collections.Generic;
+namespace Tensorflow {
+
+ /// Holder for reflection information generated from tensorflow/python/framework/cpp_shape_inference.proto
+ public static partial class CppShapeInferenceReflection {
+
+ #region Descriptor
+ /// File descriptor for tensorflow/python/framework/cpp_shape_inference.proto
+ public static pbr::FileDescriptor Descriptor {
+ get { return descriptor; }
+ }
+ private static pbr::FileDescriptor descriptor;
+
+ static CppShapeInferenceReflection() {
+ byte[] descriptorData = global::System.Convert.FromBase64String(
+ string.Concat(
+ "CjV0ZW5zb3JmbG93L3B5dGhvbi9mcmFtZXdvcmsvY3BwX3NoYXBlX2luZmVy",
+ "ZW5jZS5wcm90bxIKdGVuc29yZmxvdxoldGVuc29yZmxvdy9jb3JlL2ZyYW1l",
+ "d29yay90eXBlcy5wcm90bxosdGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay90",
+ "ZW5zb3Jfc2hhcGUucHJvdG8i7QIKF0NwcFNoYXBlSW5mZXJlbmNlUmVzdWx0",
+ "EisKBXNoYXBlGAEgASgLMhwudGVuc29yZmxvdy5UZW5zb3JTaGFwZVByb3Rv",
+ "EkMKC2hhbmRsZV9kYXRhGAQgASgLMi4udGVuc29yZmxvdy5DcHBTaGFwZUlu",
+ "ZmVyZW5jZVJlc3VsdC5IYW5kbGVEYXRhGmYKEkhhbmRsZVNoYXBlQW5kVHlw",
+ "ZRIrCgVzaGFwZRgBIAEoCzIcLnRlbnNvcmZsb3cuVGVuc29yU2hhcGVQcm90",
+ "bxIjCgVkdHlwZRgCIAEoDjIULnRlbnNvcmZsb3cuRGF0YVR5cGUabAoKSGFu",
+ "ZGxlRGF0YRIOCgZpc19zZXQYASABKAgSTgoOc2hhcGVfYW5kX3R5cGUYAiAD",
+ "KAsyNi50ZW5zb3JmbG93LkNwcFNoYXBlSW5mZXJlbmNlUmVzdWx0LkhhbmRs",
+ "ZVNoYXBlQW5kVHlwZUoECAIQA0oECAMQBCJlCh1DcHBTaGFwZUluZmVyZW5j",
+ "ZUlucHV0c05lZWRlZBIcChRpbnB1dF90ZW5zb3JzX25lZWRlZBgBIAMoBRIm",
+ "Ch5pbnB1dF90ZW5zb3JzX2FzX3NoYXBlc19uZWVkZWQYAiADKAVCA/gBAWIG",
+ "cHJvdG8z"));
+ descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
+ new pbr::FileDescriptor[] { global::Tensorflow.TypesReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, },
+ new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult), global::Tensorflow.CppShapeInferenceResult.Parser, new[]{ "Shape", "HandleData" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType), global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType.Parser, new[]{ "Shape", "Dtype" }, null, null, null),
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult.Types.HandleData), global::Tensorflow.CppShapeInferenceResult.Types.HandleData.Parser, new[]{ "IsSet", "ShapeAndType" }, null, null, null)}),
+ new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceInputsNeeded), global::Tensorflow.CppShapeInferenceInputsNeeded.Parser, new[]{ "InputTensorsNeeded", "InputTensorsAsShapesNeeded" }, null, null, null)
+ }));
+ }
+ #endregion
+
+ }
+ #region Messages
+ public sealed partial class CppShapeInferenceResult : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CppShapeInferenceResult());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public CppShapeInferenceResult() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public CppShapeInferenceResult(CppShapeInferenceResult other) : this() {
+ shape_ = other.shape_ != null ? other.shape_.Clone() : null;
+ handleData_ = other.handleData_ != null ? other.handleData_.Clone() : null;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public CppShapeInferenceResult Clone() {
+ return new CppShapeInferenceResult(this);
+ }
+
+ /// Field number for the "shape" field.
+ public const int ShapeFieldNumber = 1;
+ private global::Tensorflow.TensorShapeProto shape_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.TensorShapeProto Shape {
+ get { return shape_; }
+ set {
+ shape_ = value;
+ }
+ }
+
+ /// Field number for the "handle_data" field.
+ public const int HandleDataFieldNumber = 4;
+ private global::Tensorflow.CppShapeInferenceResult.Types.HandleData handleData_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData {
+ get { return handleData_; }
+ set {
+ handleData_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as CppShapeInferenceResult);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(CppShapeInferenceResult other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!object.Equals(Shape, other.Shape)) return false;
+ if (!object.Equals(HandleData, other.HandleData)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (shape_ != null) hash ^= Shape.GetHashCode();
+ if (handleData_ != null) hash ^= HandleData.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (shape_ != null) {
+ output.WriteRawTag(10);
+ output.WriteMessage(Shape);
+ }
+ if (handleData_ != null) {
+ output.WriteRawTag(34);
+ output.WriteMessage(HandleData);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (shape_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape);
+ }
+ if (handleData_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(HandleData);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(CppShapeInferenceResult other) {
+ if (other == null) {
+ return;
+ }
+ if (other.shape_ != null) {
+ if (shape_ == null) {
+ shape_ = new global::Tensorflow.TensorShapeProto();
+ }
+ Shape.MergeFrom(other.Shape);
+ }
+ if (other.handleData_ != null) {
+ if (handleData_ == null) {
+ handleData_ = new global::Tensorflow.CppShapeInferenceResult.Types.HandleData();
+ }
+ HandleData.MergeFrom(other.HandleData);
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ if (shape_ == null) {
+ shape_ = new global::Tensorflow.TensorShapeProto();
+ }
+ input.ReadMessage(shape_);
+ break;
+ }
+ case 34: {
+ if (handleData_ == null) {
+ handleData_ = new global::Tensorflow.CppShapeInferenceResult.Types.HandleData();
+ }
+ input.ReadMessage(handleData_);
+ break;
+ }
+ }
+ }
+ }
+
+ #region Nested types
+ /// Container for nested types declared in the CppShapeInferenceResult message type.
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static partial class Types {
+ public sealed partial class HandleShapeAndType : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HandleShapeAndType());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[0]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public HandleShapeAndType() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public HandleShapeAndType(HandleShapeAndType other) : this() {
+ shape_ = other.shape_ != null ? other.shape_.Clone() : null;
+ dtype_ = other.dtype_;
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public HandleShapeAndType Clone() {
+ return new HandleShapeAndType(this);
+ }
+
+ /// Field number for the "shape" field.
+ public const int ShapeFieldNumber = 1;
+ private global::Tensorflow.TensorShapeProto shape_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.TensorShapeProto Shape {
+ get { return shape_; }
+ set {
+ shape_ = value;
+ }
+ }
+
+ /// Field number for the "dtype" field.
+ public const int DtypeFieldNumber = 2;
+ private global::Tensorflow.DataType dtype_ = 0;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public global::Tensorflow.DataType Dtype {
+ get { return dtype_; }
+ set {
+ dtype_ = value;
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as HandleShapeAndType);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(HandleShapeAndType other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (!object.Equals(Shape, other.Shape)) return false;
+ if (Dtype != other.Dtype) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (shape_ != null) hash ^= Shape.GetHashCode();
+ if (Dtype != 0) hash ^= Dtype.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (shape_ != null) {
+ output.WriteRawTag(10);
+ output.WriteMessage(Shape);
+ }
+ if (Dtype != 0) {
+ output.WriteRawTag(16);
+ output.WriteEnum((int) Dtype);
+ }
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (shape_ != null) {
+ size += 1 + pb::CodedOutputStream.ComputeMessageSize(Shape);
+ }
+ if (Dtype != 0) {
+ size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype);
+ }
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(HandleShapeAndType other) {
+ if (other == null) {
+ return;
+ }
+ if (other.shape_ != null) {
+ if (shape_ == null) {
+ shape_ = new global::Tensorflow.TensorShapeProto();
+ }
+ Shape.MergeFrom(other.Shape);
+ }
+ if (other.Dtype != 0) {
+ Dtype = other.Dtype;
+ }
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10: {
+ if (shape_ == null) {
+ shape_ = new global::Tensorflow.TensorShapeProto();
+ }
+ input.ReadMessage(shape_);
+ break;
+ }
+ case 16: {
+ dtype_ = (global::Tensorflow.DataType) input.ReadEnum();
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ public sealed partial class HandleData : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HandleData());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[1]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public HandleData() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public HandleData(HandleData other) : this() {
+ isSet_ = other.isSet_;
+ shapeAndType_ = other.shapeAndType_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public HandleData Clone() {
+ return new HandleData(this);
+ }
+
+ /// Field number for the "is_set" field.
+ public const int IsSetFieldNumber = 1;
+ private bool isSet_;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool IsSet {
+ get { return isSet_; }
+ set {
+ isSet_ = value;
+ }
+ }
+
+ /// Field number for the "shape_and_type" field.
+ public const int ShapeAndTypeFieldNumber = 2;
+ private static readonly pb::FieldCodec _repeated_shapeAndType_codec
+ = pb::FieldCodec.ForMessage(18, global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType.Parser);
+ private readonly pbc::RepeatedField shapeAndType_ = new pbc::RepeatedField();
+ ///
+ /// Only valid if <is_set>.
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField ShapeAndType {
+ get { return shapeAndType_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as HandleData);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(HandleData other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if (IsSet != other.IsSet) return false;
+ if(!shapeAndType_.Equals(other.shapeAndType_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ if (IsSet != false) hash ^= IsSet.GetHashCode();
+ hash ^= shapeAndType_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ if (IsSet != false) {
+ output.WriteRawTag(8);
+ output.WriteBool(IsSet);
+ }
+ shapeAndType_.WriteTo(output, _repeated_shapeAndType_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ if (IsSet != false) {
+ size += 1 + 1;
+ }
+ size += shapeAndType_.CalculateSize(_repeated_shapeAndType_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(HandleData other) {
+ if (other == null) {
+ return;
+ }
+ if (other.IsSet != false) {
+ IsSet = other.IsSet;
+ }
+ shapeAndType_.Add(other.shapeAndType_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 8: {
+ IsSet = input.ReadBool();
+ break;
+ }
+ case 18: {
+ shapeAndType_.AddEntriesFrom(input, _repeated_shapeAndType_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ }
+ #endregion
+
+ }
+
+ public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage {
+ private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CppShapeInferenceInputsNeeded());
+ private pb::UnknownFieldSet _unknownFields;
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pb::MessageParser Parser { get { return _parser; } }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public static pbr::MessageDescriptor Descriptor {
+ get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[1]; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ pbr::MessageDescriptor pb::IMessage.Descriptor {
+ get { return Descriptor; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public CppShapeInferenceInputsNeeded() {
+ OnConstruction();
+ }
+
+ partial void OnConstruction();
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public CppShapeInferenceInputsNeeded(CppShapeInferenceInputsNeeded other) : this() {
+ inputTensorsNeeded_ = other.inputTensorsNeeded_.Clone();
+ inputTensorsAsShapesNeeded_ = other.inputTensorsAsShapesNeeded_.Clone();
+ _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public CppShapeInferenceInputsNeeded Clone() {
+ return new CppShapeInferenceInputsNeeded(this);
+ }
+
+ /// Field number for the "input_tensors_needed" field.
+ public const int InputTensorsNeededFieldNumber = 1;
+ private static readonly pb::FieldCodec _repeated_inputTensorsNeeded_codec
+ = pb::FieldCodec.ForInt32(10);
+ private readonly pbc::RepeatedField inputTensorsNeeded_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField InputTensorsNeeded {
+ get { return inputTensorsNeeded_; }
+ }
+
+ /// Field number for the "input_tensors_as_shapes_needed" field.
+ public const int InputTensorsAsShapesNeededFieldNumber = 2;
+ private static readonly pb::FieldCodec _repeated_inputTensorsAsShapesNeeded_codec
+ = pb::FieldCodec.ForInt32(18);
+ private readonly pbc::RepeatedField inputTensorsAsShapesNeeded_ = new pbc::RepeatedField();
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public pbc::RepeatedField InputTensorsAsShapesNeeded {
+ get { return inputTensorsAsShapesNeeded_; }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override bool Equals(object other) {
+ return Equals(other as CppShapeInferenceInputsNeeded);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool Equals(CppShapeInferenceInputsNeeded other) {
+ if (ReferenceEquals(other, null)) {
+ return false;
+ }
+ if (ReferenceEquals(other, this)) {
+ return true;
+ }
+ if(!inputTensorsNeeded_.Equals(other.inputTensorsNeeded_)) return false;
+ if(!inputTensorsAsShapesNeeded_.Equals(other.inputTensorsAsShapesNeeded_)) return false;
+ return Equals(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override int GetHashCode() {
+ int hash = 1;
+ hash ^= inputTensorsNeeded_.GetHashCode();
+ hash ^= inputTensorsAsShapesNeeded_.GetHashCode();
+ if (_unknownFields != null) {
+ hash ^= _unknownFields.GetHashCode();
+ }
+ return hash;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public override string ToString() {
+ return pb::JsonFormatter.ToDiagnosticString(this);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void WriteTo(pb::CodedOutputStream output) {
+ inputTensorsNeeded_.WriteTo(output, _repeated_inputTensorsNeeded_codec);
+ inputTensorsAsShapesNeeded_.WriteTo(output, _repeated_inputTensorsAsShapesNeeded_codec);
+ if (_unknownFields != null) {
+ _unknownFields.WriteTo(output);
+ }
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public int CalculateSize() {
+ int size = 0;
+ size += inputTensorsNeeded_.CalculateSize(_repeated_inputTensorsNeeded_codec);
+ size += inputTensorsAsShapesNeeded_.CalculateSize(_repeated_inputTensorsAsShapesNeeded_codec);
+ if (_unknownFields != null) {
+ size += _unknownFields.CalculateSize();
+ }
+ return size;
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(CppShapeInferenceInputsNeeded other) {
+ if (other == null) {
+ return;
+ }
+ inputTensorsNeeded_.Add(other.inputTensorsNeeded_);
+ inputTensorsAsShapesNeeded_.Add(other.inputTensorsAsShapesNeeded_);
+ _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
+ }
+
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public void MergeFrom(pb::CodedInputStream input) {
+ uint tag;
+ while ((tag = input.ReadTag()) != 0) {
+ switch(tag) {
+ default:
+ _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
+ break;
+ case 10:
+ case 8: {
+ inputTensorsNeeded_.AddEntriesFrom(input, _repeated_inputTensorsNeeded_codec);
+ break;
+ }
+ case 18:
+ case 16: {
+ inputTensorsAsShapesNeeded_.AddEntriesFrom(input, _repeated_inputTensorsAsShapesNeeded_codec);
+ break;
+ }
+ }
+ }
+ }
+
+ }
+
+ #endregion
+
+}
+
+#endregion Designer generated code
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 33914c3a..bb5d889c 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -5,7 +5,7 @@
TensorFlow.NET
Tensorflow
1.14.0
- 0.11.6
+ 0.11.8
Haiping Chen, Meinrad Recheis, Eli Belash
SciSharp STACK
true
@@ -17,7 +17,7 @@
TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#
Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io
- 0.11.6.0
+ 0.11.8.0
Changes since v0.10.0:
1. Upgrade NumSharp to v0.20.3.
2. Add DisposableObject class to manage object lifetime.
@@ -31,9 +31,10 @@ Docs: https://tensorflownet.readthedocs.io
10. Support n-dim indexing for tensor.
11. Add RegisterNoGradients
12. Add CumsumGrad, BroadcastToGrad.
-13. Return VariableV1 instead of RefVariable.
+13. Return VariableV1 instead of RefVariable.
+14. Add Tensor overload to GradientDescentOptimizer.
7.3
- 0.11.6.0
+ 0.11.8.0
LICENSE
true
true
@@ -42,7 +43,7 @@ Docs: https://tensorflownet.readthedocs.io
true
- TRACE;DEBUG
+ TRACE;DEBUG;SERIALIZABLE
@@ -65,6 +66,8 @@ Docs: https://tensorflownet.readthedocs.io
+
+
@@ -72,8 +75,4 @@ Docs: https://tensorflownet.readthedocs.io
-
-
-
-
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
index 34edcb4f..7c5054d3 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
@@ -25,6 +25,7 @@ using System.Text;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using static Tensorflow.c_api;
+using Newtonsoft.Json;
namespace Tensorflow
{
@@ -44,11 +45,17 @@ namespace Tensorflow
///
/// True if this Tensor holds data allocated by C#.
///
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public bool IsMemoryOwner => AllocationType >= AllocationType.Marshal;
///
/// The allocation method used to create this Tensor.
///
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public AllocationType AllocationType { get; protected set; }
///
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index fb8e2457..ef053651 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -28,6 +28,7 @@ using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using NumSharp.Utilities;
using Tensorflow.Framework;
+using Newtonsoft.Json;
namespace Tensorflow
{
@@ -43,19 +44,29 @@ namespace Tensorflow
private readonly int _value_index;
private TF_Output? _tf_output;
private readonly TF_DataType _override_dtype;
-
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public int Id => _id;
///
/// The Graph that contains this tensor.
///
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public Graph graph => op?.graph;
///
/// The Operation that produces this tensor as an output.
///
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public Operation op => _op;
-
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public Tensor[] outputs => op.outputs;
///
@@ -72,24 +83,40 @@ namespace Tensorflow
/// The DType of elements in this tensor.
///
public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle);
-
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
- public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
+ private IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public int NDims => rank;
///
/// The name of the device on which this tensor will be produced, or null.
///
public string Device => op.Device;
-
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public int[] dims => shape;
///
/// Used for keep other pointer when do implicit operating
///
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public object Tag { get; set; }
@@ -139,6 +166,9 @@ namespace Tensorflow
return rank < 0 ? null : shape;
}
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape);
///
@@ -479,9 +509,11 @@ namespace Tensorflow
} else
throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType}).");
}
-
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public bool IsDisposed => _disposed;
- public int tensor_int_val { get; set; }
+ // public int tensor_int_val { get; set; }
}
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 1e239d50..80bb31c1 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -1,4 +1,5 @@
-using NumSharp;
+using Newtonsoft.Json;
+using NumSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
@@ -35,6 +36,9 @@ namespace Tensorflow
///
/// Returns the size this shape represents.
///
+#if SERIALIZABLE
+ [JsonIgnore]
+#endif
public int size
{
get
diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs
index 3827229d..2a2a9bfa 100644
--- a/src/TensorFlowNET.Core/Tensors/dtypes.cs
+++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs
@@ -33,6 +33,7 @@ namespace Tensorflow
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
+ public static TF_DataType variant = TF_DataType.TF_VARIANT;
public static TF_DataType resource = TF_DataType.TF_RESOURCE;
///
diff --git a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs
index 2d4effca..cc3527c2 100644
--- a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs
+++ b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs
@@ -46,7 +46,7 @@ namespace Tensorflow.Train
value,
name,
colocate_with_primary: true);
- ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var);
+ ops.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var);
_averages[var] = avg;
}
else
diff --git a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs
index 2d472f5a..1a2821bb 100644
--- a/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs
+++ b/src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs
@@ -39,6 +39,12 @@ namespace Tensorflow.Train
: base(learning_rate, use_locking, name)
{
_lr = learning_rate;
+ }
+
+ public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent")
+ : base(learning_rate, use_locking, name)
+ {
+ _lr_t = learning_rate;
}
public override void _prepare()
diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
index 83774734..7b887e22 100644
--- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
@@ -115,6 +115,7 @@ namespace Tensorflow
dtype: dtype);
});
_shape = shape ?? (initial_value as Tensor).TensorShape;
+ _initial_value = initial_value as Tensor;
_handle = resource_variable_ops.eager_safe_variable_handle(
initial_value: _initial_value,
shape: _shape,
@@ -122,7 +123,6 @@ namespace Tensorflow
name: name,
graph_mode: _in_graph_mode);
_unique_id = unique_id;
- _initial_value = initial_value as Tensor;
_handle_name = handle_name + ":0";
_dtype = _initial_value.dtype.as_base_dtype();
// _constraint = constraint;
@@ -133,6 +133,7 @@ namespace Tensorflow
{
_is_initialized_op = gen_resource_variable_ops.var_is_initialized_op(_handle);
});
+
if(initial_value != null)
{
tf_with(ops.name_scope("Assign"), scope1 =>
@@ -143,10 +144,25 @@ namespace Tensorflow
name: n);
});
}
+
+ // Manually assign reads to the handle's device to avoid log
+ // messages.
+ tf_with(ops.name_scope("Read"), delegate
+ {
+ var value = _read_variable_op();
+ _graph_element = value;
+ });
}
+
+ ops.add_to_collections(collections, this);
});
+ }
- throw new NotImplementedException("");
+ private Tensor _read_variable_op()
+ {
+ var result = gen_resource_variable_ops.read_variable_op(_handle, _dtype);
+ // _maybe_set_handle_data(_dtype, _handle, result);
+ return result;
}
private void _init_from_proto(VariableDef variable_def, string import_scope = null)
@@ -200,6 +216,18 @@ namespace Tensorflow
_dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype"));
}
+ public Tensor sparse_read(Tensor indices, string name = "Gather")
+ {
+ return tf_with(ops.name_scope(name), scope =>
+ {
+ name = scope;
+ var value = gen_resource_variable_ops.resource_gather(
+ _handle, indices, dtype: _dtype, name: name);
+
+ return array_ops.identity(value);
+ });
+ }
+
public override string ToString()
{
return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}";
diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs
index d898a4aa..0e056949 100644
--- a/src/TensorFlowNET.Core/Variables/variables.py.cs
+++ b/src/TensorFlowNET.Core/Variables/variables.py.cs
@@ -135,6 +135,28 @@ namespace Tensorflow
}
// If at least one input was modified, replace the op.
+ if(modified)
+ {
+ var new_op_type = op_type;
+ if (new_op_type == "RefSwitch")
+ new_op_type = "Switch";
+ var new_op_name = op.node_def.Name + "_" + name;
+ new_op_name = new_op_name.Replace(":", "_");
+ var _output_types = op._output_types;
+
+ // Convert attr values to AttrValue protos.
+ var attr_protos = new Dictionary();
+ foreach (var attr_def in op.node_def.Attr)
+ attr_protos[attr_def.Key] = attr_def.Value;
+
+ return op.graph.create_op(
+ new_op_type,
+ new_op_inputs.ToArray(),
+ _output_types,
+ name: new_op_name,
+ attrs: attr_protos);
+ }
+
return op;
}
diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs
index 4e7235bc..f4b4b77f 100644
--- a/src/TensorFlowNET.Core/ops.GraphKeys.cs
+++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs
@@ -30,8 +30,10 @@ namespace Tensorflow
public class GraphKeys
{
#region const
-
-
+ ///
+ /// Key to collect concatenated sharded variables.
+ ///
+ public const string CONCATENATED_VARIABLES_ = "concatenated_variables";
///
/// the subset of `Variable` objects that will be trained by an optimizer.
///
@@ -52,7 +54,12 @@ namespace Tensorflow
///
public const string LOSSES_ = "losses";
- public const string MOVING_AVERAGE_VARIABLES = "moving_average_variables";
+ public const string LOCAL_VARIABLES_ = "local_variables";
+
+ public const string METRIC_VARIABLES_ = "metric_variables";
+ public const string MODEL_VARIABLES_ = "model_variables";
+
+ public const string MOVING_AVERAGE_VARIABLES_ = "moving_average_variables";
///
/// Key to collect Variable objects that are global (shared across machines).
@@ -64,7 +71,21 @@ namespace Tensorflow
public const string GLOBAL_STEP_ = "global_step";
- public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" };
+ ///
+ /// List of all collections that keep track of variables.
+ ///
+ public string[] _VARIABLE_COLLECTIONS_ = new string[]
+ {
+ GLOBAL_VARIABLES_,
+ LOCAL_VARIABLES_,
+ METRIC_VARIABLES_,
+ MODEL_VARIABLES_,
+ TRAINABLE_VARIABLES_,
+ MOVING_AVERAGE_VARIABLES_,
+ CONCATENATED_VARIABLES_,
+ TRAINABLE_RESOURCE_VARIABLES_
+ };
+
///
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
///
@@ -86,7 +107,8 @@ namespace Tensorflow
#endregion
-
+
+ public string CONCATENATED_VARIABLES => CONCATENATED_VARIABLES_;
///
/// the subset of `Variable` objects that will be trained by an optimizer.
///
@@ -106,13 +128,16 @@ namespace Tensorflow
/// Key to collect local variables that are local to the machine and are not
/// saved/restored.
///
- public string LOCAL_VARIABLES = "local_variables";
+ public string LOCAL_VARIABLES = LOCAL_VARIABLES_;
///
/// Key to collect losses
///
public string LOSSES => LOSSES_;
+ public string METRIC_VARIABLES => METRIC_VARIABLES_;
+ public string MOVING_AVERAGE_VARIABLES = MOVING_AVERAGE_VARIABLES_;
+
///
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.
diff --git a/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj b/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj
index 10d27a5c..ef3c7e18 100644
--- a/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj
+++ b/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj
@@ -18,6 +18,6 @@
TensorFlow.Hub
-
+
\ No newline at end of file
diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs
index f0a79ed6..f4f3f141 100644
--- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs
+++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs
@@ -82,8 +82,7 @@ namespace TensorFlowNET.UnitTest
var sess_graph = sess.GetPrivate("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
- .And.BeEquivalentTo(sess_graph)
- .And.BeEquivalentTo(beforehand);
+ .And.BeEquivalentTo(sess_graph);
Console.WriteLine($"{tid}-{default_graph.graph_key}");