| @@ -134,7 +134,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| break; | break; | ||||
| default: | default: | ||||
| Console.WriteLine("import_scoped_meta_graph_with_return_elements"); | |||||
| Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}"); | |||||
| continue; | continue; | ||||
| } | } | ||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Framework | namespace Tensorflow.Framework | ||||
| { | { | ||||
| @@ -75,7 +75,7 @@ namespace Tensorflow | |||||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | ||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | ||||
| public partial class Graph : DisposableObject, IEnumerable<Operation> | |||||
| public partial class Graph : DisposableObject//, IEnumerable<Operation> | |||||
| { | { | ||||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | private Dictionary<int, ITensorOrOperation> _nodes_by_id; | ||||
| public Dictionary<string, ITensorOrOperation> _nodes_by_name; | public Dictionary<string, ITensorOrOperation> _nodes_by_name; | ||||
| @@ -257,17 +257,17 @@ namespace Tensorflow | |||||
| if (inputs == null) | if (inputs == null) | ||||
| inputs = new Tensor[0]; | inputs = new Tensor[0]; | ||||
| foreach ((int idx, Tensor a) in enumerate(inputs)) | |||||
| { | |||||
| } | |||||
| if (String.IsNullOrEmpty(name)) | |||||
| if (string.IsNullOrEmpty(name)) | |||||
| name = op_type; | name = op_type; | ||||
| // If a names ends with a '/' it is a "name scope" and we use it as-is, | // If a names ends with a '/' it is a "name scope" and we use it as-is, | ||||
| // after removing the trailing '/'. | // after removing the trailing '/'. | ||||
| name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | ||||
| var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | ||||
| if (name.Contains("define_loss/bigger_box_loss/mul_13")) | |||||
| { | |||||
| } | |||||
| var input_ops = inputs.Select(x => x.op).ToArray(); | var input_ops = inputs.Select(x => x.op).ToArray(); | ||||
| var control_inputs = _control_dependencies_for_inputs(input_ops); | var control_inputs = _control_dependencies_for_inputs(input_ops); | ||||
| @@ -526,14 +526,14 @@ namespace Tensorflow | |||||
| return debugString;*/ | return debugString;*/ | ||||
| } | } | ||||
| private IEnumerable<Operation> GetEnumerable() | |||||
| /*private IEnumerable<Operation> GetEnumerable() | |||||
| => c_api_util.tf_operations(this); | => c_api_util.tf_operations(this); | ||||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | ||||
| => GetEnumerable().GetEnumerator(); | => GetEnumerable().GetEnumerator(); | ||||
| IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
| => throw new NotImplementedException(); | |||||
| => throw new NotImplementedException();*/ | |||||
| public static implicit operator IntPtr(Graph graph) | public static implicit operator IntPtr(Graph graph) | ||||
| { | { | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| public class gen_nn_ops | public class gen_nn_ops | ||||
| @@ -54,10 +54,6 @@ namespace Tensorflow | |||||
| public void _set_control_flow_context(ControlFlowContext ctx) | public void _set_control_flow_context(ControlFlowContext ctx) | ||||
| { | { | ||||
| if(name == "define_loss/conv_sobj_branch/batch_normalization/cond/FusedBatchNorm_1") | |||||
| { | |||||
| } | |||||
| _control_flow_context = ctx; | _control_flow_context = ctx; | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| @@ -37,7 +38,9 @@ namespace Tensorflow | |||||
| } | } | ||||
| return num; | return num; | ||||
| } | } | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | public int NumInputs => c_api.TF_OperationNumInputs(_handle); | ||||
| private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); | private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| @@ -23,6 +24,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Operation | public partial class Operation | ||||
| { | { | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | ||||
| public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index)); | public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index)); | ||||
| @@ -40,7 +44,9 @@ namespace Tensorflow | |||||
| private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
| public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Tensor output => _outputs.FirstOrDefault(); | public Tensor output => _outputs.FirstOrDefault(); | ||||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | ||||
| @@ -15,6 +15,9 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
| #if SERIALIZABLE | |||||
| using Newtonsoft.Json; | |||||
| #endif | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| @@ -43,20 +46,37 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public partial class Operation : ITensorOrOperation | public partial class Operation : ITensorOrOperation | ||||
| { | { | ||||
| private readonly IntPtr _handle; // _c_op in python | |||||
| private readonly Graph _graph; | |||||
| private NodeDef _node_def; | |||||
| private readonly IntPtr _handle; // _c_op in python | |||||
| public string type => OpType; | |||||
| public Graph graph => _graph; | |||||
| public int _id => _id_value; | |||||
| public int _id_value; | |||||
| private readonly Graph _graph; | |||||
| private NodeDef _node_def; | |||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public string type => OpType; | |||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Graph graph => _graph; | |||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int _id => _id_value; | |||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int _id_value; | |||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Operation op => this; | public Operation op => this; | ||||
| public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
| public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | ||||
| public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | ||||
| public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
| public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public NodeDef node_def | public NodeDef node_def | ||||
| { | { | ||||
| get | get | ||||
| @@ -5,7 +5,7 @@ | |||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <TargetTensorFlow>1.14.0</TargetTensorFlow> | <TargetTensorFlow>1.14.0</TargetTensorFlow> | ||||
| <Version>0.11.7</Version> | |||||
| <Version>0.11.8</Version> | |||||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| @@ -17,7 +17,7 @@ | |||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | ||||
| <Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.11.7.0</AssemblyVersion> | |||||
| <AssemblyVersion>0.11.8.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Changes since v0.10.0: | <PackageReleaseNotes>Changes since v0.10.0: | ||||
| 1. Upgrade NumSharp to v0.20.3. | 1. Upgrade NumSharp to v0.20.3. | ||||
| 2. Add DisposableObject class to manage object lifetime. | 2. Add DisposableObject class to manage object lifetime. | ||||
| @@ -34,7 +34,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
| 13. Return VariableV1 instead of RefVariable. | 13. Return VariableV1 instead of RefVariable. | ||||
| 14. Add Tensor overload to GradientDescentOptimizer.</PackageReleaseNotes> | 14. Add Tensor overload to GradientDescentOptimizer.</PackageReleaseNotes> | ||||
| <LangVersion>7.3</LangVersion> | <LangVersion>7.3</LangVersion> | ||||
| <FileVersion>0.11.7.0</FileVersion> | |||||
| <FileVersion>0.11.8.0</FileVersion> | |||||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
| <SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
| @@ -43,7 +43,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
| <DefineConstants>TRACE;DEBUG</DefineConstants> | |||||
| <DefineConstants>TRACE;DEBUG;SERIALIZABLE</DefineConstants> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | ||||
| @@ -66,6 +66,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.5.1" /> | <PackageReference Include="Google.Protobuf" Version="3.5.1" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | |||||
| <PackageReference Include="NumSharp" Version="0.20.4" /> | <PackageReference Include="NumSharp" Version="0.20.4" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -25,6 +25,7 @@ using System.Text; | |||||
| using NumSharp.Backends; | using NumSharp.Backends; | ||||
| using NumSharp.Backends.Unmanaged; | using NumSharp.Backends.Unmanaged; | ||||
| using static Tensorflow.c_api; | using static Tensorflow.c_api; | ||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -44,11 +45,17 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// True if this Tensor holds data allocated by C#. | /// True if this Tensor holds data allocated by C#. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public bool IsMemoryOwner => AllocationType >= AllocationType.Marshal; | public bool IsMemoryOwner => AllocationType >= AllocationType.Marshal; | ||||
| /// <summary> | /// <summary> | ||||
| /// The allocation method used to create this Tensor. | /// The allocation method used to create this Tensor. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public AllocationType AllocationType { get; protected set; } | public AllocationType AllocationType { get; protected set; } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -28,6 +28,7 @@ using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | using NumSharp.Backends.Unmanaged; | ||||
| using NumSharp.Utilities; | using NumSharp.Utilities; | ||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -43,19 +44,29 @@ namespace Tensorflow | |||||
| private readonly int _value_index; | private readonly int _value_index; | ||||
| private TF_Output? _tf_output; | private TF_Output? _tf_output; | ||||
| private readonly TF_DataType _override_dtype; | private readonly TF_DataType _override_dtype; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int Id => _id; | public int Id => _id; | ||||
| /// <summary> | /// <summary> | ||||
| /// The Graph that contains this tensor. | /// The Graph that contains this tensor. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Graph graph => op?.graph; | public Graph graph => op?.graph; | ||||
| /// <summary> | /// <summary> | ||||
| /// The Operation that produces this tensor as an output. | /// The Operation that produces this tensor as an output. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Operation op => _op; | public Operation op => _op; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public Tensor[] outputs => op.outputs; | public Tensor[] outputs => op.outputs; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -72,24 +83,40 @@ namespace Tensorflow | |||||
| /// The DType of elements in this tensor. | /// The DType of elements in this tensor. | ||||
| /// </summary> | /// </summary> | ||||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); | public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | ||||
| public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||||
| private IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||||
| public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int NDims => rank; | public int NDims => rank; | ||||
| /// <summary> | /// <summary> | ||||
| /// The name of the device on which this tensor will be produced, or null. | /// The name of the device on which this tensor will be produced, or null. | ||||
| /// </summary> | /// </summary> | ||||
| public string Device => op.Device; | public string Device => op.Device; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int[] dims => shape; | public int[] dims => shape; | ||||
| /// <summary> | /// <summary> | ||||
| /// Used for keep other pointer when do implicit operating | /// Used for keep other pointer when do implicit operating | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public object Tag { get; set; } | public object Tag { get; set; } | ||||
| @@ -139,6 +166,9 @@ namespace Tensorflow | |||||
| return rank < 0 ? null : shape; | return rank < 0 ? null : shape; | ||||
| } | } | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -479,9 +509,11 @@ namespace Tensorflow | |||||
| } else | } else | ||||
| throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType})."); | throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType})."); | ||||
| } | } | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public bool IsDisposed => _disposed; | public bool IsDisposed => _disposed; | ||||
| public int tensor_int_val { get; set; } | |||||
| // public int tensor_int_val { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using NumSharp; | |||||
| using Newtonsoft.Json; | |||||
| using NumSharp; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
| @@ -35,6 +36,9 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the size this shape represents. | /// Returns the size this shape represents. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public int size | public int size | ||||
| { | { | ||||
| get | get | ||||
| @@ -46,7 +46,7 @@ namespace Tensorflow.Train | |||||
| value, | value, | ||||
| name, | name, | ||||
| colocate_with_primary: true); | colocate_with_primary: true); | ||||
| ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var); | |||||
| ops.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var); | |||||
| _averages[var] = avg; | _averages[var] = avg; | ||||
| } | } | ||||
| else | else | ||||
| @@ -30,8 +30,10 @@ namespace Tensorflow | |||||
| public class GraphKeys | public class GraphKeys | ||||
| { | { | ||||
| #region const | #region const | ||||
| /// <summary> | |||||
| /// Key to collect concatenated sharded variables. | |||||
| /// </summary> | |||||
| public const string CONCATENATED_VARIABLES_ = "concatenated_variables"; | |||||
| /// <summary> | /// <summary> | ||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | /// the subset of `Variable` objects that will be trained by an optimizer. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -52,7 +54,12 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public const string LOSSES_ = "losses"; | public const string LOSSES_ = "losses"; | ||||
| public const string MOVING_AVERAGE_VARIABLES = "moving_average_variables"; | |||||
| public const string LOCAL_VARIABLES_ = "local_variables"; | |||||
| public const string METRIC_VARIABLES_ = "metric_variables"; | |||||
| public const string MODEL_VARIABLES_ = "model_variables"; | |||||
| public const string MOVING_AVERAGE_VARIABLES_ = "moving_average_variables"; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect Variable objects that are global (shared across machines). | /// Key to collect Variable objects that are global (shared across machines). | ||||
| @@ -64,7 +71,21 @@ namespace Tensorflow | |||||
| public const string GLOBAL_STEP_ = "global_step"; | public const string GLOBAL_STEP_ = "global_step"; | ||||
| public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" }; | |||||
| /// <summary> | |||||
| /// List of all collections that keep track of variables. | |||||
| /// </summary> | |||||
| public string[] _VARIABLE_COLLECTIONS_ = new string[] | |||||
| { | |||||
| GLOBAL_VARIABLES_, | |||||
| LOCAL_VARIABLES_, | |||||
| METRIC_VARIABLES_, | |||||
| MODEL_VARIABLES_, | |||||
| TRAINABLE_VARIABLES_, | |||||
| MOVING_AVERAGE_VARIABLES_, | |||||
| CONCATENATED_VARIABLES_, | |||||
| TRAINABLE_RESOURCE_VARIABLES_ | |||||
| }; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -86,7 +107,8 @@ namespace Tensorflow | |||||
| #endregion | #endregion | ||||
| public string CONCATENATED_VARIABLES => CONCATENATED_VARIABLES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | /// the subset of `Variable` objects that will be trained by an optimizer. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -106,13 +128,16 @@ namespace Tensorflow | |||||
| /// Key to collect local variables that are local to the machine and are not | /// Key to collect local variables that are local to the machine and are not | ||||
| /// saved/restored. | /// saved/restored. | ||||
| /// </summary> | /// </summary> | ||||
| public string LOCAL_VARIABLES = "local_variables"; | |||||
| public string LOCAL_VARIABLES = LOCAL_VARIABLES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect losses | /// Key to collect losses | ||||
| /// </summary> | /// </summary> | ||||
| public string LOSSES => LOSSES_; | public string LOSSES => LOSSES_; | ||||
| public string METRIC_VARIABLES => METRIC_VARIABLES_; | |||||
| public string MOVING_AVERAGE_VARIABLES = MOVING_AVERAGE_VARIABLES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect Variable objects that are global (shared across machines). | /// Key to collect Variable objects that are global (shared across machines). | ||||
| /// Default collection for all variables, except local ones. | /// Default collection for all variables, except local ones. | ||||
| @@ -82,8 +82,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var sess_graph = sess.GetPrivate<Graph>("_graph"); | var sess_graph = sess.GetPrivate<Graph>("_graph"); | ||||
| sess_graph.Should().NotBeNull(); | sess_graph.Should().NotBeNull(); | ||||
| default_graph.Should().NotBeNull() | default_graph.Should().NotBeNull() | ||||
| .And.BeEquivalentTo(sess_graph) | |||||
| .And.BeEquivalentTo(beforehand); | |||||
| .And.BeEquivalentTo(sess_graph); | |||||
| Console.WriteLine($"{tid}-{default_graph.graph_key}"); | Console.WriteLine($"{tid}-{default_graph.graph_key}"); | ||||