Partially Support the function loadingtags/v0.100.5-BERT-load
| @@ -1,7 +1,7 @@ | |||
| | |||
| Microsoft Visual Studio Solution File, Format Version 12.00 | |||
| # Visual Studio Version 16 | |||
| VisualStudioVersion = 16.0.31624.102 | |||
| # Visual Studio Version 17 | |||
| VisualStudioVersion = 17.4.33213.308 | |||
| MinimumVisualStudioVersion = 10.0.40219.1 | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | |||
| EndProject | |||
| @@ -23,6 +23,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | |||
| EndProject | |||
| Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Common", "Tensorflow.Common\Tensorflow.Common.csproj", "{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}" | |||
| EndProject | |||
| Global | |||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
| Debug|Any CPU = Debug|Any CPU | |||
| @@ -153,6 +155,18 @@ Global | |||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | |||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.Build.0 = Debug|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.ActiveCfg = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.Build.0 = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.Build.0 = Release|Any CPU | |||
| EndGlobalSection | |||
| GlobalSection(SolutionProperties) = preSolution | |||
| HideSolutionNode = FALSE | |||
| @@ -0,0 +1,31 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class DictionaryExtension | |||
| { | |||
| public static void Deconstruct<T1, T2>(this KeyValuePair<T1, T2> pair, out T1 first, out T2 second) | |||
| { | |||
| first = pair.Key; | |||
| second = pair.Value; | |||
| } | |||
| public static void Update<T1, T2>(this Dictionary<T1, T2> dic, IDictionary<T1, T2> other) | |||
| { | |||
| foreach(var (key, value) in other) | |||
| { | |||
| dic[key] = value; | |||
| } | |||
| } | |||
| public static T2 GetOrDefault<T1, T2>(this Dictionary<T1, T2> dic, T1 key, T2 defaultValue) | |||
| { | |||
| if (dic.ContainsKey(key)) | |||
| { | |||
| return dic[key]; | |||
| } | |||
| return defaultValue; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| using OneOf; | |||
| using System; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class OneofExtension | |||
| { | |||
| public static bool IsTypeOrDeriveFrom<T>(this IOneOf src) | |||
| { | |||
| return src.Value is T; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <TargetFramework>netstandard2.0</TargetFramework> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="OneOf" Version="3.0.223" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class NamedTuple | |||
| { | |||
| public string Name { get; set; } | |||
| public Dictionary<string, object> ValueDict { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,17 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class c_api | |||
| { | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| @@ -45,6 +46,23 @@ namespace Tensorflow | |||
| { | |||
| return as_text(bytes_or_text, encoding); | |||
| } | |||
| public ByteString as_bytes(ByteString bytes, Encoding encoding = null) | |||
| { | |||
| return bytes; | |||
| } | |||
| public ByteString as_bytes(byte[] bytes, Encoding encoding = null) | |||
| { | |||
| return ByteString.CopyFrom(bytes); | |||
| } | |||
| public ByteString as_bytes(string text, Encoding encoding = null) | |||
| { | |||
| if(encoding is null) | |||
| { | |||
| encoding = Encoding.UTF8; | |||
| } | |||
| return ByteString.CopyFrom(encoding.GetBytes(text)); | |||
| } | |||
| } | |||
| public bool executing_eagerly() | |||
| @@ -54,6 +54,6 @@ namespace Tensorflow | |||
| Dictionary<string, Tensor> input_map = null, | |||
| string[] return_elements = null, | |||
| string name = null, | |||
| OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list); | |||
| OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list); | |||
| } | |||
| } | |||
| @@ -14,6 +14,8 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| @@ -79,5 +81,10 @@ namespace Tensorflow | |||
| num_split: num_split, | |||
| axis: axis, | |||
| name: name); | |||
| public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | |||
| { | |||
| return gen_ops.ensure_shape(x, shape, name); | |||
| } | |||
| } | |||
| } | |||
| @@ -61,7 +61,7 @@ namespace Tensorflow | |||
| public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Set `num_dims` to -1 to represent "unknown rank". | |||
| @@ -22,6 +22,7 @@ using System.ComponentModel; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -107,6 +107,12 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public void Release() | |||
| { | |||
| _handle.Dispose(); | |||
| _handle = null; | |||
| } | |||
| public override string ToString() | |||
| => $"0x{_handle.DangerousGetHandle():x16}"; | |||
| @@ -25,5 +25,32 @@ namespace Tensorflow | |||
| public IntPtr data; | |||
| public ulong length; | |||
| public IntPtr data_deallocator; | |||
| public unsafe Span<T> AsSpan<T>() where T: unmanaged | |||
| { | |||
| if(length > int.MaxValue) | |||
| { | |||
| throw new ValueError($"The length {length} is too large to use in the span."); | |||
| } | |||
| return new Span<T>(data.ToPointer(), (int)length); | |||
| } | |||
| public unsafe byte[] ToByteArray() | |||
| { | |||
| byte[] res = new byte[length]; | |||
| if(length > int.MaxValue) | |||
| { | |||
| byte* root = (byte*)data; | |||
| for(ulong i = 0; i < length; i++) | |||
| { | |||
| res[i] = *(root++); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan()); | |||
| } | |||
| return res; | |||
| } | |||
| } | |||
| } | |||
| @@ -161,7 +161,7 @@ public static class CheckPointUtils | |||
| internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | |||
| { | |||
| return full_list.TakeWhile(x => | |||
| return full_list.Where(x => | |||
| { | |||
| var saveables = x.gather_saveables_for_checkpoint(); | |||
| return saveables is not null && saveables.Count > 0; | |||
| @@ -1,10 +1,12 @@ | |||
| using System; | |||
| using OneOf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using Tensorflow.Common.Extensions; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| namespace Tensorflow.Checkpoint | |||
| @@ -28,7 +30,7 @@ namespace Tensorflow.Checkpoint | |||
| ); | |||
| public static class SaveUtil | |||
| { | |||
| public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||
| { | |||
| var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||
| @@ -104,7 +106,10 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| var td = trackable_data[i]; | |||
| Debug.Assert(td.node_id == i); | |||
| object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); | |||
| TrackableObjectGraph.Types.TrackableObject trackable_object = new(); | |||
| trackable_object.SlotVariables.AddRange(td.slot_variable_proto); | |||
| trackable_object.Children.AddRange(td.children_proto); | |||
| object_graph_proto.Nodes.Add(trackable_object); | |||
| } | |||
| return object_graph_proto; | |||
| } | |||
| @@ -117,16 +122,16 @@ namespace Tensorflow.Checkpoint | |||
| /// <param name="call_with_mapped_captures"></param> | |||
| /// <param name="cache"></param> | |||
| /// <param name="object_graph_proto"></param> | |||
| private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
| private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||
| bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | |||
| { | |||
| Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
| Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||
| foreach(var td in tensor_trackables) | |||
| { | |||
| // TODO: deal with cache. | |||
| var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||
| Trackable trackable = null; | |||
| IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||
| IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict; | |||
| if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||
| { | |||
| (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | |||
| @@ -148,12 +153,12 @@ namespace Tensorflow.Checkpoint | |||
| return serialized_tensors; | |||
| } | |||
| private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
| private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
| { | |||
| var trackable = trackable_data.object_to_save; | |||
| // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | |||
| IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict; | |||
| IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict; | |||
| if (call_with_mapped_captures) | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -163,8 +168,7 @@ namespace Tensorflow.Checkpoint | |||
| ret_tensor_dict = trackable.serialize_to_tensors(); | |||
| } | |||
| // TODO: deal with the type `SaveSpce` (currently it will never be it). | |||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
| Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new(); | |||
| foreach(var pair in ret_tensor_dict) | |||
| { | |||
| var local_name = TrackableUtils.escape_local_name(pair.Key); | |||
| @@ -173,10 +177,12 @@ namespace Tensorflow.Checkpoint | |||
| tensor_dict[checkpoint_key] = maybe_tensor; | |||
| if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>()) | |||
| foreach(var key in maybe_tensor.Keys) | |||
| { | |||
| throw new NotImplementedException(); | |||
| //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||
| if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>()) | |||
| { | |||
| maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name; | |||
| } | |||
| } | |||
| if(object_graph_proto is not null) | |||
| @@ -200,7 +206,7 @@ namespace Tensorflow.Checkpoint | |||
| /// <param name="call_with_mapped_captures"></param> | |||
| /// <param name="object_graph_proto"></param> | |||
| /// <returns></returns> | |||
| private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
| private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||
| bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||
| { | |||
| Dictionary<Trackable, string> object_names = new(); | |||
| @@ -8,6 +8,7 @@ using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using static Tensorflow.Binding; | |||
| using Google.Protobuf; | |||
| using OneOf; | |||
| namespace Tensorflow.Checkpoint; | |||
| @@ -114,14 +115,10 @@ public static class SaveUtilV1 | |||
| { | |||
| var trackable = trackable_objects[i]; | |||
| Debug.Assert(node_ids[trackable] == i); | |||
| TrackableObjectGraph.Types.TrackableObject object_proto; | |||
| var object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||
| if (slot_variables.TryGetValue(trackable, out var slots)) | |||
| { | |||
| object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); | |||
| } | |||
| else | |||
| { | |||
| object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||
| object_proto.SlotVariables.AddRange(slots); | |||
| } | |||
| object_graph_proto.Nodes.Add(object_proto); | |||
| foreach (var child in graph_view.list_children(trackable)) | |||
| @@ -184,13 +181,13 @@ public static class SaveUtilV1 | |||
| // TODO: tensorflow python has a process with callable `saveable_factory`. | |||
| List<MySaveableObject> saveables = new(); | |||
| if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||
| if (maybe_saveable.TryPickT1(out var s, out var variable)) | |||
| { | |||
| saveables.Add(s); | |||
| } | |||
| else | |||
| { | |||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key)); | |||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(variable as Trackable, key)); | |||
| } | |||
| foreach (var saveable in saveables) | |||
| @@ -222,7 +219,7 @@ public static class SaveUtilV1 | |||
| public record class CheckpointFactoryData | |||
| ( | |||
| Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
| Func<string, OneOf<BaseResourceVariable, MySaveableObject>> factory, | |||
| string name, | |||
| string checkpoint_key | |||
| ); | |||
| @@ -12,6 +12,7 @@ using static Tensorflow.Binding; | |||
| using Tensorflow.Operations; | |||
| using Newtonsoft.Json; | |||
| using Tensorflow.Training; | |||
| using OneOf; | |||
| namespace Tensorflow.Checkpoint; | |||
| @@ -44,12 +45,12 @@ public class TrackableSaver | |||
| _graph_view = graph_view; | |||
| // TODO: cache when not executing eagerly. | |||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder` | |||
| // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||
| } | |||
| private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| private (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||
| { | |||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||
| @@ -70,9 +71,10 @@ public class TrackableSaver | |||
| Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| if (!serialized_tensors.ContainsKey(Trackable.None)) | |||
| { | |||
| serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>(); | |||
| serialized_tensors[Trackable.None] = new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>(); | |||
| } | |||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor); | |||
| return (serialized_tensors, feed_additions, registered_savers, graph_proto); | |||
| } | |||
| @@ -392,6 +394,7 @@ public class CheckpointRestoreCoordinator | |||
| /// </summary> | |||
| public List<Trackable> AllTrackables => _all_trackables; | |||
| public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||
| // TODO(Rinne): change to weak ref. | |||
| public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||
| public int RestoreUid => _restore_uid; | |||
| public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | |||
| @@ -406,7 +409,7 @@ public class CheckpointRestoreCoordinator | |||
| // skip the callback. | |||
| } | |||
| public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||
| public List<Operation> restore_saveables(Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||
| { | |||
| List<Operation> restore_ops = new(); | |||
| foreach(var position in positions) | |||
| @@ -418,7 +421,7 @@ public class CheckpointRestoreCoordinator | |||
| Dictionary<string, BaseResourceVariable> variable_dict = new(); | |||
| foreach(var item in tensor_saveables) | |||
| { | |||
| if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||
| if(item.Value.TryPickT0(out var variable, out var _)) | |||
| { | |||
| variable_dict[item.Key] = variable; | |||
| } | |||
| @@ -15,106 +15,14 @@ using Tensorflow.Graphs; | |||
| using System.Xml.Linq; | |||
| using System.Diagnostics; | |||
| using RestoreFunc = System.Func<object, object>; | |||
| using OneOf; | |||
| namespace Tensorflow.Checkpoint | |||
| { | |||
| public class Maybe<TA, TB> | |||
| { | |||
| private TA? _valueA = default(TA); | |||
| private TB? _valueB = default(TB); | |||
| private Type _type; | |||
| private bool _assignedTA; | |||
| public Maybe(TA value) | |||
| { | |||
| _valueA = value; | |||
| _type= typeof(TA); | |||
| _assignedTA = true; | |||
| } | |||
| public Maybe(TB value) | |||
| { | |||
| _valueB = value; | |||
| _type = typeof(TB); | |||
| _assignedTA = false; | |||
| } | |||
| public Type DataType => _type; | |||
| /// <summary> | |||
| /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. | |||
| /// It returns | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="res"></param> | |||
| /// <returns></returns> | |||
| public bool TryGet<T>(out T? res) | |||
| { | |||
| if(_valueA is T && _valueB is not T) | |||
| { | |||
| res = (T)(object)_valueA; | |||
| return _assignedTA; | |||
| } | |||
| else if(_valueA is not T && _valueB is T) | |||
| { | |||
| res = (T)(object)_valueB; | |||
| return !_assignedTA; | |||
| } | |||
| res = default(T); | |||
| return false; | |||
| } | |||
| public bool IsTypeOrDeriveFrom<T>() | |||
| { | |||
| if (_valueA is T && _valueB is not T) | |||
| { | |||
| return _assignedTA; | |||
| } | |||
| else if (_valueA is not T && _valueB is T) | |||
| { | |||
| return !_assignedTA; | |||
| } | |||
| else if (_valueA is T && _valueB is T) | |||
| { | |||
| return true; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| public T GetValue<T>() | |||
| { | |||
| if (_valueA is T && _valueB is not T) | |||
| { | |||
| return (T)(object)_valueA; | |||
| } | |||
| else if (_valueA is not T && _valueB is T) | |||
| { | |||
| return (T)(object)_valueB; | |||
| } | |||
| else if (_valueA is T && _valueB is T) | |||
| { | |||
| throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); | |||
| } | |||
| } | |||
| public static implicit operator Maybe<TA, TB>(TA a) | |||
| { | |||
| return new Maybe<TA, TB>(a); | |||
| } | |||
| public static implicit operator Maybe<TA, TB>(TB b) | |||
| { | |||
| return new Maybe<TA, TB>(b); | |||
| } | |||
| } | |||
| internal class SingleDeviceSaver | |||
| { | |||
| private IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> _tensor_slice_dict; | |||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, Maybe<Tensor, SaveSpec>>> tensor_slice_dict) | |||
| private IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> _tensor_slice_dict; | |||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_slice_dict) | |||
| { | |||
| _tensor_slice_dict = tensor_slice_dict; | |||
| } | |||
| @@ -122,15 +30,15 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| _tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||
| x => x.Key, x => x.Value.ToDictionary( | |||
| y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||
| as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||
| y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT0(y.Value)) | |||
| as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||
| } | |||
| public SingleDeviceSaver(IDictionary<string, IDictionary<string, SaveSpec>> tensor_slice_dict) | |||
| { | |||
| _tensor_slice_dict = tensor_slice_dict.ToDictionary( | |||
| x => x.Key, x => x.Value.ToDictionary( | |||
| y => y.Key, y => new Maybe<Tensor, SaveSpec>(y.Value)) | |||
| as IDictionary<string, Maybe<Tensor, SaveSpec>>); | |||
| y => y.Key, y => OneOf<Tensor, SaveSpec>.FromT1(y.Value)) | |||
| as IDictionary<string, OneOf<Tensor, SaveSpec>>); | |||
| } | |||
| public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||
| { | |||
| @@ -149,7 +57,7 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| var slice_spec = slice.Key; | |||
| var maybe_tensor = slice.Value; | |||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||
| if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||
| { | |||
| var tensor_value = spec.tensor; | |||
| if (tensor_value is not null) | |||
| @@ -161,7 +69,6 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| else | |||
| { | |||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||
| tensor_names.Add(checkpoint_key); | |||
| tensors.Add(tensor); | |||
| slice_specs.Add(slice_spec); | |||
| @@ -193,7 +100,7 @@ namespace Tensorflow.Checkpoint | |||
| var slice_spec = slice.Key; | |||
| var maybe_tensor = slice.Value; | |||
| // TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||
| if(maybe_tensor.TryPickT1(out var spec, out var tensor)) | |||
| { | |||
| tensor_dtypes.Add(spec.dtype); | |||
| slice_specs.Add(spec.slice_spec); | |||
| @@ -201,7 +108,6 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| else | |||
| { | |||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||
| tensor_dtypes.Add(tensor.dtype); | |||
| slice_specs.Add(slice_spec); | |||
| tensor_names.Add(checkpoint_key); | |||
| @@ -256,12 +162,12 @@ namespace Tensorflow.Checkpoint | |||
| /// <param name="serialized_tensors"> A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. </param> | |||
| /// <param name="registered_savers"></param> | |||
| /// <param name="call_with_mapped_capture"></param> | |||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors, | |||
| IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||
| { | |||
| _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | |||
| _restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | |||
| Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||
| Dictionary<string, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> tensors_by_device= new(); | |||
| foreach(var pair in serialized_tensors) | |||
| { | |||
| @@ -276,9 +182,9 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| restore_fn = new RestoreFunc(x => | |||
| { | |||
| if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) | |||
| if(x is IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>) | |||
| { | |||
| return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>); | |||
| return obj._restore_from_tensors(x as IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>); | |||
| } | |||
| throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | |||
| }); | |||
| @@ -287,16 +193,7 @@ namespace Tensorflow.Checkpoint | |||
| foreach(var item in tensor_dict) | |||
| { | |||
| var checkpoint_key = item.Key; | |||
| IDictionary<string, Tensor> spec_to_tensor; | |||
| if(item.Value.TryGet<Tensor>(out var t)) | |||
| { | |||
| spec_to_tensor = new Dictionary<string, Tensor>(); | |||
| spec_to_tensor[""] = t; | |||
| } | |||
| else | |||
| { | |||
| spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>(); | |||
| } | |||
| var spec_to_tensor = item.Value; | |||
| foreach(var spec in spec_to_tensor) | |||
| { | |||
| @@ -311,12 +208,20 @@ namespace Tensorflow.Checkpoint | |||
| _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | |||
| _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | |||
| // skip the process of device name because lack of API. | |||
| var host_device = tensor.Device; | |||
| var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, Tensor>>()); | |||
| string host_device; | |||
| if (tensor.IsT0) | |||
| { | |||
| host_device = tensor.AsT0.Device; | |||
| } | |||
| else | |||
| { | |||
| host_device = tensor.AsT1.device; | |||
| } | |||
| host_device = saveable_object_util.set_cpu0(host_device); | |||
| var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>()); | |||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||
| { | |||
| internal_dict[checkpoint_key] = new Dictionary<string, Tensor>(); | |||
| internal_dict[checkpoint_key] = new Dictionary<string, OneOf<Tensor, SaveSpec>>(); | |||
| } | |||
| internal_dict[checkpoint_key][slice_spec] = tensor; | |||
| } | |||
| @@ -412,7 +317,7 @@ namespace Tensorflow.Checkpoint | |||
| IDictionary<string, Operation> restore_func() | |||
| { | |||
| Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||
| Dictionary<RestoreFunc, IDictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||
| Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||
| Dictionary<string, Operation> restore_ops = new(); | |||
| @@ -433,29 +338,29 @@ namespace Tensorflow.Checkpoint | |||
| var slice_spec = item.Key; | |||
| var tensor = item.Value; | |||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>()); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| { | |||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||
| { | |||
| Dictionary<string, Tensor> dict = new(); | |||
| dict[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict); | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor); | |||
| } | |||
| restore_fn_input_count[restore_fn]--; | |||
| if (restore_fn_input_count[restore_fn] == 0) | |||
| { | |||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
| foreach (var input in restore_fn_inputs[restore_fn]) | |||
| { | |||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
| @@ -538,7 +443,7 @@ namespace Tensorflow.Checkpoint | |||
| public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||
| { | |||
| Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
| Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new(); | |||
| foreach (var saveable in saveables) | |||
| { | |||
| var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | |||
| @@ -1,7 +1,9 @@ | |||
| using System; | |||
| using OneOf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Security; | |||
| using System.Text; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| @@ -49,7 +51,7 @@ public class CheckpointPosition | |||
| { | |||
| _checkpoint.AllTrackables.Add(trackable); | |||
| _checkpoint.MatchedProtoIds.Add(_proto_id); | |||
| if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||
| if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null) | |||
| { | |||
| // skip the `logging.warning`. | |||
| return false; | |||
| @@ -61,13 +63,13 @@ public class CheckpointPosition | |||
| } | |||
| } | |||
| public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||
| public (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||
| { | |||
| // skip the registered_saver | |||
| if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
| return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||
| new List<CheckpointPosition>(), null); | |||
| } | |||
| @@ -75,7 +77,7 @@ public class CheckpointPosition | |||
| List<Operation> existing_restore_ops; | |||
| List<CheckpointPosition> positions = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables; | |||
| if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
| { | |||
| (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | |||
| @@ -109,8 +111,8 @@ public class CheckpointPosition | |||
| /// Creates a saveable using the _serialize_to_tensor method. | |||
| /// </summary> | |||
| /// <param name="saveable_factories"></param> | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||
| IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| { | |||
| string suffix = SaveableCompat.get_saveable_name(this.Trackable); | |||
| suffix = suffix ?? ""; | |||
| @@ -124,23 +126,23 @@ public class CheckpointPosition | |||
| var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | |||
| // skip the cache. | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> dict = new(); | |||
| dict[saveable_name] = saveable; | |||
| return (new List<Operation>(), dict); | |||
| } | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||
| IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| if(ObjectProto.Attributes is null) | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||
| return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>()); | |||
| } | |||
| List<Operation> existing_restore_ops = new(); | |||
| HashSet<string> created_compat_names = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||
| foreach (var serialized_tensor in ObjectProto.Attributes) | |||
| { | |||
| Operation existing_op; | |||
| @@ -172,12 +174,12 @@ public class CheckpointPosition | |||
| _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | |||
| continue; | |||
| } | |||
| named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||
| named_saveables[serialized_tensor.CheckpointKey] = saveable.Value; | |||
| } | |||
| return (existing_restore_ops, named_saveables); | |||
| } | |||
| private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||
| private OneOf<BaseResourceVariable, MySaveableObject>? _get_saveable_from_factory(IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||
| TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | |||
| { | |||
| var expected_factory_name = serialized_tensor.Name; | |||
| @@ -221,7 +223,7 @@ public class CheckpointPosition | |||
| Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | |||
| visit_queue.Enqueue((this, this.Trackable)); | |||
| List<Operation> restore_ops = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||
| Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||
| List<CheckpointPosition> positions = new(); | |||
| CheckpointPosition current_position = null; | |||
| @@ -306,7 +308,7 @@ public class CheckpointPosition | |||
| } | |||
| } | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||
| private (List<Operation>, Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||
| { | |||
| var trackable = this.Trackable; | |||
| trackable._maybe_initialize_trackable(); | |||
| @@ -318,7 +320,7 @@ public class CheckpointPosition | |||
| } | |||
| else | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
| return (new List<Operation>(), new Dictionary<string, OneOf<BaseResourceVariable, MySaveableObject>>(), | |||
| new List<CheckpointPosition>(), null); | |||
| } | |||
| } | |||
| @@ -14,9 +14,11 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Extensions; | |||
| namespace Tensorflow.Contexts | |||
| { | |||
| @@ -25,12 +27,93 @@ namespace Tensorflow.Contexts | |||
| /// </summary> | |||
| public sealed partial class Context | |||
| { | |||
| public ConfigProto Config { get; set; } = new ConfigProto | |||
| protected Device.PhysicalDevice[] _physical_devices; | |||
| protected Dictionary<Device.PhysicalDevice, int> _physical_device_to_index; | |||
| ConfigProto _config; | |||
| public ConfigProto Config | |||
| { | |||
| GpuOptions = new GPUOptions | |||
| get | |||
| { | |||
| _initialize_physical_devices(); | |||
| var config = new ConfigProto(); | |||
| if(_config is not null) | |||
| { | |||
| config.MergeFrom(_config); | |||
| } | |||
| config.LogDevicePlacement = _log_device_placement; | |||
| config.DeviceCount["CPU"] = 0; | |||
| config.DeviceCount["GPU"] = 0; | |||
| foreach(var dev in _physical_devices) | |||
| { | |||
| if (config.DeviceCount.ContainsKey(dev.DeviceType)) | |||
| { | |||
| config.DeviceCount[dev.DeviceType] += 1; | |||
| } | |||
| else | |||
| { | |||
| config.DeviceCount[dev.DeviceType] = 1; | |||
| } | |||
| } | |||
| var gpu_options = _compute_gpu_options(); | |||
| config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray()); | |||
| return config; | |||
| } | |||
| set | |||
| { | |||
| _config = value; | |||
| } | |||
| } | |||
| protected void _initialize_physical_devices(bool reinitialize = false) | |||
| { | |||
| if(!reinitialize && _physical_devices is not null) | |||
| { | |||
| return; | |||
| } | |||
| var devs = list_physical_devices(); | |||
| _physical_devices = devs.Select(d => new Device.PhysicalDevice() | |||
| { | |||
| DeviceName = d.DeviceName, | |||
| DeviceType = d.DeviceType | |||
| }).ToArray(); | |||
| _physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair<Device.PhysicalDevice, int>(p, i)) | |||
| .ToDictionary(x => x.Key, x => x.Value); | |||
| _import_config(); | |||
| } | |||
| protected void _import_config() | |||
| { | |||
| if(_config is null) | |||
| { | |||
| return; | |||
| } | |||
| if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus)) | |||
| { | |||
| num_cpus = 1; | |||
| } | |||
| if(num_cpus != 1) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| }; | |||
| var gpus = _physical_devices.Where(d => d.DeviceType == "GPU"); | |||
| if(gpus.Count() == 0) | |||
| { | |||
| return; | |||
| } | |||
| if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count)) | |||
| { | |||
| gpu_count = 0; | |||
| } | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| ConfigProto MergeConfig() | |||
| { | |||
| @@ -111,6 +111,14 @@ namespace Tensorflow.Contexts | |||
| return results.ToArray(); | |||
| } | |||
| public bool is_custom_device(string device_name) | |||
| { | |||
| return false; | |||
| // TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs. | |||
| //ensure_initialized(); | |||
| //return c_api.TFE_IsCustomDevice(_handle, device_name); | |||
| } | |||
| public EagerDeviceContext device(string name) | |||
| { | |||
| return new EagerDeviceContext(this, name); | |||
| @@ -37,7 +37,26 @@ namespace Tensorflow.Contexts | |||
| public string ScopeName { get; set; } = ""; | |||
| bool initialized = false; | |||
| ContextSwitchStack context_switches; | |||
| public FunctionCallOptions FunctionCallOptions { get; } | |||
| protected FunctionCallOptions _function_call_options; | |||
| public FunctionCallOptions FunctionCallOptions | |||
| { | |||
| get | |||
| { | |||
| if(_function_call_options is null) | |||
| { | |||
| var config = Config; | |||
| _function_call_options = new FunctionCallOptions() | |||
| { | |||
| Config = config | |||
| }; | |||
| } | |||
| return _function_call_options; | |||
| } | |||
| set | |||
| { | |||
| _function_call_options = value; | |||
| } | |||
| } | |||
| SafeContextHandle _handle; | |||
| @@ -122,6 +141,11 @@ namespace Tensorflow.Contexts | |||
| name : | |||
| "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||
| public string anonymous_name() | |||
| { | |||
| return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | |||
| } | |||
| public void graph_mode(bool isFunc = false) | |||
| => context_switches.Push(false, isFunc); | |||
| @@ -158,6 +182,37 @@ namespace Tensorflow.Contexts | |||
| return has_graph_arg; | |||
| } | |||
| public bool has_function(string name) | |||
| { | |||
| ensure_initialized(); | |||
| return c_api.TFE_ContextHasFunction(_handle, name); | |||
| } | |||
| public void add_function(SafeFuncGraphHandle fn) | |||
| { | |||
| ensure_initialized(); | |||
| Status status = new(); | |||
| c_api.TFE_ContextAddFunction(_handle, fn, status); | |||
| status.Check(true); | |||
| } | |||
| public void remove_function(string name) | |||
| { | |||
| ensure_initialized(); | |||
| Status status = new(); | |||
| c_api.TFE_ContextRemoveFunction(_handle, name, status); | |||
| status.Check(true); | |||
| } | |||
| public void add_function_def(FunctionDef fdef) | |||
| { | |||
| ensure_initialized(); | |||
| var fdef_string = fdef.ToByteArray(); | |||
| Status status = new Status(); | |||
| c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status); | |||
| status.Check(true); | |||
| } | |||
| public void restore_mode() | |||
| { | |||
| context_switches.Pop(); | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Google.Protobuf; | |||
| using Protobuf.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Contexts | |||
| @@ -9,10 +10,11 @@ namespace Tensorflow.Contexts | |||
| public class FunctionCallOptions | |||
| { | |||
| public ConfigProto Config { get; set; } | |||
| public string ExecutorType { get; set; } | |||
| public string config_proto_serialized() | |||
| public ByteString config_proto_serialized() | |||
| { | |||
| return Config.ToByteString().ToStringUtf8(); | |||
| return Config.ToByteString(); | |||
| } | |||
| } | |||
| } | |||
| @@ -12,18 +12,36 @@ namespace Tensorflow.Eager | |||
| return HasGradientTape(); | |||
| } | |||
| private bool ShouldRecord(Tensor[] inputs) | |||
| public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors) | |||
| { | |||
| bool should_record = false; | |||
| foreach (var tape in tf.GetTapeSet()) | |||
| var tape_set = tf.GetTapeSet(); | |||
| var input_ids = MakeTensorIDList(tensors); | |||
| var input_dtypes = MakeTensorDtypeList(tensors); | |||
| bool some_tape_watching = false; | |||
| if (tape_set is not null && tape_set.Count > 0) | |||
| { | |||
| if (tape.ShouldRecord(inputs)) | |||
| foreach (var tape in tape_set) | |||
| { | |||
| should_record = true; | |||
| break; | |||
| if (tape.ShouldRecord(input_ids, input_dtypes)) | |||
| { | |||
| if (tape.Persistent || some_tape_watching) | |||
| { | |||
| return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||
| } | |||
| some_tape_watching = true; | |||
| } | |||
| } | |||
| } | |||
| return should_record; | |||
| // skip the forward_accumulators. | |||
| if (some_tape_watching) | |||
| { | |||
| return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||
| } | |||
| else | |||
| { | |||
| return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -13,7 +13,17 @@ namespace Tensorflow.Eager | |||
| Tensor[] results, | |||
| BackwardFunction backwardFunction = null) | |||
| { | |||
| bool should_record = ShouldRecord(inputs); | |||
| var input_ids = MakeTensorIDList(inputs); | |||
| var input_dtypes = MakeTensorDtypeList(inputs); | |||
| bool should_record = false; | |||
| foreach (var tape in tf.GetTapeSet()) | |||
| { | |||
| if (tape.ShouldRecord(input_ids, input_dtypes)) | |||
| { | |||
| should_record = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!should_record) | |||
| { | |||
| @@ -59,7 +69,7 @@ namespace Tensorflow.Eager | |||
| op_inputs = inputs;*/ | |||
| backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); | |||
| TapeSetRecordOperation(op_name, inputs, results, backwardFunction); | |||
| TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction); | |||
| return true; | |||
| } | |||
| @@ -129,10 +139,5 @@ namespace Tensorflow.Eager | |||
| { | |||
| return HasGradientTape(); | |||
| } | |||
| TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||
| { | |||
| return tensors.Select(x => x.dtype).ToArray(); | |||
| } | |||
| } | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Functions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Eager | |||
| @@ -358,7 +358,7 @@ namespace Tensorflow.Eager | |||
| break; | |||
| case TF_AttrType.TF_ATTR_FUNC: | |||
| if (value is ConcreteFunction func) | |||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); | |||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | |||
| else | |||
| throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | |||
| break; | |||
| @@ -1,6 +1,8 @@ | |||
| using System; | |||
| using OneOf.Types; | |||
| using System; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| @@ -9,40 +11,183 @@ namespace Tensorflow.Eager | |||
| /// </summary> | |||
| public partial class EagerRunner | |||
| { | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="tape"></param> | |||
| /// <param name="target"></param> | |||
| /// <param name="sources"></param> | |||
| /// <param name="output_gradients"></param> | |||
| /// <param name="unconnected_gradients">determines the value returned if the target and | |||
| /// sources are unconnected.When 'none' the value returned is None wheras when | |||
| /// 'zero' a zero tensor in the same shape as the sources is returned.</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public Tensor[] TFE_TapeGradient(ITape tape, | |||
| Tensor[] target, | |||
| Tensor[] sources, | |||
| Tensor[] output_gradients) | |||
| List<Tensor> output_gradients, | |||
| Tensor[] sources_raw, | |||
| string unconnected_gradients) | |||
| { | |||
| var target_vec = target; | |||
| var sources_vec = sources; | |||
| var sources_set = sources_vec; | |||
| if (!tape.Persistent) | |||
| { | |||
| var tape_set = tf.GetTapeSet(); | |||
| if (tape_set.Contains(tape)) | |||
| { | |||
| throw new RuntimeError("gradient() cannot be invoked within the " + | |||
| "GradientTape context (i.e., while operations are being " + | |||
| "recorded). Either move the call to gradient() to be " + | |||
| "outside the 'with tf.GradientTape' block, or " + | |||
| "use a persistent tape: " + | |||
| "'with tf.GradientTape(persistent=true)'"); | |||
| } | |||
| } | |||
| var target_vec = MakeTensorIDList(target); | |||
| var sources_vec = MakeTensorIDList(sources); | |||
| HashSet<long> sources_set = new HashSet<long>(sources_vec); | |||
| var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>(); | |||
| int len = target.Length; | |||
| for(int i = 0; i < len; i++) | |||
| { | |||
| var target_id = target_vec[i]; | |||
| if (sources_set.Contains(target_id)) | |||
| { | |||
| var tensor = target[i]; | |||
| source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor); | |||
| } | |||
| } | |||
| List<Tensor> outgrad_vec = new(); | |||
| if(output_gradients is not null) | |||
| { | |||
| outgrad_vec = output_gradients.ToList(); | |||
| } | |||
| var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||
| var seq_array = target; | |||
| var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>(); | |||
| for (int i = 0; i < target.Length; ++i) | |||
| bool unconnected_gradients_zero = unconnected_gradients == "zero"; | |||
| Tensor[] sources_obj = null; | |||
| if (unconnected_gradients_zero) | |||
| { | |||
| source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i])); | |||
| sources_obj = MakeTensorList(sources_raw); | |||
| } | |||
| if (output_gradients != null) | |||
| if (result.Length > 0) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| for(int i = 0; i < result.Length; i++) | |||
| { | |||
| if (result[i] is null && unconnected_gradients_zero) | |||
| { | |||
| var dtype = sources_obj[i].dtype; | |||
| result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike(); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| return result; | |||
| } | |||
| Tensor[] MakeTensorList(IEnumerable<Tensor> tensors) | |||
| { | |||
| return tensors.ToArray(); | |||
| } | |||
| long[] MakeTensorIDList(Tensor[] tensors) | |||
| { | |||
| int len = tensors.Length; | |||
| long[] ids = new long[len]; | |||
| for(int i = 0; i < len; i++) | |||
| { | |||
| var tensor = tensors[i]; | |||
| ids[i] = tensor.Id; | |||
| } | |||
| return ids; | |||
| } | |||
| TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||
| { | |||
| int len = tensors.Length; | |||
| TF_DataType[] dtypes = new TF_DataType[len]; | |||
| for (int i = 0; i < len; i++) | |||
| { | |||
| output_gradients = new Tensor[0]; | |||
| var tensor = tensors[i]; | |||
| dtypes[i] = tensor.dtype; | |||
| } | |||
| return dtypes; | |||
| } | |||
| var outgrad_vec = MakeTensorList(output_gradients); | |||
| TapeTensor TapeTensorFromTensor(Tensor tensor) | |||
| { | |||
| long id = tensor.Id; | |||
| var dtype = tensor.dtype; | |||
| if (tensor is EagerTensor) | |||
| { | |||
| var handle = tensor.EagerTensorHandle; | |||
| if (DTypeNeedsHandleData(dtype)) | |||
| { | |||
| return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor); | |||
| } | |||
| Status status = new(); | |||
| int num_dims = c_api.TFE_TensorHandleNumDims(handle, status); | |||
| long[] dims = new long[num_dims]; | |||
| for(int i = 0; i < num_dims; i++) | |||
| { | |||
| dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | |||
| } | |||
| Shape tensor_shape = new(dims); | |||
| if(status.Code != TF_Code.TF_OK) | |||
| { | |||
| return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null); | |||
| } | |||
| else | |||
| { | |||
| return new TapeTensor(id, dtype, tensor_shape); | |||
| } | |||
| } | |||
| var shape_tuple = tensor.shape.dims; | |||
| if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype)) | |||
| { | |||
| return new TapeTensor(id, dtype, tensor); | |||
| } | |||
| long[] l = new long[shape_tuple.Length]; | |||
| for(int i = 0; i < shape_tuple.Length; i++) | |||
| { | |||
| if (shape_tuple[i] < 0) | |||
| { | |||
| l[i] = 0; | |||
| } | |||
| else | |||
| { | |||
| l[i] = shape_tuple[i]; | |||
| } | |||
| } | |||
| return new TapeTensor(id, dtype, new Shape(l)); | |||
| } | |||
| return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec); | |||
| bool DTypeNeedsHandleData(TF_DataType dtype) | |||
| { | |||
| return dtype == dtypes.variant || dtype == dtypes.resource; | |||
| } | |||
| Tensor[] MakeTensorList(Tensor[] tensors) | |||
| bool ListContainNone(long[] list) | |||
| { | |||
| return tensors; | |||
| int len = list.Length; | |||
| if(len == 0) | |||
| { | |||
| return true; | |||
| } | |||
| for(int i = 0; i < len; i++) | |||
| { | |||
| if (list[i] == -1) | |||
| { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| @@ -7,8 +7,9 @@ namespace Tensorflow.Eager | |||
| public partial class EagerRunner | |||
| { | |||
| void TapeSetRecordBackprop(string op_type, | |||
| Tensor[] input_tensors, | |||
| TapeTensor[] output_tensors, | |||
| TapeTensor[] output_info, | |||
| long[] input_ids, | |||
| TF_DataType[] input_detyps, | |||
| BackwardFunction backward_function) | |||
| { | |||
| if (!CouldBackprop()) | |||
| @@ -18,7 +19,7 @@ namespace Tensorflow.Eager | |||
| foreach (var tape in tf.GetTapeSet()) | |||
| { | |||
| tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function); | |||
| tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function); | |||
| } | |||
| } | |||
| } | |||
| @@ -10,18 +10,28 @@ namespace Tensorflow.Eager | |||
| public bool TapeSetRecordOperation(string op_type, | |||
| Tensor[] input_tensors, | |||
| Tensor[] output_tensors, | |||
| long[] input_ids, | |||
| TF_DataType[] input_dtypes, | |||
| BackwardFunction backward_function) | |||
| { | |||
| var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray(); | |||
| var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray(); | |||
| if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, | |||
| backward_function)) | |||
| return false; | |||
| TapeSetRecordBackprop(op_type, input_tensors, output_info, | |||
| TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes, | |||
| backward_function); | |||
| return true; | |||
| } | |||
| public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||
| Tensor[] input_tensors, BackwardFunction backward_function) | |||
| { | |||
| var input_ids = MakeTensorIDList(input_tensors); | |||
| var input_dtypes = MakeTensorDtypeList(input_tensors); | |||
| TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes, | |||
| backward_function); | |||
| } | |||
| } | |||
| } | |||
| @@ -29,7 +29,14 @@ namespace Tensorflow.Eager | |||
| Tensor[] TFE_TapeGradient(ITape tape, | |||
| Tensor[] target, | |||
| Tensor[] sources, | |||
| Tensor[] output_gradients); | |||
| List<Tensor> output_gradients, | |||
| Tensor[] sources_raw, | |||
| string unconnected_gradients); | |||
| void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||
| Tensor[] input_tensors, BackwardFunction backward_function); | |||
| int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors); | |||
| bool RecordGradient(string op_name, | |||
| Tensor[] inputs, | |||
| @@ -0,0 +1,53 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| internal static class backprop_util | |||
| { | |||
| // TODO: add quantized_dtypes (after being supported). | |||
| private static HashSet<TF_DataType> _trainable_dtypes = new HashSet<TF_DataType>(new TF_DataType[] | |||
| { | |||
| dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, | |||
| dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16 | |||
| }); | |||
| public static bool IsTrainable(Tensor tensor) | |||
| { | |||
| var dtype = _DTypeFromTensor(tensor); | |||
| return _trainable_dtypes.Contains(dtype); | |||
| } | |||
| public static bool IsTrainable(TF_DataType dtype) | |||
| { | |||
| return _trainable_dtypes.Contains(dtype); | |||
| } | |||
| private static TF_DataType _DTypeFromTensor(Tensor tensor) | |||
| { | |||
| var dtype = tensor.dtype; | |||
| if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT) | |||
| { | |||
| CppShapeInferenceResult.Types.HandleData handle_data; | |||
| if (tensor is EagerTensor) | |||
| { | |||
| handle_data = tensor.HandleData; | |||
| } | |||
| else | |||
| { | |||
| handle_data = handle_data_util.get_resource_handle_data(tensor); | |||
| } | |||
| if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null && | |||
| handle_data.ShapeAndType.Count > 0) | |||
| { | |||
| var first_type = handle_data.ShapeAndType[0].Dtype; | |||
| if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type)) | |||
| { | |||
| return first_type.as_tf_dtype(); | |||
| } | |||
| } | |||
| } | |||
| return dtype; | |||
| } | |||
| } | |||
| } | |||
| @@ -30,6 +30,9 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); | |||
| @@ -277,7 +280,7 @@ namespace Tensorflow | |||
| public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||
| public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// | |||
| @@ -480,5 +483,8 @@ namespace Tensorflow | |||
| IntPtr[] target, int target_size, | |||
| IntPtr[] sources, int source_size, | |||
| IntPtr[] outputs, int output_size); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name); | |||
| } | |||
| } | |||
| @@ -18,6 +18,10 @@ namespace Tensorflow.Eager | |||
| var types = v.Select(t => t.dtype.as_datatype_enum()); | |||
| return (types.ToArray(), v.ToArray()); | |||
| } | |||
| public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||
| { | |||
| return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name); | |||
| } | |||
| public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||
| { | |||
| string device_name = ctx.DeviceName; | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public class TangentInfo | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| public object Indices { get; set; } | |||
| public object Tangents { get; set; } | |||
| } | |||
| } | |||
| @@ -6,8 +6,11 @@ | |||
| public class DenseSpec : TypeSpec | |||
| { | |||
| protected Shape _shape; | |||
| public Shape shape => _shape; | |||
| public Shape shape | |||
| { | |||
| get { return _shape; } | |||
| set { _shape = value; } | |||
| } | |||
| protected TF_DataType _dtype; | |||
| public TF_DataType dtype => _dtype; | |||
| @@ -1,6 +0,0 @@ | |||
| namespace Tensorflow.Framework.Models | |||
| { | |||
| class ScopedTFFunction | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Framework | |||
| { | |||
| internal class ScopedTFFunction | |||
| { | |||
| SafeFuncGraphHandle _handle; | |||
| string _name; | |||
| public ScopedTFFunction(SafeFuncGraphHandle func, string name) | |||
| { | |||
| _handle = func; | |||
| _name = name; | |||
| } | |||
| public SafeFuncGraphHandle Get() | |||
| { | |||
| return _handle; | |||
| } | |||
| } | |||
| } | |||
| @@ -111,7 +111,17 @@ namespace Tensorflow | |||
| public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); | |||
| public static Buffer tf_buffer(byte[] data) => new Buffer(data); | |||
| public static Buffer tf_buffer(byte[] data = null) | |||
| { | |||
| if(data is not null) | |||
| { | |||
| return new Buffer(data); ; | |||
| } | |||
| else | |||
| { | |||
| return new Buffer(); | |||
| } | |||
| } | |||
| public static IEnumerable<Operation> new_tf_operations(Graph graph) | |||
| { | |||
| @@ -0,0 +1,297 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Security.Cryptography; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Common.Extensions; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| namespace Tensorflow.Framework | |||
| { | |||
| public class function_def_lib | |||
| { | |||
| // TODO(Rinne): process signatures and structured outputs. | |||
| public static FuncGraph function_def_to_graph(FunctionDef fdef, object? structured_input_signature, | |||
| object? structured_outputs, List<TensorShapeProto> input_shapes = null) | |||
| { | |||
| var func_graph = new FuncGraph(fdef.Signature.Name); | |||
| if(input_shapes is null) | |||
| { | |||
| if(fdef.Attr.TryGetValue("_input_shapes", out var input_shapes_attr)) | |||
| { | |||
| var raw_input_shapes = input_shapes_attr.List.Shape; | |||
| input_shapes = new List<TensorShapeProto>(); | |||
| foreach(var (input_shape, arg_def) in raw_input_shapes.Zip(fdef.Signature.InputArg, (x, y) => (x, y))) | |||
| { | |||
| if(arg_def.Type == DataType.DtResource && arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||
| { | |||
| input_shapes.Add(null); | |||
| } | |||
| else | |||
| { | |||
| input_shapes.Add(input_shape); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| var (graph_def, nested_to_flat_tensor_name) = function_def_to_graph_def(fdef, input_shapes); | |||
| func_graph.as_default(); | |||
| importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||
| var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||
| func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
| var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||
| func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
| // TODO(Rinne): func_graph.ControlOutputs | |||
| _set_handle_data(func_graph, fdef); | |||
| foreach(var node in graph_def.Node) | |||
| { | |||
| if(node.Attr.TryGetValue("_output_shapes", out var output_shapes)) | |||
| { | |||
| var op = func_graph.get_operation_by_name(node.Name); | |||
| foreach(var (output_index, shape) in enumerate(output_shapes.List.Shape.Take(op.outputs.Length))) | |||
| { | |||
| op.outputs[output_index].shape = new Shape(shape); | |||
| } | |||
| } | |||
| } | |||
| Dictionary<long, string> output_names = new(); | |||
| foreach(var (ret_arg_def, tensor_name) in zip(fdef.Signature.OutputArg, output_tensor_names)) | |||
| { | |||
| output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; | |||
| } | |||
| func_graph._output_names = output_names; | |||
| func_graph.Exit(); | |||
| return func_graph; | |||
| } | |||
| public static (GraphDef, Dictionary<string, string>) function_def_to_graph_def(FunctionDef fdef, List<TensorShapeProto> input_shapes) | |||
| { | |||
| var graph_def = new GraphDef() | |||
| { | |||
| Versions = new VersionDef() | |||
| { | |||
| Producer = versions.GRAPH_DEF_VERSION, | |||
| MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||
| } | |||
| }; | |||
| var default_graph = ops.get_default_graph(); | |||
| if(input_shapes is not null && input_shapes.Count > 0 && input_shapes.Count != fdef.Signature.InputArg.Count) | |||
| { | |||
| throw new ValueError($"Length of `input_shapes` must match the number " + | |||
| $"of `input_arg`s in `fdef`. Got {input_shapes.Count} `input_shapes` and " + | |||
| $"{fdef.Signature.InputArg.Count} `input_arg`s."); | |||
| } | |||
| foreach(var (i, arg_def) in enumerate(fdef.Signature.InputArg)) | |||
| { | |||
| NodeDef node_def = new(); | |||
| node_def.Name = arg_def.Name; | |||
| node_def.Op = "Placeholder"; | |||
| node_def.Attr["dtype"] = new AttrValue() | |||
| { | |||
| Type = arg_def.Type | |||
| }; | |||
| if(input_shapes is not null && input_shapes.Count > 0 && input_shapes[i] is not null) | |||
| { | |||
| var input_shape = input_shapes[i]; | |||
| // skip the condition that input_shape is not `TensorShapeProto`. | |||
| AttrValue shape = new AttrValue() | |||
| { | |||
| Shape = new TensorShapeProto() | |||
| }; | |||
| shape.Shape = new TensorShapeProto(input_shape); | |||
| node_def.Attr["shape"] = shape; | |||
| } | |||
| if (!fdef.ArgAttr.ContainsKey((uint)i)) | |||
| { | |||
| fdef.ArgAttr[(uint)i] = new FunctionDef.Types.ArgAttrs(); | |||
| } | |||
| var arg_attrs = fdef.ArgAttr[(uint)i].Attr; | |||
| foreach(var k in arg_attrs.Keys) | |||
| { | |||
| if(k == "_output_shapes") | |||
| { | |||
| if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.List) | |||
| { | |||
| node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].List.Shape[0]); | |||
| } | |||
| else if (arg_attrs[k].ValueCase == AttrValue.ValueOneofCase.Shape) | |||
| { | |||
| node_def.Attr["shape"].Shape = new TensorShapeProto(arg_attrs[k].Shape); | |||
| } | |||
| } | |||
| else if (k.StartsWith("_")) | |||
| { | |||
| if (!node_def.Attr.ContainsKey(k)) | |||
| { | |||
| node_def.Attr[k] = new AttrValue(); | |||
| } | |||
| node_def.Attr[k] = new AttrValue(arg_attrs[k]); | |||
| } | |||
| } | |||
| graph_def.Node.Add(node_def); | |||
| } | |||
| graph_def.Node.AddRange(fdef.NodeDef); | |||
| Dictionary<string, string> nested_to_flat_tensor_name = new(); | |||
| foreach(var arg_def in fdef.Signature.InputArg) | |||
| { | |||
| nested_to_flat_tensor_name[arg_def.Name] = $"{arg_def.Name}:0"; | |||
| string control_name = "^" + arg_def.Name; | |||
| nested_to_flat_tensor_name[control_name] = control_name; | |||
| } | |||
| foreach(var node_def in fdef.NodeDef) | |||
| { | |||
| var graph = default_graph; | |||
| while (true) | |||
| { | |||
| if(graph is null) | |||
| { | |||
| break; | |||
| } | |||
| var f = graph.Functions.GetOrDefault(node_def.Op, null); | |||
| if(f is not null && graph.OuterGraph is null) | |||
| { | |||
| break; | |||
| } | |||
| graph = graph.OuterGraph; | |||
| } | |||
| var op_def = default_graph.GetOpDef(node_def.Op); | |||
| foreach(var attr in op_def.Attr) | |||
| { | |||
| if(attr.Type == "func") | |||
| { | |||
| var fname = node_def.Attr[attr.Name].Func.Name; | |||
| if (!is_function(fname)) | |||
| { | |||
| throw new ValueError($"Function {fname} was not found. Please make sure " + | |||
| $"the FunctionDef `fdef` is correct."); | |||
| } | |||
| } | |||
| else if(attr.Type == "list(func)") | |||
| { | |||
| foreach(var fn in node_def.Attr[attr.Name].List.Func) | |||
| { | |||
| var fname = fn.Name; | |||
| if (!is_function(fname)) | |||
| { | |||
| throw new ValueError($"Function {fname} was not found. Please make " + | |||
| $"sure the FunctionDef `fdef` is correct."); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| int flattened_index = 0; | |||
| foreach(var arg_def in op_def.OutputArg) | |||
| { | |||
| var num_args = _get_num_args(arg_def, node_def); | |||
| for(int i = 0; i < num_args; i++) | |||
| { | |||
| var nested_name = $"{node_def.Name}:{arg_def.Name}:{i}"; | |||
| var flat_name = $"{node_def.Name}:{flattened_index}"; | |||
| nested_to_flat_tensor_name[nested_name] = flat_name; | |||
| flattened_index++; | |||
| } | |||
| } | |||
| string control_name = "^" + node_def.Name; | |||
| nested_to_flat_tensor_name[control_name] = control_name; | |||
| } | |||
| foreach(var node_def in graph_def.Node) | |||
| { | |||
| for(int i = 0; i < node_def.Input.Count; i++) | |||
| { | |||
| node_def.Input[i] = nested_to_flat_tensor_name[node_def.Input[i]]; | |||
| } | |||
| } | |||
| return (graph_def, nested_to_flat_tensor_name); | |||
| } | |||
| private static void _set_handle_data(FuncGraph func_graph, FunctionDef fdef) | |||
| { | |||
| foreach(var (tensor, arg_def) in zip(func_graph.Inputs, fdef.Signature.InputArg).Concat(zip(func_graph.Outputs, fdef.Signature.OutputArg))) | |||
| { | |||
| if(arg_def.HandleData is not null && arg_def.HandleData.Count > 0) | |||
| { | |||
| tensor.shape = Shape.Scalar; | |||
| var shape_and_type = arg_def.HandleData[0]; | |||
| var handle_data = new HandleData(); | |||
| handle_data.IsSet = true; | |||
| handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||
| { | |||
| Shape = shape_and_type.Shape, | |||
| Dtype = shape_and_type.Dtype | |||
| }); | |||
| resource_variable_ops._set_handle_shapes_and_types(tensor, handle_data, true); | |||
| } | |||
| } | |||
| } | |||
| private static long _get_num_args(OpDef.Types.ArgDef arg_def, NodeDef node_def) | |||
| { | |||
| if (!string.IsNullOrEmpty(arg_def.NumberAttr)) | |||
| { | |||
| return node_def.Attr[arg_def.NumberAttr].I; | |||
| } | |||
| else if(!string.IsNullOrEmpty(arg_def.TypeListAttr)) | |||
| { | |||
| return node_def.Attr[arg_def.TypeListAttr].List.Type.Count; | |||
| } | |||
| else if(arg_def.TypeAttr is not null || arg_def.Type != DataType.DtInvalid) | |||
| { | |||
| return 1; | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError($"Invalid arg_def:\n\n{arg_def}. Please make sure the " + | |||
| $"FunctionDef `fdef` is correct."); | |||
| } | |||
| } | |||
| public static bool is_function(string fname) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| return tf.Context.has_function(fname); | |||
| } | |||
| else | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| while(graph is not null) | |||
| { | |||
| if (graph.IsFunction(fname)) | |||
| { | |||
| return true; | |||
| } | |||
| if(graph.OuterGraph is not null) | |||
| { | |||
| graph = graph.OuterGraph; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| throw new ValueError("Unexpected behavior happened in runtime, please submit an issue to " + | |||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.OpDef.Types; | |||
| @@ -25,9 +26,14 @@ namespace Tensorflow | |||
| { | |||
| public class importer | |||
| { | |||
| public static ITensorOrOperation[] import_graph_def_for_function(GraphDef graph_def, string name = null) | |||
| { | |||
| return import_graph_def(graph_def, validate_colocation_constraints: false, name: name); | |||
| } | |||
| public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, | |||
| Dictionary<string, Tensor> input_map = null, | |||
| string[] return_elements = null, | |||
| bool validate_colocation_constraints = true, | |||
| string name = null, | |||
| OpList producer_op_list = null) | |||
| { | |||
| @@ -60,7 +66,7 @@ namespace Tensorflow | |||
| var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||
| var status = new Status(); | |||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements, validate_colocation_constraints ); | |||
| // need to create a class ImportGraphDefWithResults with IDisposal | |||
| results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | |||
| status.Check(true); | |||
| @@ -107,21 +113,36 @@ namespace Tensorflow | |||
| foreach (var new_op in graph._add_new_tf_operations()) | |||
| { | |||
| var original_device = new_op.Device; | |||
| new_op._set_device(original_device); | |||
| } | |||
| } | |||
| public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, | |||
| string prefix, | |||
| Dictionary<string, Tensor> input_map, | |||
| string[] return_elements) | |||
| string[] return_elements, | |||
| bool validate_colocation_constraints) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Options, true); | |||
| foreach (var input in input_map) | |||
| { | |||
| var (src_name, src_index) = _ParseTensorName(input.Key); | |||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); | |||
| var input_src = tf.compat.as_str(input.Key); | |||
| var input_dst = input.Value; | |||
| if (input_src.StartsWith("^")) | |||
| { | |||
| var src_name = tf.compat.as_str(input_src.Substring(1)); | |||
| var dst_op = input_dst._as_tf_output().oper; | |||
| c_api.TF_ImportGraphDefOptionsRemapControlDependency(options.Options, src_name, dst_op); | |||
| } | |||
| else | |||
| { | |||
| var (src_name, src_index) = _ParseTensorName(input.Key); | |||
| src_name = tf.compat.as_str(src_name); | |||
| var dst_output = input_dst._as_tf_output(); | |||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Options, src_name, src_index, dst_output); | |||
| } | |||
| } | |||
| if (return_elements == null) | |||
| @@ -132,15 +153,16 @@ namespace Tensorflow | |||
| if (name.Contains(":")) | |||
| { | |||
| var (op_name, index) = _ParseTensorName(name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||
| op_name = tf.compat.as_str(op_name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Options, op_name, index); | |||
| } | |||
| else | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Options, tf.compat.as_str(name)); | |||
| } | |||
| } | |||
| // c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); | |||
| c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options.Options, validate_colocation_constraints); | |||
| } | |||
| private static (string, int) _ParseTensorName(string tensor_name) | |||
| @@ -173,6 +195,14 @@ namespace Tensorflow | |||
| return graph_def; | |||
| } | |||
| private static GraphDef _ProcessGraphDefParam(GraphDef graph_def) | |||
| { | |||
| var old_graph_def = graph_def; | |||
| graph_def = new GraphDef(old_graph_def); | |||
| return graph_def; | |||
| } | |||
| private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) | |||
| { | |||
| foreach (var attr_def in op_def.Attr) | |||
| @@ -240,6 +270,35 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private static void _RemoveDefaultAttrs(OpList producer_op_list, GraphDef graph_def) | |||
| { | |||
| var producer_op_dict = producer_op_list.Op.ToDictionary(x => x.Name, x => x); | |||
| foreach (var node in graph_def.Node) | |||
| { | |||
| // Remove any default attr values that aren't in op_def. | |||
| if (producer_op_dict.ContainsKey(node.Op)) | |||
| { | |||
| var op_def = op_def_registry.GetOpDef(node.Op); | |||
| if(op_def is null) | |||
| { | |||
| continue; | |||
| } | |||
| var producer_op_def = producer_op_dict[node.Op]; | |||
| foreach (var key in node.Attr.Keys) | |||
| { | |||
| if (_FindAttrInOpDef(key, op_def) is null) | |||
| { | |||
| var attr_def = _FindAttrInOpDef(key, producer_op_def); | |||
| if (attr_def != null && attr_def.DefaultValue != null && | |||
| node.Attr[key] == attr_def.DefaultValue) | |||
| node.Attr[key].ClearValue(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) | |||
| { | |||
| return op_def.Attr.FirstOrDefault(x => x.Name == name); | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Framework | |||
| { | |||
| public class versions | |||
| { | |||
| public static int GRAPH_DEF_VERSION = 1286; | |||
| public static int GRAPH_DEF_VERSION_MIN_CONSUMER = 0; | |||
| } | |||
| } | |||
| @@ -1,9 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Functions | |||
| @@ -13,29 +17,46 @@ namespace Tensorflow.Functions | |||
| /// </summary> | |||
| public class ConcreteFunction: Trackable | |||
| { | |||
| protected IEnumerable<Tensor> _captured_inputs; | |||
| protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | |||
| protected Dictionary<string, AttrValue> _attrs; | |||
| protected FunctionSpec _function_spec; | |||
| protected FunctionSpec _pre_initialized_function_spec = null; | |||
| protected EagerDefinedFunction _inference_function; | |||
| protected Dictionary<string, TapeGradientFunctions> _tape_functions_cache = new(); | |||
| internal FuncGraph func_graph; | |||
| internal ForwardBackwardCall forward_backward; | |||
| public Tensor[] Inputs => func_graph.Inputs; | |||
| public Tensor[] CapturedInputs => func_graph.external_captures; | |||
| public string Name => func_graph?.FuncName; | |||
| public string Name => _delayed_rewrite_functions.Forward().Name; | |||
| public Tensor[] Outputs; | |||
| public Tensor[] Outputs => func_graph.Outputs; | |||
| public Type ReturnType; | |||
| public TensorSpec[] OutputStructure; | |||
| public IEnumerable<string> ArgKeywords { get; set; } | |||
| public long NumPositionArgs { get; set; } | |||
| public FunctionDef FunctionDef => _delayed_rewrite_functions.Forward().Definition; | |||
| public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | |||
| public IEnumerable<IVariableV1> Variables => func_graph.Variables; | |||
| public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | |||
| public ConcreteFunction(string name) | |||
| { | |||
| func_graph = new FuncGraph(name); | |||
| _captured_inputs = func_graph.external_captures; | |||
| _attrs= new Dictionary<string, AttrValue>(); | |||
| _set_infer_function(); | |||
| } | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, AttrValue> attrs = null) | |||
| { | |||
| func_graph = graph; | |||
| _captured_inputs = func_graph.external_captures; | |||
| ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||
| //ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||
| _attrs = attrs; | |||
| _set_infer_function(); | |||
| } | |||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||
| @@ -53,6 +74,9 @@ namespace Tensorflow.Functions | |||
| new[] { output }, | |||
| null); | |||
| func_graph.Exit(); | |||
| _captured_inputs = func_graph.external_captures; | |||
| _attrs = new Dictionary<string, AttrValue>(); | |||
| _set_infer_function(); | |||
| } | |||
| public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | |||
| @@ -73,6 +97,9 @@ namespace Tensorflow.Functions | |||
| new[] { output.variant_tensor }, | |||
| null); | |||
| func_graph.Exit(); | |||
| _captured_inputs = func_graph.external_captures; | |||
| _attrs = new Dictionary<string, AttrValue>(); | |||
| _set_infer_function(); | |||
| } | |||
| /*public ConcreteFunction(Func<Tensors, Tensors> func, | |||
| @@ -130,39 +157,56 @@ namespace Tensorflow.Functions | |||
| { | |||
| var executing_eagerly = tf.Context.executing_eagerly(); | |||
| var default_graph = ops.get_default_graph(); | |||
| // TODO(Rinne): deal with `default_graph.building_function` | |||
| var tempvv = func_graph.Variables; | |||
| if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph) | |||
| { | |||
| foreach(var v in this.func_graph.Variables) | |||
| { | |||
| resource_variable_ops.variable_accessed(v); | |||
| } | |||
| } | |||
| var tensor_inputs = new Tensors(); | |||
| foreach (var (i, arg) in enumerate(args)) | |||
| { | |||
| tensor_inputs.Add(arg); | |||
| // If we're graph building, shape inference is on. | |||
| if (!executing_eagerly) | |||
| { | |||
| } | |||
| } | |||
| tensor_inputs.AddRange(captured_inputs); | |||
| if (!executing_eagerly) | |||
| { | |||
| // TODO(Rinne): add the check | |||
| } | |||
| tensor_inputs.AddRange(captured_inputs); | |||
| args = tensor_inputs.ToArray(); | |||
| var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; | |||
| var possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args); | |||
| // No tape is watching; skip to running the function. | |||
| if (possible_gradient_type == 0 && executing_eagerly) | |||
| if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE && executing_eagerly) | |||
| { | |||
| var attrs = new object[] | |||
| { | |||
| "executor_type", "", | |||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| }; | |||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||
| return _build_call_outputs(_inference_function.Call(args)); | |||
| } | |||
| if (forward_backward == null) | |||
| forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||
| forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||
| var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
| Tensors flat_outputs = null; | |||
| if (executing_eagerly) | |||
| { | |||
| flat_outputs = forward_function.Call(args_with_tangents); | |||
| } | |||
| else | |||
| { | |||
| tf_with(default_graph._override_gradient_function(new Dictionary<string, Func<Operation, object[], Tensor[]>>(){ | |||
| { "PartitionedCall", _get_gradient_function() }, { "StatefulPartitionedCall", _get_gradient_function() } | |||
| }), _ => | |||
| { | |||
| flat_outputs = forward_function.Call(args_with_tangents); | |||
| }); | |||
| } | |||
| forward_backward.Record(flat_outputs); | |||
| return flat_outputs; | |||
| return _build_call_outputs(flat_outputs); | |||
| } | |||
| public void AddTograph(Graph? g = null) | |||
| @@ -171,13 +215,99 @@ namespace Tensorflow.Functions | |||
| { | |||
| g = ops.get_default_graph(); | |||
| } | |||
| // TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||
| _delayed_rewrite_functions.Forward().AddToGraph(g); | |||
| } | |||
| public void SetExternalCaptures(IEnumerable<Tensor> captures) | |||
| { | |||
| _captured_inputs = captures; | |||
| } | |||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
| { | |||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
| TangentInfo input_tangents; | |||
| if (executing_eagerly) | |||
| { | |||
| // TODO(Rinne): check if it needs to be implemented. | |||
| input_tangents = new TangentInfo(); | |||
| } | |||
| else | |||
| { | |||
| input_tangents = new TangentInfo(); | |||
| } | |||
| if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) | |||
| { | |||
| if(input_tangents.Indices is not null || executing_eagerly) | |||
| { | |||
| string cache_key = "first_order"; | |||
| if(!_tape_functions_cache.TryGetValue(cache_key, out var functions)) | |||
| { | |||
| functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| _tape_functions_cache[cache_key] = functions; | |||
| } | |||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
| } | |||
| else | |||
| { | |||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true); | |||
| } | |||
| } | |||
| else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | |||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); | |||
| } | |||
| internal void set_variables(IEnumerable<IVariableV1> variables) | |||
| { | |||
| func_graph.Variables = variables; | |||
| } | |||
| internal void _set_infer_function() | |||
| { | |||
| _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||
| _inference_function = _delayed_rewrite_functions.Forward(); | |||
| } | |||
| internal void _set_function_spec(FunctionSpec spec) | |||
| { | |||
| _function_spec = null; | |||
| _pre_initialized_function_spec = spec; | |||
| _initialize_function_spec(); | |||
| } | |||
| internal void _initialize_function_spec() | |||
| { | |||
| if(_pre_initialized_function_spec is null) | |||
| { | |||
| return; | |||
| } | |||
| Debug.Assert(_function_spec is null, "already initialized"); | |||
| var spec = _pre_initialized_function_spec; | |||
| //var args = spec.Fullargspec.DictValue.Fields["args"]; | |||
| // TODO(Rinne): self.structured_input_signature | |||
| _function_spec = new FunctionSpec() | |||
| { | |||
| Fullargspec = spec.Fullargspec, | |||
| IsMethod = spec.IsMethod, | |||
| InputSignature = spec.InputSignature | |||
| }; | |||
| } | |||
| internal Func<Operation, object[], Tensor[]> _get_gradient_function() | |||
| { | |||
| return _delayed_rewrite_functions._rewrite_forward_and_call_backward; | |||
| } | |||
| private Tensors _build_call_outputs(Tensors result) | |||
| { | |||
| // TODO(Rinne): deal with `func_graph.structured_outputs` | |||
| return result; | |||
| } | |||
| public override string ToString() | |||
| @@ -1,50 +1,232 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| using Tensorflow.Common.Extensions; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Framework; | |||
| using System.Buffers; | |||
| using Tensorflow.Gradients; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public class EagerDefinedFunction | |||
| public class EagerDefinedFunction: IDisposable | |||
| { | |||
| public int _num_outputs; | |||
| public string Name => _func_graph.FuncName; | |||
| FuncGraph _graph; | |||
| FunctionDef _definition; | |||
| OpDef _signature; | |||
| string _name; | |||
| internal ScopedTFFunction _c_func; | |||
| internal Tensor[] _func_graph_outputs; | |||
| internal string _grad_func_name; | |||
| internal Func<Operation, Tensor[], Tensor[]> csharp_grad_func; | |||
| internal EagerDefinedFunction _grad_func; | |||
| internal bool _registered_on_context = false; | |||
| public string Name => _name; | |||
| public DataType[] OutputTypes { get; protected set; } | |||
| public Shape[] OutputShapes { get; protected set; } | |||
| public FunctionDef Definition | |||
| { | |||
| get | |||
| { | |||
| if(_definition is null) | |||
| { | |||
| _definition = _get_definition(); | |||
| } | |||
| return _definition; | |||
| } | |||
| } | |||
| FuncGraph _func_graph; | |||
| public EagerDefinedFunction(string name, FuncGraph graph, | |||
| public OpDef Signature | |||
| { | |||
| get | |||
| { | |||
| if( _signature is null) | |||
| { | |||
| _signature = Definition.Signature; | |||
| } | |||
| return _signature; | |||
| } | |||
| } | |||
| public unsafe EagerDefinedFunction(string name, FuncGraph graph, | |||
| Tensors inputs, Tensors outputs, | |||
| Dictionary<string, string> attrs) | |||
| Dictionary<string, AttrValue> attrs) | |||
| { | |||
| _num_outputs = outputs.Length; | |||
| var input_ops = inputs.Select(x => x.op).ToArray(); | |||
| var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | |||
| .Select(x => x as Operation).ToArray(); | |||
| var output_names = new string[0]; | |||
| var graph_output_names = graph._output_names; | |||
| string[] output_names; | |||
| if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t)))) | |||
| { | |||
| output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray(); | |||
| if(output_names.Distinct().Count() != output_names.Length) | |||
| { | |||
| output_names = new string[0]; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| output_names = new string[0]; | |||
| } | |||
| _func_graph = new FuncGraph(graph, name, attrs); | |||
| _func_graph.ToGraph(operations, inputs, outputs, output_names); | |||
| Status status = new Status(); | |||
| var fn = c_api.TF_GraphToFunction(graph.c_graph, | |||
| name, | |||
| false, | |||
| operations.Length, | |||
| operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(), | |||
| inputs.Length, | |||
| inputs.Select(t => t._as_tf_output()).ToArray(), | |||
| outputs.Length, | |||
| outputs.Select(t => t._as_tf_output()).ToArray(), | |||
| output_names.Length != outputs.Length ? null : output_names, | |||
| IntPtr.Zero, // warning: the control output hasbben totally ignored. | |||
| null, | |||
| status); | |||
| status.Check(true); | |||
| _c_func = new ScopedTFFunction(fn, name); | |||
| foreach(var (attr_name, attr_value) in attrs) | |||
| { | |||
| var serialized = attr_value.ToByteArray(); | |||
| c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status); | |||
| status.Check(true); | |||
| } | |||
| var signature = _get_definition().Signature; | |||
| _name = signature.Name; | |||
| tf_with(ops.init_scope(), s => | |||
| { | |||
| tf.Context.add_function(fn); | |||
| _registered_on_context = true; | |||
| }); | |||
| _num_outputs = signature.OutputArg.Count; | |||
| OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray(); | |||
| OutputShapes = outputs.Select(x => x.shape).ToArray(); | |||
| _func_graph_outputs = new List<Tensor>(outputs).ToArray(); | |||
| csharp_grad_func = null; | |||
| _graph = graph; | |||
| } | |||
| public Tensors Call(Tensors args) | |||
| public unsafe Tensors Call(Tensors args) | |||
| { | |||
| // TODO(Rinne): Add arg `CancellationManager`. | |||
| // TODO(Rinne): Check the arg length. | |||
| var function_call_options = tf.Context.FunctionCallOptions; | |||
| string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons. | |||
| //if (function_call_options.config_proto_serialized().Length == 0) | |||
| //{ | |||
| // config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||
| //} | |||
| //else | |||
| //{ | |||
| // config = function_call_options.config_proto_serialized().ToStringUtf8(); | |||
| //} | |||
| string executor_type = function_call_options.ExecutorType ?? ""; | |||
| var executing_eagerly = tf.Context.executing_eagerly(); | |||
| var attrs = new object[] | |||
| { | |||
| "executor_type", "", | |||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| "executor_type", executor_type, | |||
| "config_proto", config | |||
| }; | |||
| var results = tf.Runner.TFE_Execute(tf.Context, | |||
| tf.Context.DeviceName, | |||
| _func_graph.FuncName, | |||
| args, | |||
| attrs, | |||
| _num_outputs); | |||
| Tensor[] outputs; | |||
| if (executing_eagerly) | |||
| { | |||
| outputs = execute.executes( | |||
| Signature.Name, | |||
| _num_outputs, | |||
| args, | |||
| attrs, | |||
| tf.Context); | |||
| } | |||
| else | |||
| { | |||
| if(tf.GetTapeSet().Count == 0) | |||
| { | |||
| outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||
| executing_eagerly, config, ""); | |||
| } | |||
| else | |||
| { | |||
| var tape = tf.GetTapeSet().Peek(); | |||
| tape.StopRecord(); | |||
| outputs = functional_ops.partitioned_call(args, this, OutputTypes, | |||
| executing_eagerly, config, ""); | |||
| tape.StartRecord(); | |||
| } | |||
| } | |||
| foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) | |||
| { | |||
| handle_data_util.copy_handle_data(func_graph_output, outputs[i]); | |||
| } | |||
| if (executing_eagerly) | |||
| { | |||
| return outputs; | |||
| } | |||
| else | |||
| { | |||
| foreach(var (i, shape) in enumerate(OutputShapes)) | |||
| { | |||
| outputs[i].shape = shape; | |||
| } | |||
| return outputs; | |||
| } | |||
| } | |||
| public void AddToGraph(Graph g = null) | |||
| { | |||
| if(g is null && tf.Context.executing_eagerly()) | |||
| { | |||
| var ctx = tf.Context; | |||
| if (!ctx.has_function(this.Name)) | |||
| { | |||
| ctx.add_function_def(Definition); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| if (!g.IsFunction(Name)) | |||
| { | |||
| g.AddFunction(this); | |||
| } | |||
| foreach(var f in _graph.Functions.Values) | |||
| { | |||
| if (!g.IsFunction(f.Name)) | |||
| { | |||
| g.AddFunction(f); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return results; | |||
| private FunctionDef _get_definition() | |||
| { | |||
| var buffer = c_api_util.tf_buffer(); | |||
| Status status = new(); | |||
| c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status); | |||
| status.Check(true); | |||
| var proto_data = c_api.TF_GetBuffer(buffer); | |||
| return FunctionDef.Parser.ParseFrom(proto_data.AsSpan<byte>()); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| tf.Context.remove_function(Name); | |||
| } | |||
| } | |||
| } | |||
| @@ -14,12 +14,11 @@ namespace Tensorflow.Functions | |||
| } | |||
| public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
| public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
| ForwardAndBackwardFunctions(Tensors inference_args) | |||
| { | |||
| var outputs = _func_graph.Outputs; | |||
| (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
| = BuildFunctionsForOutputs(outputs, inference_args); | |||
| return _forward; | |||
| var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray(); | |||
| return BuildFunctionsForOutputs(outputs, inference_args); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,23 +1,84 @@ | |||
| using System; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow | |||
| { | |||
| public class Function: Trackable | |||
| public class Function: Trackable, IGenericFunction | |||
| { | |||
| #pragma warning disable CS0169 // The field 'Function._handle' is never used | |||
| private IntPtr _handle; | |||
| #pragma warning restore CS0169 // The field 'Function._handle' is never used | |||
| protected Func<Tensor[], Tensor[]> _csharp_function; | |||
| protected ConcreteFunction _concrete_variable_creation_fn; | |||
| protected bool _autograph; | |||
| protected TracingCompiler _variable_creation_fn; | |||
| public string Name { get; set; } | |||
| public Function() | |||
| public Function(Func<Tensor[], Tensor[]> csharp_function, | |||
| string name, bool auto_graph = true) | |||
| { | |||
| _csharp_function = csharp_function; | |||
| Name = name; | |||
| _autograph = auto_graph; | |||
| } | |||
| public virtual Tensors Apply(Tensors inputs) | |||
| { | |||
| if (_run_functions_eagerly()) | |||
| { | |||
| return _csharp_function(inputs); | |||
| } | |||
| var result = _call(inputs); | |||
| return result; | |||
| } | |||
| public ConcreteFunction get_concrete_function(params Tensor[] args) | |||
| { | |||
| return _get_concrete_function_garbage_collected(args); | |||
| } | |||
| public Function(string name) | |||
| protected virtual Tensors _call(Tensors inputs) | |||
| { | |||
| Name = name; | |||
| if(_variable_creation_fn is not null) | |||
| { | |||
| return _variable_creation_fn.Apply(inputs); | |||
| } | |||
| _initialize(inputs); | |||
| return _concrete_variable_creation_fn.CallFlat(inputs, | |||
| _concrete_variable_creation_fn.CapturedInputs); | |||
| } | |||
| protected TracingCompiler _compiler(Func<Tensor[], Tensor[]> fn) | |||
| { | |||
| var name = nameof(fn); | |||
| return new TracingCompiler(fn, name, autograph: _autograph); | |||
| } | |||
| protected virtual bool _run_functions_eagerly() | |||
| { | |||
| return false; | |||
| } | |||
| protected ConcreteFunction _get_concrete_function_garbage_collected(Tensor[] args) | |||
| { | |||
| if(_variable_creation_fn is null) | |||
| { | |||
| _initialize(args); | |||
| // TODO(Rinne): _initialize_uninitialized_variables | |||
| } | |||
| var concrete = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||
| return concrete; | |||
| } | |||
| private void _initialize(Tensor[] args) | |||
| { | |||
| _variable_creation_fn = _compiler(_csharp_function); | |||
| _variable_creation_fn._name = this.Name; | |||
| _concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public interface IGenericFunction | |||
| { | |||
| Tensors Apply(Tensors args); | |||
| ConcreteFunction get_concrete_function(params Tensor[] args); | |||
| } | |||
| } | |||
| @@ -3,8 +3,10 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.tensorflow; | |||
| @@ -15,17 +17,21 @@ namespace Tensorflow.Functions | |||
| /// </summary> | |||
| public abstract class TapeGradientFunctions | |||
| { | |||
| string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
| string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
| string _FORWARD_PREFIX = "__forward_"; | |||
| string _BACKWARD_PREFIX = "__backward_"; | |||
| string _INFERENCE_PREFIX = "__inference_"; | |||
| protected string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
| protected string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
| protected string _FORWARD_PREFIX = "__forward_"; | |||
| protected string _BACKWARD_PREFIX = "__backward_"; | |||
| protected string _INFERENCE_PREFIX = "__inference_"; | |||
| protected FuncGraph _func_graph; | |||
| protected EagerDefinedFunction _forward; | |||
| protected FuncGraph _forward_graph; | |||
| protected List<int> _forwardprop_input_indices; | |||
| protected List<int> _forwardprop_output_indices; | |||
| protected int _num_forwardprop_outputs; | |||
| protected int _num_inference_outputs; | |||
| protected int _num_outputs; | |||
| protected int _num_trainable_inference_outputs; | |||
| protected ConcreteFunction _backward; | |||
| BackwardFunction _backward_function_wrapper; | |||
| @@ -33,11 +39,25 @@ namespace Tensorflow.Functions | |||
| bool need_gradients_for_jvps) | |||
| { | |||
| _func_graph = func_graph; | |||
| _forward_graph = null; | |||
| _forward = null; | |||
| _backward = null; | |||
| _num_outputs = func_graph.Outputs.Length; | |||
| _forwardprop_output_indices = null; | |||
| _num_forwardprop_outputs = 0; | |||
| _num_inference_outputs = func_graph.Outputs.Length; | |||
| _num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count(); | |||
| } | |||
| public EagerDefinedFunction Forward(Tensors inference_args) | |||
| public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) | |||
| { | |||
| return ForwardAndBackwardFunctions(inference_args); | |||
| // TODO(Rinne): add input_tangents arg. | |||
| if(_forward is null) | |||
| { | |||
| (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
| = ForwardAndBackwardFunctions(inference_args); | |||
| } | |||
| return _forward; | |||
| } | |||
| /// <summary> | |||
| @@ -45,11 +65,16 @@ namespace Tensorflow.Functions | |||
| /// </summary> | |||
| /// <param name="flat_outputs"></param> | |||
| /// <param name="inference_args"></param> | |||
| public void Record(Tensors flat_outputs, Tensors inference_args) | |||
| public virtual void Record(Tensors flat_outputs, Tensors inference_args) | |||
| { | |||
| // TODO(Rinne): add arg `input_tagents`. | |||
| var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||
| tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, | |||
| getBackwardFunction: backward_function); | |||
| if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function); | |||
| } | |||
| /// <summary> | |||
| @@ -61,66 +86,95 @@ namespace Tensorflow.Functions | |||
| /// <returns></returns> | |||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | |||
| { | |||
| var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs) | |||
| .ToDictionary(x => x.Item1, x => x.Item2); | |||
| var captured_inputs = backward.CapturedInputs; | |||
| var remapped_captures = captured_inputs.Select(c => | |||
| { | |||
| if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value)) | |||
| { | |||
| return value; | |||
| } | |||
| else | |||
| { | |||
| return c; | |||
| } | |||
| }).ToArray(); | |||
| if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph)) | |||
| { | |||
| var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph); | |||
| throw new RuntimeError($"Failed to map all backward graph captures to " + | |||
| $"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}"); | |||
| } | |||
| Dictionary<int, Tensor> variant_zeros_like = new Dictionary<int, Tensor>(); | |||
| var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | |||
| var recorded_outputs = new Tensors(); | |||
| var trainable_recorded_outputs = 0; | |||
| foreach (var (output_index, output) in enumerate(outputs)) | |||
| int trainable_recorded_outputs = 0; | |||
| var skip_positions = new HashSet<int>(); | |||
| var relevant_outputs = outputs; | |||
| foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||
| { | |||
| if (trainable_recorded_outputs < backward_function_inputs) | |||
| recorded_outputs.Add(output); | |||
| if (gradients_util.IsTrainable(output)) | |||
| trainable_recorded_outputs += 1; | |||
| if (backprop_util.IsTrainable(output)) | |||
| trainable_recorded_outputs++; | |||
| else | |||
| skip_positions.Add(output_index); | |||
| if (output.dtype == dtypes.variant) | |||
| variant_zeros_like[output_index] = default_gradient.zeros_like(output); | |||
| } | |||
| if(_backward_function_wrapper == null) | |||
| _backward_function_wrapper = (args, unneeded_gradients) => | |||
| { | |||
| var capture_mapping = new Dictionary<long, Tensor>(); | |||
| foreach (var (i, output) in enumerate(outputs)) | |||
| capture_mapping[forward_graph.Outputs[i].Id] = output; | |||
| var remapped_captures = new Tensors(); | |||
| foreach (var capture in backward.CapturedInputs) | |||
| { | |||
| if (capture_mapping.ContainsKey(capture.Id)) | |||
| remapped_captures.Add(capture_mapping[capture.Id]); | |||
| } | |||
| var skip_positions = new List<int>(); | |||
| foreach (var (output_index, output) in enumerate(outputs)) | |||
| if(backward.Outputs is null || backward.Outputs.Length == 0) | |||
| { | |||
| if (!gradients_util.IsTrainable(output)) | |||
| skip_positions.Add(output_index); | |||
| return backward.FlatStructuredOutputs; | |||
| } | |||
| _backward_function_wrapper = (args, unneeded_gradients) => | |||
| var processed_args = new Tensors(); | |||
| int input_index = 0; | |||
| foreach (var (output_index, arg) in enumerate(args)) | |||
| { | |||
| var processed_args = new Tensors(); | |||
| var input_index = 0; | |||
| foreach (var (output_index, arg) in enumerate(args)) | |||
| if (skip_positions.Contains(output_index)) | |||
| continue; | |||
| if (arg is null) | |||
| { | |||
| var input_placeholder = backward.Inputs[input_index]; | |||
| Tensor variant_arg; | |||
| if (input_placeholder.dtype == dtypes.variant) | |||
| { | |||
| variant_arg = variant_zeros_like[output_index]; | |||
| } | |||
| else | |||
| { | |||
| var (shape, type) = default_gradient.shape_and_dtype(input_placeholder); | |||
| variant_arg = array_ops.zeros(shape, type); | |||
| } | |||
| processed_args.Add(variant_arg); | |||
| } | |||
| else | |||
| { | |||
| if (skip_positions.Contains(output_index)) | |||
| continue; | |||
| if (arg == null) | |||
| throw new NotImplementedException(""); | |||
| processed_args.Add(arg); | |||
| input_index += 1; | |||
| if (input_index >= backward_function_inputs) | |||
| break; | |||
| } | |||
| input_index++; | |||
| if (input_index >= backward_function_inputs) | |||
| break; | |||
| } | |||
| tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
| var gradients = backward.CallFlat(processed_args, remapped_captures); | |||
| tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
| var gradients = backward.CallFlat(processed_args, remapped_captures); | |||
| foreach (var unneeded_gradient_index in unneeded_gradients) | |||
| { | |||
| var index = Convert.ToInt32(unneeded_gradient_index); | |||
| if (gradients.Length <= index) | |||
| gradients.Insert(index, null); | |||
| } | |||
| foreach (var unneeded_gradient_index in unneeded_gradients) | |||
| { | |||
| var index = Convert.ToInt32(unneeded_gradient_index); | |||
| if (gradients.Length <= index) | |||
| gradients.Insert(index, null); | |||
| } | |||
| return gradients; | |||
| }; | |||
| } | |||
| return gradients; | |||
| }; | |||
| return (_backward_function_wrapper, recorded_outputs); | |||
| } | |||
| @@ -132,51 +186,66 @@ namespace Tensorflow.Functions | |||
| var trainable_indices = new List<int>(); | |||
| foreach(var (index, output) in enumerate(outputs)) | |||
| { | |||
| if (gradients_util.IsTrainable(output)) | |||
| if (backprop_util.IsTrainable(output)) | |||
| { | |||
| trainable_outputs.Add(output); | |||
| trainable_indices.Add(index); | |||
| } | |||
| } | |||
| var gradients_wrt_outputs = new List<Tensor>(); | |||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); | |||
| var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||
| backwards_graph.as_default(); | |||
| var gradients_wrt_outputs = new List<Tensor>(); | |||
| foreach (var output in trainable_outputs) | |||
| gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||
| { | |||
| var (gradient_shape, gradient_dtype) = default_gradient.shape_and_dtype(output); | |||
| var gradient_placeholder = tf.placeholder(gradient_dtype, gradient_shape); | |||
| gradients_wrt_outputs.Add(gradient_placeholder); | |||
| handle_data_util.copy_handle_data(output, gradient_placeholder); | |||
| } | |||
| // TODO(Rinne): with ops.device(None) | |||
| var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
| _func_graph.Inputs, | |||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||
| src_graph: _func_graph); | |||
| _func_graph.Inputs, | |||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||
| src_graph: _func_graph); | |||
| var captures_from_forward = backwards_graph.external_captures | |||
| .Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) | |||
| .ToArray(); | |||
| HashSet<Tensor> existing_outputs = new(_func_graph.Outputs); | |||
| foreach(var capture in captures_from_forward) | |||
| { | |||
| if (!_func_graph.Outputs.Contains(capture)) | |||
| if (!existing_outputs.Contains(capture)) | |||
| { | |||
| existing_outputs.Add(capture); | |||
| _func_graph.Outputs.Add(capture); | |||
| } | |||
| } | |||
| backwards_graph.Exit(); | |||
| var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
| var backward_function_attr = new Dictionary<string, string>(); | |||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| gradients_wrt_outputs.append(backwards_graph.internal_captures); | |||
| backwards_graph.Inputs = gradients_wrt_outputs; | |||
| backwards_graph.Outputs = gradients_wrt_inputs; | |||
| backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); | |||
| backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); | |||
| var (wrapped_forward_function, wrapped_backward_function) = | |||
| monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||
| //var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
| //var backward_function_attr = new Dictionary<string, string>(); | |||
| //backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
| //var backward_function = new ConcreteFunction(backwards_graph, | |||
| // monomorphic_function_utils._parse_func_attrs(backward_function_attr)); | |||
| var forward_function_attr = new Dictionary<string, string>(); | |||
| forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
| var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
| _func_graph.Inputs, _func_graph.Outputs, forward_function_attr); | |||
| //var forward_function_attr = new Dictionary<string, string>(); | |||
| //forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
| //var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
| // _func_graph.Inputs, _func_graph.Outputs, | |||
| // monomorphic_function_utils._parse_func_attrs(forward_function_attr)); | |||
| return (forward_function, _func_graph, backward_function, null, 0); | |||
| return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0); | |||
| } | |||
| public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
| public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
| ForwardAndBackwardFunctions(Tensors inference_args) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| @@ -0,0 +1,84 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Security.Cryptography.X509Certificates; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public class TracingCompiler | |||
| { | |||
| Func<Tensor[], Tensor[]> _csharp_function; | |||
| //FunctionSpec _function_spec; | |||
| internal string _name; | |||
| bool _autograph; | |||
| Dictionary<string, ConcreteFunction> _function_cache; | |||
| Dictionary<string, AttrValue> _function_attributes; | |||
| int _tracing_count; | |||
| public TracingCompiler(Func<Tensor[], Tensor[]> csharp_function, string name, object? input_signatures = null, | |||
| Dictionary<string, AttrValue> attributes = null, bool autograph = true, object? autograph_options = null, | |||
| bool reduce_retracing = false, bool capture_by_value = false) | |||
| { | |||
| _csharp_function = csharp_function; | |||
| bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME); | |||
| _name = name; | |||
| _autograph = autograph; | |||
| _function_attributes = attributes ?? new Dictionary<string, AttrValue>(); | |||
| _function_cache = new Dictionary<string, ConcreteFunction>(); | |||
| _tracing_count = 0; | |||
| } | |||
| public Tensor[] Apply(Tensor[] inputs) | |||
| { | |||
| // TODO(Rinne): add lock here. | |||
| var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs); | |||
| return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs); | |||
| } | |||
| internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args) | |||
| { | |||
| var (concrete_function, _) = _maybe_define_concrete_function(args); | |||
| return concrete_function; | |||
| } | |||
| private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args) | |||
| { | |||
| return _maybe_define_function(args); | |||
| } | |||
| private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) | |||
| { | |||
| var lookup_func_key = make_cache_key(args); | |||
| if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) | |||
| { | |||
| return (concrete_function, args); | |||
| } | |||
| concrete_function = _create_concrete_function(args); | |||
| _function_cache[lookup_func_key] = concrete_function; | |||
| return (concrete_function, args); | |||
| } | |||
| private ConcreteFunction _create_concrete_function(Tensor[] args) | |||
| { | |||
| _tracing_count++; | |||
| int arglen = args.Length; | |||
| var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func( | |||
| _name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()), | |||
| args, new Dictionary<string, object>(), autograph: _autograph | |||
| ), _function_attributes); | |||
| return concrete_function; | |||
| } | |||
| private static string make_cache_key(Tensor[] inputs) | |||
| { | |||
| //string res = ""; | |||
| //foreach (var input in inputs) | |||
| //{ | |||
| // res += $"{input.name}_{input.Id}"; | |||
| //} | |||
| return inputs.Length.ToString(); | |||
| } | |||
| } | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| using Tensorflow.Functions; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -54,6 +55,9 @@ namespace Tensorflow | |||
| public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status); | |||
| public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, SafeFuncGraphHandle grad, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_GraphGetFunctions(SafeGraphHandle g, IntPtr[] funcs, int max_func, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| internal static class composite_tensor_utils | |||
| { | |||
| public static List<object> flatten_with_variables(object inputs) | |||
| { | |||
| List<object> flat_inputs = new(); | |||
| foreach(var value in nest.flatten(inputs)) | |||
| { | |||
| if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) | |||
| { | |||
| throw new NotImplementedException("The composite tensor has not been fully supported."); | |||
| } | |||
| else | |||
| { | |||
| flat_inputs.Add(value); | |||
| } | |||
| } | |||
| return flat_inputs; | |||
| } | |||
| public static List<object> flatten_with_variables_or_variable_specs(object arg) | |||
| { | |||
| List<object> flat_inputs = new(); | |||
| foreach(var value in nest.flatten(arg)) | |||
| { | |||
| if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) | |||
| { | |||
| throw new NotImplementedException("The composite tensor has not been fully supported."); | |||
| } | |||
| // TODO(Rinne): deal with `VariableSpec`. | |||
| else if(value is TypeSpec type_spec && value is not TensorSpec) | |||
| { | |||
| throw new NotImplementedException("The TypeSpec has not been fully supported."); | |||
| } | |||
| else | |||
| { | |||
| flat_inputs.Add(value); | |||
| } | |||
| } | |||
| return flat_inputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,94 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Variables; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public static class function_saved_model_utils | |||
| { | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="concrete_function"></param> | |||
| /// <param name="inputs">a list tensors or other objects (such as variables) which | |||
| /// contain tensors that were originally captured by the function</param> | |||
| public static void restore_captures(ConcreteFunction concrete_function, IEnumerable<object> inputs) | |||
| { | |||
| var bound_inputs = inputs?.Select(obj => | |||
| { | |||
| if(obj is Tensor tensor) | |||
| { | |||
| return get_tensor_from_node(tensor); | |||
| } | |||
| else if(obj is IVariableV1 variable) | |||
| { | |||
| return get_tensor_from_node(variable); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("Encountered an type error, please submit an issue to " + | |||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| }); | |||
| var bound_variables = inputs.Where(obj => obj is IVariableV1).Select(x => (IVariableV1)x); | |||
| List<Tensor> captured_inputs_list = new(); | |||
| concrete_function.set_variables(bound_variables); | |||
| if (bound_inputs is not null) | |||
| { | |||
| foreach(var (bound_input, internal_capture) in zip(bound_inputs, concrete_function.Inputs.Skip(concrete_function.Inputs.Length - bound_inputs.Count()))) | |||
| { | |||
| if(hasattr(bound_input, "__tf_experimental_restore_capture__")) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| else | |||
| { | |||
| captured_inputs_list.Add(bound_input); | |||
| concrete_function.func_graph.replace_capture(bound_input, internal_capture); | |||
| if(internal_capture.dtype == dtypes.resource) | |||
| { | |||
| if (resource_variable_ops.is_resource_variable(bound_input)) | |||
| { | |||
| handle_data_util.copy_handle_data(bound_input.Handle, internal_capture); | |||
| } | |||
| else | |||
| { | |||
| handle_data_util.copy_handle_data(bound_input, internal_capture); | |||
| } | |||
| } | |||
| concrete_function.func_graph.capture(bound_input); | |||
| } | |||
| } | |||
| } | |||
| if(captured_inputs_list.Any(inp => inp is null)) | |||
| { | |||
| // TODO(Rinne): add warnings. | |||
| } | |||
| concrete_function.SetExternalCaptures(captured_inputs_list); | |||
| } | |||
| public static Tensor get_tensor_from_node(Tensor node) | |||
| { | |||
| return node; | |||
| } | |||
| public static Tensor get_tensor_from_node(IVariableV1 node) | |||
| { | |||
| if (resource_variable_ops.is_resource_variable(node)) | |||
| { | |||
| return node.Handle; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("Encountered an type error, please submit an issue to " + | |||
| "https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,282 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Common.Extensions; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Framework; | |||
| using static Tensorflow.Binding; | |||
| using System.Diagnostics; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| internal static class monomorphic_function_utils | |||
| { | |||
| internal static string _FORWARD_PREFIX = "__forward_"; | |||
| internal static string _BACKWARD_PREFIX = "__backward_"; | |||
| internal static string _INFERENCE_PREFIX = "__inference_"; | |||
| internal static string IMPLEMENTS_ATTRIBUTE_NAME = "_implements"; | |||
| internal static string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
| internal static string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
| public static string _inference_name(string name) | |||
| { | |||
| return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | |||
| } | |||
| public static string _forward_name(string name) | |||
| { | |||
| return $"{_FORWARD_PREFIX}{name}_{ops.uid()}"; | |||
| } | |||
| public static string _backward_name(string name) | |||
| { | |||
| return $"{_BACKWARD_PREFIX}{name}_{ops.uid()}"; | |||
| } | |||
| public static (EagerDefinedFunction, ConcreteFunction) _create_forward_backward_with_graph(Dictionary<string, AttrValue> attrs, | |||
| FuncGraph forward_graph, FuncGraph backwards_graph) | |||
| { | |||
| string forward_function_name = _forward_name(forward_graph.Name); | |||
| Dictionary<string, AttrValue> common_attributes; | |||
| if(attrs is null) | |||
| { | |||
| common_attributes = new Dictionary<string, AttrValue>(); | |||
| } | |||
| else | |||
| { | |||
| common_attributes = new Dictionary<string, AttrValue>(attrs); | |||
| } | |||
| if (common_attributes.ContainsKey(IMPLEMENTS_ATTRIBUTE_NAME)) | |||
| { | |||
| common_attributes.Remove(IMPLEMENTS_ATTRIBUTE_NAME); | |||
| } | |||
| var backward_function_attr = _parse_func_attrs(new Dictionary<string, object>() | |||
| { | |||
| {FORWARD_FUNCTION_ATTRIBUTE_NAME, forward_function_name } | |||
| }); | |||
| backward_function_attr.Update(common_attributes); | |||
| var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
| var forward_function_attr = _parse_func_attrs(new Dictionary<string, object>() | |||
| { | |||
| {BACKWARD_FUNCTION_ATTRIBUTE_NAME, backward_function.Name } | |||
| }); | |||
| forward_function_attr.Update(common_attributes); | |||
| var forward_function = new EagerDefinedFunction(forward_function_name, forward_graph, | |||
| forward_graph.Inputs, forward_graph.Outputs, forward_function_attr); | |||
| return (forward_function, backward_function); | |||
| } | |||
| public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, object> attributes) | |||
| { | |||
| Dictionary<string, AttrValue> attrs = new(); | |||
| foreach(var item in attributes) | |||
| { | |||
| var key = item.Key; | |||
| var value = item.Value; | |||
| if (value is AttrValue attr_value) | |||
| { | |||
| attrs[key] = attr_value; | |||
| } | |||
| else if (value is bool b) | |||
| { | |||
| attrs[key] = new AttrValue() { B = b }; | |||
| } | |||
| else if (value is int i) | |||
| { | |||
| attrs[key] = new AttrValue() { I = i }; | |||
| } | |||
| else if (value is float f) | |||
| { | |||
| attrs[key] = new AttrValue() { F = f }; | |||
| } | |||
| else if(value is string s) | |||
| { | |||
| attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(s) }; | |||
| } | |||
| else if (value is byte[] bytes) | |||
| { | |||
| attrs[key] = new AttrValue() { S = ByteString.CopyFrom(bytes) }; | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError($"Attribute {key} must be bool, int, float, string, or " + | |||
| $"AttrValue. Got {value.GetType()}."); | |||
| } | |||
| } | |||
| return attrs; | |||
| } | |||
| public static Dictionary<string, AttrValue> _parse_func_attrs(Dictionary<string, string> attributes) | |||
| { | |||
| Dictionary<string, AttrValue> attrs = new(); | |||
| foreach (var item in attributes) | |||
| { | |||
| var key = item.Key; | |||
| var value = item.Value; | |||
| attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(value) }; | |||
| } | |||
| return attrs; | |||
| } | |||
| } | |||
| public class DelayedRewriteGradientFunctions : TapeGradientFunctions | |||
| { | |||
| EagerDefinedFunction _inference_function; | |||
| Dictionary<string, AttrValue> _attrs; | |||
| int _num_inference_outputs; | |||
| Dictionary<int, (EagerDefinedFunction, ConcreteFunction)> _cached_function_pairs = new(); | |||
| public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, AttrValue> attrs) | |||
| : base(func_graph, false) | |||
| { | |||
| _func_graph = func_graph; | |||
| _inference_function = new EagerDefinedFunction(monomorphic_function_utils._inference_name(_func_graph.Name), | |||
| _func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); | |||
| _attrs = attrs; | |||
| _num_inference_outputs = _func_graph.Outputs.Length; | |||
| } | |||
| public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) | |||
| { | |||
| if (input_tangents is not null) | |||
| { | |||
| throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + | |||
| $"a class that does not support forwardprop."); | |||
| } | |||
| return _inference_function; | |||
| } | |||
| public override void Record(Tensors flat_outputs, Tensors inference_args) | |||
| { | |||
| var (backward_function, to_record) = _backward(flat_outputs); | |||
| foreach(var tape in tf.GetTapeSet()) | |||
| { | |||
| tape.RecordOperation(_inference_function.Signature.Name, to_record, | |||
| inference_args, backward_function); | |||
| } | |||
| } | |||
| public (EagerDefinedFunction, ConcreteFunction) forward_backward(int num_doutputs = -2) | |||
| { | |||
| if(num_doutputs == -2) | |||
| { | |||
| num_doutputs = _num_inference_outputs; | |||
| } | |||
| if(_cached_function_pairs.TryGetValue(num_doutputs, out var target)) | |||
| { | |||
| return target; | |||
| } | |||
| var (forward, backward) = _construct_forward_backward(num_doutputs); | |||
| _cached_function_pairs[num_doutputs] = (forward, backward); | |||
| return (forward, backward); | |||
| } | |||
| private (BackwardFunction, Tensors) _backward(Tensors outputs) | |||
| { | |||
| Tensor[] backward_function(Tensor[] args, long[] unneeded_gradients) | |||
| { | |||
| var call_op = outputs[0].op; | |||
| return _rewrite_forward_and_call_backward(call_op, args); | |||
| } | |||
| return (backward_function, outputs); | |||
| } | |||
| internal Tensor[] _rewrite_forward_and_call_backward(Operation op, params object[] doutputs) | |||
| { | |||
| var (forward_function, backward_function) = forward_backward(doutputs.Length); | |||
| if(backward_function.Outputs is null || backward_function.Outputs.Length == 0) | |||
| { | |||
| return backward_function.FlatStructuredOutputs; | |||
| } | |||
| forward_function.AddToGraph(op.graph); | |||
| op._set_func_attr("f", forward_function.Name); | |||
| op._set_type_list_attr("Tout", forward_function.OutputTypes); | |||
| op._add_outputs(forward_function.OutputTypes.Select(x => x.as_tf_dtype()). | |||
| Skip(op.outputs.Length).ToArray(), forward_function.OutputShapes.Skip(op.outputs.Length).ToArray() | |||
| ); | |||
| for(int i = 0; i < op.outputs.Length; i++) | |||
| { | |||
| var func_graph_output = forward_function._func_graph_outputs[i]; | |||
| handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]); | |||
| } | |||
| var capture_mapping = zip(_func_graph.Outputs.Select(t => ops.tensor_id(t)), op.outputs). | |||
| ToDictionary(x => x.Item1, x => x.Item2); | |||
| var remapped_captures = backward_function.CapturedInputs.Select( | |||
| x => capture_mapping.GetOrDefault(ops.tensor_id(x), x) | |||
| ); | |||
| List<Tensor> cleaned_doutputs = new(); | |||
| foreach(var (doutput, placeholder) in zip(doutputs, _func_graph.Outputs)) | |||
| { | |||
| if (backprop_util.IsTrainable(placeholder)) | |||
| { | |||
| if(doutput is IndexedSlices) | |||
| { | |||
| cleaned_doutputs.Add(ops.convert_to_tensor(doutput)); | |||
| } | |||
| else if(doutput is null) | |||
| { | |||
| cleaned_doutputs.Add(default_gradient.zeros_like(placeholder)); | |||
| } | |||
| else if(doutput is Tensor tensor) | |||
| { | |||
| cleaned_doutputs.Add(tensor); | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError($"Unsupported type {doutput.GetType()} in function _rewrite_forward_and_call_backward"); | |||
| } | |||
| } | |||
| } | |||
| return backward_function.CallFlat(cleaned_doutputs.ToArray(), remapped_captures.ToArray()); | |||
| } | |||
| private (EagerDefinedFunction, ConcreteFunction) _construct_forward_backward(int num_doutputs) | |||
| { | |||
| var trainable_outputs = _func_graph.Outputs.Take(num_doutputs).Where(x => backprop_util.IsTrainable(x)); | |||
| List<TensorSpec> signature = new(); | |||
| foreach(var t in trainable_outputs) | |||
| { | |||
| var (shape, dtype) = default_gradient.shape_and_dtype(t); | |||
| signature.Add(new TensorSpec(shape, dtype)); | |||
| } | |||
| Tensor[] _backprop_function(Tensor[] grad_ys) | |||
| { | |||
| return gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs, | |||
| grad_ys, src_graph: _func_graph); | |||
| } | |||
| _func_graph.as_default(); | |||
| FuncGraph backwards_graph = new(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||
| FuncGraph.func_graph_from_func(backwards_graph.Name, x => _backprop_function(x.Select(y => | |||
| { | |||
| Debug.Assert(y is Tensor); | |||
| return (Tensor)y; | |||
| }).ToArray()), new object[0], new Dictionary<string, object>(), signature.ToArray(), backwards_graph); | |||
| var backwards_graph_captures = backwards_graph.external_captures; | |||
| var captures_from_forward = backwards_graph_captures.Where(c => c is not EagerTensor && c.graph == _func_graph); | |||
| HashSet<Tensor> existing_outputs = new HashSet<Tensor>(_func_graph.Outputs); | |||
| foreach(var capture in captures_from_forward) | |||
| { | |||
| if (!existing_outputs.Contains(capture)) | |||
| { | |||
| existing_outputs.Add(capture); | |||
| _func_graph.Outputs.Add(capture); | |||
| } | |||
| } | |||
| var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph( | |||
| _attrs, _func_graph, backwards_graph); | |||
| _func_graph.Exit(); | |||
| return (forward_function, backward_function); | |||
| } | |||
| } | |||
| } | |||
| @@ -9,7 +9,7 @@ namespace Tensorflow.Gradients | |||
| /// Map from tensor to how many references still exist for this tensor in | |||
| /// the tape. | |||
| /// </summary> | |||
| public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; } | |||
| public UnorderedMap<long, long> tensor_usage_counts { get; set; } | |||
| /// <summary> | |||
| /// Maps from op ID to how many output tensors of this op still need to have | |||
| /// their gradients computed. | |||
| @@ -19,7 +19,7 @@ namespace Tensorflow.Gradients | |||
| public BackpropInitialState() | |||
| { | |||
| op_tape = new OpTape(); | |||
| tensor_usage_counts = new UnorderedMap<Tensor, long>(); | |||
| tensor_usage_counts = new UnorderedMap<long, long>(); | |||
| op_missing_tensor = new UnorderedMap<long, long>(); | |||
| } | |||
| } | |||
| @@ -67,40 +67,59 @@ namespace Tensorflow.Gradients | |||
| /// <param name="target"></param> | |||
| /// <param name="source"></param> | |||
| /// <returns></returns> | |||
| public Tensor gradient(Tensor target, Tensor source) | |||
| public Tensor gradient(Tensor target, Tensor source, List<Tensor> output_gradients = null, | |||
| string unconnected_gradients = null) | |||
| { | |||
| if(_tape is null) | |||
| { | |||
| throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||
| "compute one set of gradients (or jacobians)."); | |||
| } | |||
| ITape tape = stop_recording(); | |||
| var results = tf.Runner.TFE_TapeGradient(tape, | |||
| new[] { target }, | |||
| new[] { source }, | |||
| null); | |||
| output_gradients, | |||
| new[] { source }, | |||
| unconnected_gradients); | |||
| return results[0]; | |||
| } | |||
| public Tensor gradient(Tensor target, ResourceVariable source) | |||
| public Tensor gradient(Tensor target, ResourceVariable source, List<Tensor> output_gradients = null, | |||
| string unconnected_gradients = null) | |||
| { | |||
| var results = gradient(target, new List<IVariableV1> { source }); | |||
| var results = gradient(target, new List<IVariableV1> { source }, output_gradients, unconnected_gradients); | |||
| return results[0]; | |||
| } | |||
| public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||
| public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List<Tensor> output_gradients = null, | |||
| string unconnected_gradients = null) | |||
| { | |||
| var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }); | |||
| var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients); | |||
| return (results[0], results[1]); | |||
| } | |||
| public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | |||
| public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources, List<Tensor> output_gradients = null, | |||
| string unconnected_gradients = null) | |||
| { | |||
| if (_tape is null) | |||
| { | |||
| throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||
| "compute one set of gradients (or jacobians)."); | |||
| } | |||
| var tape = stop_recording(); | |||
| var results = tf.Runner.TFE_TapeGradient(tape, | |||
| new[] { target }, | |||
| sources.Select(x => x.Handle).ToArray(), | |||
| null); | |||
| output_gradients, | |||
| sources.Select(x => x.Handle).ToArray(), | |||
| unconnected_gradients); | |||
| if (!tape.Persistent) | |||
| { | |||
| @@ -6,24 +6,31 @@ namespace Tensorflow.Gradients | |||
| public interface ITape | |||
| { | |||
| void SetTapeId(int id); | |||
| bool ShouldRecord(Tensor[] tensors); | |||
| bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes); | |||
| void StartRecord(); | |||
| void StopRecord(); | |||
| bool Persistent { get; } | |||
| void RecordOperation(string op_type, | |||
| Tensor[] input_tensors, | |||
| TapeTensor[] output_tensors, | |||
| long[] input_tensor_id, | |||
| TF_DataType[] input_dtypes, | |||
| BackwardFunction backward_function); | |||
| void VariableAccessed(ResourceVariable variable); | |||
| void RecordOperation(string op_type, | |||
| Tensor[] outputs, | |||
| Tensor[] inputs, | |||
| BackwardFunction backward_function); | |||
| void VariableAccessed(IVariableV1 variable); | |||
| void Watch(Tensor x); | |||
| ResourceVariable[] WatchedVariables(); | |||
| IVariableV1[] WatchedVariables(); | |||
| Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||
| Tensor[] source_tensor_ids, | |||
| UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
| Tensor[] output_gradients); | |||
| Tensor[] ComputeGradient(long[] target_tensor_ids, | |||
| long[] source_tensor_ids, | |||
| UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
| List<Tensor> output_gradients, | |||
| bool build_default_zeros_grads); | |||
| } | |||
| } | |||
| @@ -9,9 +9,9 @@ namespace Tensorflow.Gradients | |||
| { | |||
| public string op_type { get; set; } | |||
| public TapeTensor[] output_tensor_info { get; set; } | |||
| public Tensor[] input_tensor_id { get; set; } | |||
| public long[] input_tensor_id { get; set; } | |||
| public BackwardFunction backward_function { get; set; } | |||
| public override string ToString() | |||
| => $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}"; | |||
| => $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; | |||
| } | |||
| } | |||
| @@ -2,235 +2,246 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| public partial class Tape | |||
| { | |||
| // int kMinAggregateCount = 4; | |||
| // int kMinAggregateBytes = 128 * 1024 * 1024; | |||
| static readonly int kMinAggregateCount = 4; | |||
| static readonly int kMinAggregateBytes = 128 * 1024 * 1024; | |||
| private static UnorderedMap<string, UnorderedSet<int>> _functionsAcceptingNoneForIndicesMap; | |||
| public Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||
| Tensor[] source_tensor_ids, | |||
| UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
| Tensor[] output_gradients) | |||
| static Tape() | |||
| { | |||
| var sources_set = new UnorderedSet<Tensor>(source_tensor_ids); | |||
| // var gradients_size = new UnorderedMap<Tensor, long>(); | |||
| var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap(); | |||
| var state = PrepareBackprop( | |||
| target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent); | |||
| var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||
| var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, | |||
| output_gradients, | |||
| tensor_tape_, | |||
| state.op_tape); | |||
| _functionsAcceptingNoneForIndicesMap = new(); | |||
| _functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| _functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| _functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||
| } | |||
| while (!op_stack.empty()) | |||
| public Tensor[] ComputeGradient(long[] target_tensor_ids, | |||
| long[] source_tensor_ids, | |||
| UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
| List<Tensor> output_gradients, | |||
| bool build_default_zeros_grads) | |||
| { | |||
| UnorderedSet<long> sources_set = new(source_tensor_ids); | |||
| BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent); | |||
| var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||
| var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape); | |||
| UnorderedMap<long, long> gradients_size = new(); | |||
| while(op_stack.Count > 0) | |||
| { | |||
| var op = op_stack.Dequeue(); | |||
| if (!state.op_tape.find(op, out var trace)) | |||
| long op = op_stack.Dequeue(); | |||
| if(!state.op_tape.TryGetValue(op, out var op_it)) | |||
| { | |||
| continue; | |||
| // Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||
| } | |||
| var trace = op_it; | |||
| state.op_tape.erase(op); | |||
| var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | |||
| var unneeded_gradients = new List<long>(); | |||
| for (int i = 0; i < trace.input_tensor_id.Length; i++) | |||
| List<Tensor> out_gradients = new(); | |||
| List<long> unneeded_gradients = new(); | |||
| for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++) | |||
| { | |||
| var in_tensor_id = trace.input_tensor_id[i]; | |||
| if (!tensor_tape_.find(in_tensor_id) && | |||
| !sources_set.find(in_tensor_id)) | |||
| long in_tensor_id = trace.input_tensor_id[i]; | |||
| if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id)) | |||
| { | |||
| unneeded_gradients.Add(i); | |||
| } | |||
| } | |||
| bool any_gradient_nonzero = false; | |||
| var zero_indices = new List<int>(); | |||
| for (int i = 0; i < trace.output_tensor_info.Length; ++i) | |||
| List<int> zero_indices = new(); | |||
| for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++) | |||
| { | |||
| var id = trace.output_tensor_info[i].GetTensor(); | |||
| if (!gradients.find(id, out var grad_it)) | |||
| long id = trace.output_tensor_info[i].GetID(); | |||
| if(!gradients.TryGetValue(id, out var grad_it)) | |||
| { | |||
| if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) && | |||
| func_name_it.find(i)) | |||
| out_gradients.Add(null); | |||
| if (build_default_zeros_grads) | |||
| { | |||
| out_gradients.Add(null); | |||
| } | |||
| else | |||
| { | |||
| out_gradients.Add(null); | |||
| zero_indices.Add(i); | |||
| if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) || | |||
| !func_name_it.find(i)) | |||
| { | |||
| zero_indices.Add(i); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| any_gradient_nonzero = true; | |||
| var new_gradients = grad_it.Count == 1 ? | |||
| grad_it[0] : | |||
| gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients | |||
| Tensor new_gradients; | |||
| if (grad_it.Count == 1) | |||
| { | |||
| new_gradients = grad_it[0]; | |||
| } | |||
| else | |||
| { | |||
| new_gradients = AggregateGradients(grad_it); | |||
| } | |||
| if (!sources_set.find(id)) | |||
| { | |||
| gradients.Remove(id); | |||
| } | |||
| else | |||
| { | |||
| // grad_it.Clear(); | |||
| // grad_it.Add(new_gradients); | |||
| // vspace.MarkAsResult(new_gradients); | |||
| grad_it.Clear(); | |||
| grad_it.Add(new_gradients); | |||
| // MarkAsResult | |||
| } | |||
| out_gradients.Add(new_gradients); | |||
| } | |||
| } | |||
| Tensor[] in_gradients; | |||
| Tensor[] in_gradients = new Tensor[0]; | |||
| if (any_gradient_nonzero) | |||
| { | |||
| // foreach (var i in zero_indices) | |||
| // out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||
| in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray()); | |||
| if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length) | |||
| throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); | |||
| if (!_persistent) | |||
| foreach(var i in zero_indices) | |||
| { | |||
| // trace.backward_function_deleter(trace.backward_function); | |||
| trace.backward_function = null; | |||
| out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||
| } | |||
| in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients); | |||
| } | |||
| else | |||
| { | |||
| in_gradients = new Tensor[trace.input_tensor_id.Length]; | |||
| out_gradients.Clear(); | |||
| } | |||
| bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length; | |||
| for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k) | |||
| for(int i = 0, end = in_gradients.Length; i < end; i++) | |||
| { | |||
| if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k; | |||
| var id = trace.input_tensor_id[k]; | |||
| if (in_gradients[i] != null) | |||
| long id = trace.input_tensor_id[i]; | |||
| if (in_gradients[i] is not null) | |||
| { | |||
| var unaggregated_grads = gradients[id]; | |||
| var unaggregated_grads = gradients.SetDefault(id, new List<Tensor>()); | |||
| unaggregated_grads.Add(in_gradients[i]); | |||
| /*if (unaggregated_grads.Count > kMinAggregateCount) | |||
| if(unaggregated_grads.Count > kMinAggregateCount) | |||
| { | |||
| if (!gradients_size.find(id, out var size)) | |||
| if(!gradients_size.TryGetValue(id, out var size)) | |||
| { | |||
| size = (long)unaggregated_grads[0].size; | |||
| size = NumElements(unaggregated_grads[0]); | |||
| gradients_size.emplace(id, size); | |||
| } | |||
| if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||
| if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| Tensor grad = AggregateGradients(unaggregated_grads); | |||
| unaggregated_grads.Clear(); | |||
| unaggregated_grads.Add(grad); | |||
| } | |||
| }*/ | |||
| } | |||
| } | |||
| if (!state.tensor_usage_counts.find(id)) | |||
| if(!state.tensor_usage_counts.find(id)) | |||
| { | |||
| continue; | |||
| } | |||
| state.tensor_usage_counts[id]--; | |||
| if (state.tensor_usage_counts[id] > 0) | |||
| if(state.tensor_usage_counts[id] > 0) | |||
| { | |||
| continue; | |||
| if (!tensor_tape_.find(id, out var tape_it)) | |||
| } | |||
| if (!tensor_tape_.TryGetValue(id, out var tape_it)) | |||
| { | |||
| if (gradients.find(id, out var grad_it)) | |||
| if (gradients.find(id)) | |||
| { | |||
| // foreach (var g in grad_it) | |||
| // DeleteGradient(g); | |||
| gradients.erase(id); | |||
| } | |||
| continue; | |||
| } | |||
| var op_id = tape_it; | |||
| if (op_id == -1) | |||
| long op_id = tape_it; | |||
| if(op_id == -1) | |||
| { | |||
| continue; | |||
| if (state.op_missing_tensor.find(op_id, out var missing_it)) | |||
| } | |||
| if(state.op_missing_tensor.find(op_id)) | |||
| { | |||
| state.op_missing_tensor[op_id]--; | |||
| if (state.op_missing_tensor[op_id] == 0) | |||
| if(state.op_missing_tensor[op_id] == 0) | |||
| { | |||
| op_stack.Enqueue(op_id); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (state.op_tape.Count > 0) | |||
| if(state.op_tape.Count > 0) | |||
| { | |||
| throw new RuntimeError("Invalid tape state."); | |||
| var result = new Tensor[source_tensor_ids.Length]; | |||
| var j = 0; | |||
| foreach (var id in source_tensor_ids) | |||
| } | |||
| Tensor[] result = new Tensor[source_tensor_ids.Length]; | |||
| for(int i = 0; i < source_tensor_ids.Length; i++) | |||
| { | |||
| if (gradients.find(id, out var grad_it)) | |||
| long tensor_id = source_tensor_ids[i]; | |||
| if(!gradients.TryGetValue(tensor_id, out var grad_it)) | |||
| { | |||
| if (grad_it.Count > 1) | |||
| result[j] = gen_math_ops.add_n(grad_it.ToArray()); | |||
| else | |||
| result[j] = grad_it[0]; | |||
| result[i] = null; | |||
| } | |||
| else | |||
| { | |||
| if(grad_it.Count > 1) | |||
| { | |||
| Tensor grad = AggregateGradients(grad_it); | |||
| grad_it.Clear(); | |||
| grad_it.Add(grad); | |||
| } | |||
| result[i] = grad_it[0]; | |||
| } | |||
| j++; | |||
| } | |||
| return result; | |||
| } | |||
| UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() | |||
| { | |||
| var m = new UnorderedMap<string, UnorderedSet<int>>(); | |||
| m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||
| return m; | |||
| return _functionsAcceptingNoneForIndicesMap; | |||
| } | |||
| UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids, | |||
| UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
| Tensor[] output_gradients, | |||
| UnorderedMap<long, List<Tensor>> InitialGradients(long[] target_tensor_ids, | |||
| UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
| List<Tensor> output_gradients, | |||
| TensorTape tensor_tape, | |||
| OpTape op_tape) | |||
| { | |||
| var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>(); | |||
| for (int i = 0; i < target_tensor_ids.Length; ++i) | |||
| var result = new UnorderedMap<long, List<Tensor>>(); | |||
| for(int i = 0, end = target_tensor_ids.Length; i < end; i++) | |||
| { | |||
| var id = target_tensor_ids[i]; | |||
| if (output_gradients.Length == 0 || output_gradients[i] == null) | |||
| long id = target_tensor_ids[i]; | |||
| if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null) | |||
| { | |||
| if (tensor_tape.find(id, out var tensor_id) && tensor_id != null) | |||
| if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1) | |||
| { | |||
| if (!op_tape.find(tensor_tape[id], out var op_it)) | |||
| if(!op_tape.TryGetValue(tensor_it, out var op_it)) | |||
| { | |||
| throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||
| "failed to find operation producing a tensor"); | |||
| "failed to find operation producing a tensor."); | |||
| } | |||
| bool found = false; | |||
| for (int j = 0; j < op_it.output_tensor_info.Length; ++j) | |||
| for(int j = 0; j < op_it.output_tensor_info.Length; j++) | |||
| { | |||
| if (op_it.output_tensor_info[j].GetTensor() == id) | |||
| if (op_it.output_tensor_info[j].GetID() == id) | |||
| { | |||
| found = true; | |||
| var ones = op_it.output_tensor_info[j].OnesLike(); | |||
| result[id].Add(ones); | |||
| Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]); | |||
| result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||
| break; | |||
| } | |||
| } | |||
| if (!found) | |||
| { | |||
| throw new ValueError("Internal state of the gradient tape is invalid: " + | |||
| "none of operations outputs match expected tensor"); | |||
| throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||
| "none of operations outputs match expected tensor."); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| if (sources_that_are_targets.find(id, out var source_tensor)) | |||
| result[id].Add(source_tensor.OnesLike()); | |||
| if(sources_that_are_targets.TryGetValue(id, out var source_tensor)) | |||
| { | |||
| Tensor ones_like = BuildOnesLike(source_tensor); | |||
| result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| result[id].Add(output_gradients[i]); | |||
| result.SetDefault(id, new List<Tensor>()).Add(output_gradients[i]); | |||
| } | |||
| } | |||
| @@ -248,5 +259,26 @@ namespace Tensorflow.Gradients | |||
| } | |||
| return result; | |||
| } | |||
| Tensor BuildOnesLike(TapeTensor t) | |||
| { | |||
| return t.OnesLike(); | |||
| } | |||
| Tensor AggregateGradients(List<Tensor> gradient_tensors) | |||
| { | |||
| if(gradient_tensors.Count == 0) | |||
| { | |||
| return gradient_tensors[0]; | |||
| } | |||
| return tf.add_n(gradient_tensors.ToArray()); | |||
| } | |||
| void DeleteGradient(Tensor gradient) | |||
| { | |||
| // Do not do anything here. Because GC will collect it when it has no reference. | |||
| } | |||
| long NumElements(Tensor tensor) => 1; | |||
| } | |||
| } | |||
| @@ -5,63 +5,62 @@ namespace Tensorflow.Gradients | |||
| { | |||
| public partial class Tape | |||
| { | |||
| public BackpropInitialState PrepareBackprop(Tensor[] target, | |||
| public BackpropInitialState PrepareBackprop(long[] target, | |||
| TensorTape tensor_tape, | |||
| OpTape op_tape, | |||
| UnorderedSet<Tensor> sources_set, | |||
| UnorderedSet<long> sources_set, | |||
| bool persistent_tape) | |||
| { | |||
| Stack<long> tensor_stack = new Stack<long>(); | |||
| foreach(var t in target) | |||
| { | |||
| tensor_stack.Push(t); | |||
| } | |||
| BackpropInitialState result = new BackpropInitialState(); | |||
| var tensor_stack = new Queue<Tensor>(target); | |||
| while (tensor_stack.Count > 0) | |||
| while(tensor_stack.Count > 0) | |||
| { | |||
| var tensor_id = tensor_stack.Dequeue(); | |||
| if (!tensor_tape.find(tensor_id, out var op_id)) | |||
| long tensor_id = tensor_stack.Pop(); | |||
| if(!tensor_tape.TryGetValue(tensor_id, out var op_id)) | |||
| { | |||
| continue; | |||
| if (op_id == -1 || | |||
| !op_tape.find(op_id, out var op_it) || | |||
| result.op_tape.find(op_id, out var result_op_it)) | |||
| } | |||
| if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it) | |||
| || result.op_tape.find(op_id)) | |||
| { | |||
| continue; | |||
| } | |||
| result.op_tape.emplace(op_id, op_it); | |||
| foreach (var it in op_it.input_tensor_id) | |||
| foreach(var it in op_it.input_tensor_id) | |||
| { | |||
| if (result.tensor_usage_counts.find(it)) | |||
| if(result.tensor_usage_counts.find(it)) | |||
| { | |||
| result.tensor_usage_counts[it]++; | |||
| } | |||
| else | |||
| { | |||
| result.tensor_usage_counts[it] = 1; | |||
| if (tensor_tape.find(it)) | |||
| tensor_stack.Enqueue(it); | |||
| { | |||
| tensor_stack.Push(it); | |||
| } | |||
| } | |||
| } | |||
| if (!persistent_tape) | |||
| op_tape.Remove(op_id); | |||
| { | |||
| op_tape.erase(op_id); | |||
| } | |||
| } | |||
| foreach (var pair in result.tensor_usage_counts) | |||
| foreach(var pair in result.tensor_usage_counts) | |||
| { | |||
| if (tensor_tape.find(pair.Key, out var it) && it != -1) | |||
| result.op_missing_tensor[it] += 1; | |||
| if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1) | |||
| { | |||
| result.op_missing_tensor[it]++; | |||
| } | |||
| } | |||
| if (!persistent_tape) | |||
| { | |||
| // Call destructors for all unneeded gradient functions and | |||
| // clear the op_tape. We can clear the tape because ownership of | |||
| // backward functions that will be used for gradient computation | |||
| // has been transferred to `result`. | |||
| /*for (const auto&op_pair : *op_tape) { | |||
| op_pair.second.backward_function_deleter( | |||
| op_pair.second.backward_function); | |||
| }*/ | |||
| op_tape.Clear(); | |||
| } | |||
| return result; | |||
| } | |||
| } | |||
| @@ -8,34 +8,45 @@ namespace Tensorflow.Gradients | |||
| public partial class Tape | |||
| { | |||
| long next_op_id_ = 0; | |||
| UnorderedMap<Tensor, long> tensor_usage_; | |||
| UnorderedMap<long, long> tensor_usage_; | |||
| public void RecordOperation(string op_type, | |||
| Tensor[] input_tensors, | |||
| TapeTensor[] output_tensors, | |||
| long[] input_tensor_id, | |||
| TF_DataType[] input_dtypes, | |||
| BackwardFunction backward_function) | |||
| { | |||
| if (!ShouldRecord(input_tensors)) | |||
| if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||
| return; | |||
| var op_id = next_op_id_++; | |||
| foreach (var i in input_tensors) | |||
| foreach (var i in input_tensor_id) | |||
| { | |||
| tensor_usage_[i]++; | |||
| } | |||
| long op_id = next_op_id_++; | |||
| foreach (var o in output_tensors) | |||
| { | |||
| tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | |||
| tensor_tape_[o.GetTensor()] = op_id; | |||
| tensor_usage_[o.GetTensor()] = 1; | |||
| tensor_tape_[o.GetID()] = op_id; | |||
| tensor_usage_[o.GetID()] = 1; | |||
| } | |||
| op_tape_[op_id] = new OpTapeEntry | |||
| { | |||
| op_type = op_type, | |||
| output_tensor_info = output_tensors, | |||
| input_tensor_id = input_tensors, | |||
| output_tensor_info = output_tensors.ToArray(), | |||
| input_tensor_id = input_tensor_id.ToArray(), | |||
| backward_function = backward_function | |||
| }; | |||
| } | |||
| public void RecordOperation(string op_type, | |||
| Tensor[] outputs, | |||
| Tensor[] inputs, | |||
| BackwardFunction backward_function) | |||
| { | |||
| tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -29,7 +30,7 @@ namespace Tensorflow.Gradients | |||
| _created_eagerly = tf.Context.executing_eagerly(); | |||
| tensor_tape_ = new TensorTape(); | |||
| op_tape_ = new OpTape(); | |||
| tensor_usage_ = new UnorderedMap<Tensor, long>(); | |||
| tensor_usage_ = new UnorderedMap<long, long>(); | |||
| if(_created_eagerly) | |||
| tf.Context.start_step(); | |||
| // nesting_id = ++tape_nesting_id_counter; | |||
| @@ -42,29 +43,28 @@ namespace Tensorflow.Gradients | |||
| public void Watch(Tensor x) | |||
| { | |||
| tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); | |||
| tensor_tape_.emplace(x, -1); | |||
| tensor_tape_.emplace(x.Id, -1); | |||
| } | |||
| public bool ShouldRecord(Tensor[] tensors) | |||
| public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes) | |||
| { | |||
| var dtypes = tensors.Select(x => x.dtype).ToArray(); | |||
| for (int i = 0; i < tensors.Length; ++i) | |||
| Debug.Assert(tensor_ids.Length == tensor_dtypes.Length); | |||
| for (int i = 0; i < tensor_ids.Length; ++i) | |||
| { | |||
| if (tensor_tape_.find(tensors[i])) | |||
| if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i])) | |||
| { | |||
| if (IsDtypeTrainable(dtypes[i])) | |||
| return true; | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| public void VariableAccessed(ResourceVariable variable) | |||
| public void VariableAccessed(IVariableV1 variable) | |||
| { | |||
| Watch(variable.Handle); | |||
| } | |||
| public ResourceVariable[] WatchedVariables() | |||
| public IVariableV1[] WatchedVariables() | |||
| { | |||
| return null; | |||
| } | |||
| @@ -1,27 +1,63 @@ | |||
| using static Tensorflow.Binding; | |||
| using OneOf; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| public class TapeTensor | |||
| { | |||
| Tensor tensor; | |||
| long id => tensor.Id; | |||
| TF_DataType dtype => tensor.dtype; | |||
| Shape shape => tensor.shape; | |||
| internal Tensor tensor; | |||
| internal long id; | |||
| internal TF_DataType dtype; | |||
| internal OneOf<Shape, Tensor> shape; | |||
| public TapeTensor(long id, TF_DataType dtype, Shape shape) | |||
| { | |||
| this.id = id; | |||
| this.dtype = dtype; | |||
| this.shape = shape; | |||
| } | |||
| public TapeTensor(long id, TF_DataType dtype, Tensor shape) | |||
| { | |||
| this.id = id; | |||
| this.dtype = dtype; | |||
| this.shape = shape; | |||
| } | |||
| public TapeTensor(Tensor tensor) | |||
| { | |||
| this.id = tensor.Id; | |||
| this.dtype = tensor.dtype; | |||
| this.shape = tensor.shape; | |||
| this.tensor = tensor; | |||
| } | |||
| public long GetID() => tensor.Id; | |||
| public Tensor GetTensor() => tensor; | |||
| public long GetID() => id; | |||
| public Tensor ZerosLike() | |||
| => tf.zeros(shape: shape, dtype: dtype); | |||
| { | |||
| if(dtype == dtypes.resource) | |||
| { | |||
| return null; | |||
| } | |||
| if(shape.Index == 1) | |||
| { | |||
| return tf.zeros_like(shape.AsT1); | |||
| } | |||
| return tf.zeros(shape.AsT0, dtype); | |||
| } | |||
| public Tensor OnesLike() | |||
| => tf.ones(shape: shape, dtype: dtype); | |||
| { | |||
| if (shape.Index == 1) | |||
| { | |||
| return tf.ones_like(shape.AsT1); | |||
| } | |||
| return tf.ones(shape.AsT0, dtype); | |||
| } | |||
| //public Tensor OnesLike() | |||
| // => tf.ones(shape: shape, dtype: dtype); | |||
| public override string ToString() | |||
| => $"{id}, {shape}, {dtype.as_numpy_name()}"; | |||
| @@ -7,7 +7,7 @@ namespace Tensorflow.Gradients | |||
| /// produced this tensor. A value of -1 means that the tensor was directly | |||
| /// watched and not the result of any operation in the tape. | |||
| /// </summary> | |||
| public class TensorTape : UnorderedMap<Tensor, long> | |||
| public class TensorTape : UnorderedMap<long, long> | |||
| { | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| public class custom_gradient | |||
| { | |||
| public static string generate_name() | |||
| { | |||
| return $"CustomGradient-{ops.uid()}"; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,52 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| internal static class default_gradient | |||
| { | |||
| public static (Shape, TF_DataType) shape_and_dtype(Tensor t) | |||
| { | |||
| if(t.dtype == dtypes.resource) | |||
| { | |||
| var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); | |||
| if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
| { | |||
| throw new ValueError($"Internal error: Tried to take gradients (or similar) " + | |||
| $"of a variable without handle data:\n{t}"); | |||
| } | |||
| return (new Shape(handle_data.ShapeAndType[0].Shape), handle_data.ShapeAndType[0].Dtype.as_tf_dtype()); | |||
| } | |||
| return (t.shape, t.dtype); | |||
| } | |||
| public static Tensor zeros_like(Tensor t) | |||
| { | |||
| if(t.dtype == dtypes.resource) | |||
| { | |||
| var (shape, dtype) = shape_and_dtype(t); | |||
| return array_ops.zeros(shape, dtype); | |||
| } | |||
| else | |||
| { | |||
| return array_ops.zeros_like(t); | |||
| } | |||
| } | |||
| public static TF_DataType get_zeros_dtype(Tensor t) | |||
| { | |||
| if(t.dtype == dtypes.resource) | |||
| { | |||
| var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); | |||
| if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
| { | |||
| throw new ValueError($"Internal error: Tried to take gradients (or similar) " + | |||
| $"of a variable without handle data:\n{t}"); | |||
| } | |||
| return handle_data.ShapeAndType[0].Dtype.as_tf_dtype(); | |||
| } | |||
| return t.dtype; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,10 +14,15 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Operations.ControlFlows; | |||
| using static Tensorflow.Binding; | |||
| @@ -25,6 +30,11 @@ namespace Tensorflow | |||
| { | |||
| public class gradients_util | |||
| { | |||
| // Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are | |||
| // unfortunately too slow to use here. | |||
| public static int POSSIBLE_GRADIENT_TYPES_NONE = 0; | |||
| public static int POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1; | |||
| public static int POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2; | |||
| public static Tensor[] _GradientsHelper(Tensor[] ys, | |||
| Tensor[] xs, | |||
| Tensor[] grad_ys = null, | |||
| @@ -143,7 +153,7 @@ namespace Tensorflow | |||
| Tensor[] in_grads = null; | |||
| Func<Operation, Tensor[], Tensor[]> grad_fn = null; | |||
| var is_partitioned_call = _IsPartitionedCall(op); | |||
| var is_func_call = false; | |||
| var is_func_call = src_graph.IsFunction(op.type) || is_partitioned_call; | |||
| var has_out_grads = out_grads.Exists(x => x != null); | |||
| if (has_out_grads && !stop_ops.Contains(op)) | |||
| { | |||
| @@ -157,14 +167,41 @@ namespace Tensorflow | |||
| { | |||
| if (is_func_call) | |||
| { | |||
| EagerDefinedFunction func_call = null; | |||
| if (is_partitioned_call) | |||
| { | |||
| var func_attr = op.get_attr("f"); | |||
| Debug.Assert(func_attr is NameAttrList); | |||
| var func_name = ((NameAttrList)func_attr).Name; | |||
| func_call = src_graph._get_function(func_name); | |||
| if(func_call is null && src_graph.OuterGraph is not null) | |||
| { | |||
| var graph = src_graph.OuterGraph; | |||
| while(graph is not null) | |||
| { | |||
| func_call = graph._get_function(func_name); | |||
| if(func_call is not null) | |||
| { | |||
| break; | |||
| } | |||
| if(graph.OuterGraph is not null) | |||
| { | |||
| graph = graph.OuterGraph; | |||
| } | |||
| else | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| func_call = src_graph._get_function(op.type); | |||
| } | |||
| // skip the following codes: | |||
| // `func_call = getattr(op, "__defun", func_call)` | |||
| grad_fn = func_call.csharp_grad_func; | |||
| } | |||
| else | |||
| { | |||
| @@ -208,6 +245,8 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), | |||
| null, (x, y) => _SymGrad(x, y)); | |||
| throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); | |||
| } | |||
| _VerifyGeneratedGradients(in_grads, op); | |||
| @@ -663,6 +702,11 @@ namespace Tensorflow | |||
| dtypes.resource, dtypes.variant}.Contains(dtype); | |||
| } | |||
| public static int PossibleTapeGradientTypes(Tensor[] tensors) | |||
| { | |||
| return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors); | |||
| } | |||
| /// <summary> | |||
| /// Return true if op has real gradient. | |||
| /// </summary> | |||
| @@ -683,7 +727,7 @@ namespace Tensorflow | |||
| private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | |||
| { | |||
| scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; | |||
| // scope = scope.TrimEnd('/').Replace('/', '_'); | |||
| return grad_fn(op, out_grads); | |||
| } | |||
| @@ -696,5 +740,28 @@ namespace Tensorflow | |||
| throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + | |||
| $"inputs {op.inputs._inputs.Count()}"); | |||
| } | |||
| private static Tensor[] _SymGrad(Operation op, Tensor[] out_grads) | |||
| { | |||
| var f_in = ((Tensor[])op.inputs).Concat(out_grads).ToArray(); | |||
| var f_types = ((Tensor[])op.inputs).Select(x => default_gradient.get_zeros_dtype(x)).ToArray(); | |||
| NameAttrList f = new(); | |||
| if (_IsPartitionedCall(op)) | |||
| { | |||
| var func_attr = op.get_attr("f"); | |||
| Debug.Assert(func_attr is NameAttrList); | |||
| f.Name = ((NameAttrList)func_attr).Name; | |||
| } | |||
| else | |||
| { | |||
| f.Name = op.type; | |||
| } | |||
| foreach(var k in op.node_def.Attr.Keys) | |||
| { | |||
| f.Attr[k] = AttrValue.Parser.ParseFrom(op.node_def.Attr[k].ToByteArray()); | |||
| } | |||
| var in_grads = gen_functional_ops.symbolic_gradient(f_in, f_types, f); | |||
| return in_grads; | |||
| } | |||
| } | |||
| } | |||
| @@ -98,12 +98,23 @@ namespace Tensorflow | |||
| { | |||
| if (op.inputs == null) return null; | |||
| RegisterFromAssembly(); | |||
| var gradient_function = op._gradient_function; | |||
| if(gradient_function is null) | |||
| { | |||
| RegisterFromAssembly(); | |||
| if (!gradientFunctions.ContainsKey(op.type)) | |||
| throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||
| if (!gradientFunctions.ContainsKey(op.type)) | |||
| throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); | |||
| return gradientFunctions[op.type]; | |||
| } | |||
| return gradientFunctions[op.type]; | |||
| Tensor[] wrapped_gradient_function(Operation operation, Tensor[] args) | |||
| { | |||
| return gradient_function(operation, args); | |||
| } | |||
| // TODO(Rinne): check if this needs to be registered. | |||
| return wrapped_gradient_function; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using MethodBoundaryAspect.Fody.Attributes; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Functions; | |||
| @@ -22,7 +23,7 @@ namespace Tensorflow.Graphs | |||
| public override void OnEntry(MethodExecutionArgs args) | |||
| { | |||
| // TODO: func_name can be cache in FullName + Args | |||
| func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; | |||
| func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; | |||
| if (functions.ContainsKey(func_name)) | |||
| { | |||
| @@ -91,6 +92,7 @@ namespace Tensorflow.Graphs | |||
| // cache function. | |||
| function.ReturnType = args.ReturnValue.GetType(); | |||
| function._set_infer_function(); | |||
| functions[func_name] = function; | |||
| // run function | |||
| @@ -1,6 +1,15 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Exceptions; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Graphs; | |||
| @@ -10,12 +19,66 @@ namespace Tensorflow.Graphs; | |||
| /// </summary> | |||
| public class FuncGraph : Graph, IDisposable | |||
| { | |||
| SafeFuncGraphHandle _func_graph_handle; | |||
| internal SafeFuncGraphHandle _func_graph_handle; | |||
| internal HashSet<Tensor> _resource_tensor_inputs; | |||
| internal HashSet<WeakReference<IVariableV1>> _watched_variables; | |||
| internal IEnumerable<WeakReference<IVariableV1>> _weak_variables; | |||
| internal object[] _structured_outputs; | |||
| internal Dictionary<long, string> _output_names; | |||
| public string FuncName => _graph_key; | |||
| public Tensors Inputs { get; set; } = new Tensors(); | |||
| public Tensors Outputs { get; set; } = new Tensors(); | |||
| public Dictionary<string, string> Attrs { get; set; } | |||
| public Tensors FlatStructuredOutputs | |||
| { | |||
| get | |||
| { | |||
| List<Tensor> res = new(); | |||
| foreach(var obj in _structured_outputs) | |||
| { | |||
| if(obj is Tensor tensor) | |||
| { | |||
| res.Add(tensor); | |||
| } | |||
| else if(obj is IEnumerable<Tensor> tensors) | |||
| { | |||
| res.AddRange(tensors); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("The structured outputs member should be tensor or tensors."); | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| } | |||
| public string Name { get; set; } | |||
| public IEnumerable<IVariableV1> Variables | |||
| { | |||
| get | |||
| { | |||
| return _weak_variables.Select(v => | |||
| { | |||
| if (v.TryGetTarget(out var target)) | |||
| { | |||
| return target; | |||
| } | |||
| else | |||
| { | |||
| throw new AssertionError("Called a function referencing variables which have been deleted. " + | |||
| "This likely means that function-local variables were created and " + | |||
| "not referenced elsewhere in the program. This is generally a " + | |||
| "mistake; consider storing variables in an object attribute on first call."); | |||
| } | |||
| }); | |||
| } | |||
| internal set | |||
| { | |||
| _weak_variables = value.Select(x => new WeakReference<IVariableV1>(x)); | |||
| } | |||
| } | |||
| public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | |||
| public Dictionary<string, AttrValue> Attrs { get; set; } | |||
| Dictionary<long, (Tensor, Tensor)> _captures | |||
| = new Dictionary<long, (Tensor, Tensor)>(); | |||
| @@ -39,31 +102,42 @@ public class FuncGraph : Graph, IDisposable | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| _graph_key = Name = name; | |||
| building_function = true; | |||
| _weak_variables = new List<WeakReference<IVariableV1>>(); | |||
| _resource_tensor_inputs = new HashSet<Tensor>(); | |||
| _watched_variables = new HashSet<WeakReference<IVariableV1>>(); | |||
| } | |||
| public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base() | |||
| public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, AttrValue> attrs) : base() | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| _graph_key = Name = name; | |||
| building_function = true; | |||
| Attrs = attrs; | |||
| // Will to test if FuncGraph has memory leak | |||
| // c_api.TF_DeleteGraph(_handle); | |||
| _handle = handle; | |||
| _weak_variables = new List<WeakReference<IVariableV1>>(); | |||
| _resource_tensor_inputs = new HashSet<Tensor>(); | |||
| _watched_variables = new HashSet<WeakReference<IVariableV1>>(); | |||
| } | |||
| public void ToGraph(Operation[] opers, | |||
| public void replace_capture(Tensor tensor, Tensor placeholder) | |||
| { | |||
| _captures[tensor.Id] = (tensor, placeholder); | |||
| } | |||
| public unsafe void ToGraph(Operation[] opers, | |||
| Tensor[] inputs, Tensor[] outputs, | |||
| string[] output_names) | |||
| { | |||
| var status = new Status(); | |||
| if (output_names != null && output_names.Length == 0) | |||
| if (output_names is null) | |||
| { | |||
| output_names = null; | |||
| output_names = new string[0]; | |||
| }; | |||
| _func_graph_handle = c_api.TF_GraphToFunction(_handle, | |||
| @@ -75,7 +149,7 @@ public class FuncGraph : Graph, IDisposable | |||
| inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| outputs.Length, | |||
| outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
| output_names, | |||
| output_names.Length != outputs.Length ? null : output_names, | |||
| IntPtr.Zero, | |||
| null, | |||
| status); | |||
| @@ -141,6 +215,16 @@ public class FuncGraph : Graph, IDisposable | |||
| return tensor; | |||
| } | |||
| public void watch_variable(IVariableV1 v) | |||
| { | |||
| if (_resource_tensor_inputs.Contains(v.Handle)) | |||
| { | |||
| return; | |||
| } | |||
| _watched_variables.Add(new WeakReference<IVariableV1>(v)); | |||
| //this = this.outer_graph; | |||
| } | |||
| Tensor capture_eager_tensor(Tensor tensor, string name) | |||
| { | |||
| Tensor graph_const = null; | |||
| @@ -205,6 +289,19 @@ public class FuncGraph : Graph, IDisposable | |||
| Inputs.Add(placeholder); | |||
| } | |||
| Tensor pop_capture(Tensor tensor) | |||
| { | |||
| if(_captures.TryGetValue(tensor.Id, out var capture)) | |||
| { | |||
| _captures.Remove(tensor.Id); | |||
| return capture.Item2; | |||
| } | |||
| else | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| Tensor _create_substitute_placeholder(Tensor value, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| @@ -228,10 +325,7 @@ public class FuncGraph : Graph, IDisposable | |||
| foreach (var (_name, attr_value) in enumerate(Attrs)) | |||
| { | |||
| var serialized = new AttrValue | |||
| { | |||
| S = ByteString.CopyFromUtf8(attr_value) | |||
| }.ToByteArray(); | |||
| var serialized = attr_value.ToByteArray(); | |||
| c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); | |||
| tf.Status.Check(true); | |||
| } | |||
| @@ -254,4 +348,261 @@ public class FuncGraph : Graph, IDisposable | |||
| { | |||
| c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); | |||
| } | |||
| public static FuncGraph func_graph_from_func(string name, Func<object[], object[]> func, | |||
| object[] args, Dictionary<string, object> kwargs, TensorSpec[] signature = null, | |||
| FuncGraph func_graph = null, bool autograph = false, object autograph_options = null, | |||
| bool add_control_dependencies = true, string[] arg_names = null, | |||
| Tensor op_return_value = null, bool capture_by_value = false, | |||
| bool acd_record_initial_resource_uses = false) | |||
| { | |||
| if(func_graph is null) | |||
| { | |||
| func_graph = new FuncGraph(name); | |||
| } | |||
| // TODO(Rinne): deal with control dependencies. | |||
| func_graph.as_default(); | |||
| var current_scope = variable_scope.get_variable_scope(); | |||
| var default_use_resource = current_scope.use_resource; | |||
| current_scope.use_resource = true; | |||
| if(signature is not null) | |||
| { | |||
| args = signature; | |||
| kwargs = new Dictionary<string, object>(); | |||
| } | |||
| var func_args = _get_defun_inputs_from_args(args, arg_names); | |||
| var func_kwargs = _get_defun_inputs_from_kwargs(kwargs); | |||
| if(func_kwargs is not null && func_kwargs.Count > 0) | |||
| { | |||
| throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`."); | |||
| } | |||
| foreach(var arg in nest.flatten<object>(new object[] { func_args, func_kwargs })) | |||
| { | |||
| if(arg is Tensor tensor && tensor.dtype == dtypes.resource) | |||
| { | |||
| func_graph._resource_tensor_inputs.Add(tensor); | |||
| } | |||
| else if (arg is ResourceVariable variable) | |||
| { | |||
| func_graph._resource_tensor_inputs.Add(variable.Handle); | |||
| } | |||
| } | |||
| // skip the assignment of `func_graph.structured_input_signature`. | |||
| var flat_func_args = nest.flatten(func_args as object); | |||
| var flat_func_kwargs = nest.flatten(func_kwargs as object); | |||
| func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | |||
| .Where(x => x is Tensor).Select(x => (Tensor)x)); | |||
| //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | |||
| //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | |||
| Tensor convert(object x) | |||
| { | |||
| if (x is null) return null; | |||
| Tensor res = null; | |||
| if(op_return_value is not null && x is Operation) | |||
| { | |||
| tf_with(ops.control_dependencies(new object[] { x }), _ => | |||
| { | |||
| res = array_ops.identity(op_return_value); | |||
| }); | |||
| } | |||
| else if(x is not TensorArray) | |||
| { | |||
| Debug.Assert(x is Tensor); | |||
| res = ops.convert_to_tensor_or_composite(x as Tensor); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException($"The `TensorArray` is not supported here currently."); | |||
| } | |||
| if (add_control_dependencies) | |||
| { | |||
| // TODO(Rinne): `x = deps_ctx.mark_as_return(x)`. | |||
| } | |||
| return res; | |||
| } | |||
| if (autograph) | |||
| { | |||
| throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported."); | |||
| } | |||
| var func_outputs = func(func_args); | |||
| func_outputs = variable_utils.convert_variables_to_tensors(func_outputs); | |||
| func_outputs = func_outputs.Select(x => convert(x)).ToArray(); | |||
| // TODO(Rinne): `check_func_mutation`. | |||
| current_scope.use_resource = default_use_resource; | |||
| var graph_variables = func_graph._watched_variables.ToList(); | |||
| HashSet<IVariableV1> arg_variables = new HashSet<IVariableV1>(); | |||
| List<Tensor> inputs = new(); | |||
| foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args)) | |||
| { | |||
| if(arg is BaseResourceVariable variable) | |||
| { | |||
| var resource_placeholder = func_graph.pop_capture(variable.Handle); | |||
| if(resource_placeholder is null) | |||
| { | |||
| continue; | |||
| } | |||
| Debug.Assert(variable is IVariableV1); | |||
| arg_variables.Add(variable as IVariableV1); | |||
| inputs.Add(resource_placeholder); | |||
| } | |||
| else if(arg is Tensor tensor) | |||
| { | |||
| inputs.Add(tensor); | |||
| } | |||
| } | |||
| var variables = graph_variables.Select(v => | |||
| { | |||
| if (v.TryGetTarget(out var target)) | |||
| { | |||
| return target; | |||
| } | |||
| else | |||
| { | |||
| return null; | |||
| } | |||
| }).Where(v => v is not null && !arg_variables.Contains(v)); | |||
| func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray(); | |||
| func_graph._structured_outputs = func_outputs; | |||
| func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null) | |||
| .Select(x => func_graph.capture(x))); | |||
| func_graph.Variables = variables; | |||
| func_graph.Exit(); | |||
| if (add_control_dependencies) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| return func_graph; | |||
| } | |||
| private static object[] _get_defun_inputs_from_args(object[] args, string[] names) | |||
| { | |||
| return _get_defun_inputs(args, names, args) as object[]; | |||
| } | |||
| private static Dictionary<string, object> _get_defun_inputs_from_kwargs(Dictionary<string, object> kwargs) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| Debug.Assert(kwargs is null || kwargs.Count == 0); | |||
| return kwargs; | |||
| //string[] names; | |||
| //object[] args; | |||
| //if(kwargs is not null && kwargs.Count > 0) | |||
| //{ | |||
| // var sorted_kwargs = kwargs.OrderBy(x => x.Key); | |||
| // names = sorted_kwargs.Select(x => x.Key).ToArray(); | |||
| // args = sorted_kwargs.Select(x => x.Value).ToArray(); | |||
| //} | |||
| //else | |||
| //{ | |||
| // names = new string[0]; | |||
| // args = new object[0]; | |||
| //} | |||
| //return _get_defun_inputs(args, names, kwargs) as Dictionary<string, object>; | |||
| } | |||
| private static object _get_defun_inputs(object[] args, string[] names, object structured_args) | |||
| { | |||
| List<object> function_inputs = new(); | |||
| if(names is null) | |||
| { | |||
| names = new string[args.Length]; | |||
| } | |||
| foreach(var (arg_value, name) in zip(args, names)) | |||
| { | |||
| foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value)) | |||
| { | |||
| function_inputs.Add(_get_defun_input(val, name)); | |||
| } | |||
| } | |||
| return nest.pack_sequence_as(structured_args, nest.flatten<object>(function_inputs), true); | |||
| } | |||
| private static object _get_defun_input(object arg, string name) | |||
| { | |||
| var func_graph = ops.get_default_graph() as FuncGraph; | |||
| Debug.Assert(func_graph is not null); | |||
| if (arg is Tensor tensor) | |||
| { | |||
| Tensor placeholder; | |||
| try | |||
| { | |||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); | |||
| } | |||
| catch (ValueError) | |||
| { | |||
| // TODO(Rinne): Add warning here. | |||
| placeholder = tf.placeholder(tensor.dtype, tensor.shape); | |||
| } | |||
| handle_data_util.copy_handle_data(tensor, placeholder); | |||
| if (name is not null) | |||
| { | |||
| placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||
| { | |||
| S = tf.compat.as_bytes(name) | |||
| }); | |||
| } | |||
| return placeholder; | |||
| } | |||
| else if (arg is TensorSpec spec) | |||
| { | |||
| string requested_name; | |||
| if (!string.IsNullOrEmpty(spec.name)) | |||
| { | |||
| requested_name = spec.name; | |||
| } | |||
| else | |||
| { | |||
| requested_name = name; | |||
| } | |||
| Tensor placeholder; | |||
| try | |||
| { | |||
| placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); | |||
| } | |||
| catch (ValueError) | |||
| { | |||
| // TODO(Rinne): Add warning here. | |||
| placeholder = tf.placeholder(spec.dtype, spec.shape); | |||
| } | |||
| if (name is not null) | |||
| { | |||
| placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||
| { | |||
| S = tf.compat.as_bytes(requested_name) | |||
| }); | |||
| } | |||
| return placeholder; | |||
| } | |||
| else if (arg is BaseResourceVariable variable) | |||
| { | |||
| var placeholder = func_graph.capture(variable.Handle, name); | |||
| placeholder.op._set_attr("_user_specified_name", new AttrValue() | |||
| { | |||
| S = tf.compat.as_bytes(name) | |||
| }); | |||
| return arg; | |||
| } | |||
| // TODO(Rinne): deal with `VariableSpec`. | |||
| else | |||
| { | |||
| return arg; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,6 @@ | |||
| namespace Tensorflow | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| { | |||
| @@ -6,5 +8,10 @@ | |||
| { | |||
| } | |||
| internal GraphOverrideGradientContext _override_gradient_function(Dictionary<string, Func<Operation, object[], Tensor[]>> gradient_function_map) | |||
| { | |||
| return new GraphOverrideGradientContext(this, gradient_function_map); | |||
| } | |||
| } | |||
| } | |||
| @@ -118,7 +118,7 @@ namespace Tensorflow | |||
| /// <param name="compute_device">(Optional.) If True, device functions will be executed | |||
| /// to compute the device property of the Operation.</param> | |||
| /// <returns>An `Operation` object.</returns> | |||
| public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) | |||
| public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true, OperationDescription desc = null) | |||
| { | |||
| var ret = new Operation(c_op, this); | |||
| _add_op(ret); | |||
| @@ -19,6 +19,9 @@ using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Collections.Specialized; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Common.Extensions; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| @@ -86,6 +89,13 @@ namespace Tensorflow | |||
| private int _next_id_counter; | |||
| private List<Operation> _unfetchable_ops = new List<Operation>(); | |||
| private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | |||
| private Dictionary<string, EagerDefinedFunction> _functions = new(); | |||
| internal Dictionary<string, Func<Operation, object[], Tensor[]>> _gradient_function_map = new(); | |||
| private VersionDef _graph_def_versions = new VersionDef() | |||
| { | |||
| Producer = versions.GRAPH_DEF_VERSION, | |||
| MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||
| }; | |||
| public string _name_stack = ""; | |||
| protected string _graph_key; | |||
| @@ -121,6 +131,8 @@ namespace Tensorflow | |||
| protected Graph outer_graph; | |||
| public Graph OuterGraph => outer_graph; | |||
| public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||
| public SafeGraphHandle c_graph => _handle; | |||
| public Graph() | |||
| { | |||
| @@ -147,6 +159,44 @@ namespace Tensorflow | |||
| return ops.set_default_graph(this); | |||
| } | |||
| public bool IsFunction(string name) | |||
| { | |||
| return _functions.ContainsKey(tf.compat.as_str(name)); | |||
| } | |||
| internal void AddFunction(EagerDefinedFunction function) | |||
| { | |||
| _check_not_finalized(); | |||
| var name = function.Name; | |||
| if(function._grad_func_name is not null && function.csharp_grad_func is not null) | |||
| { | |||
| throw new ValueError($"Gradient defined twice for function {name}"); | |||
| } | |||
| var c_graph = this.c_graph; | |||
| var func = function._c_func.Get(); | |||
| Status status = new(); | |||
| if (function._grad_func is not null) | |||
| { | |||
| var gradient = function._grad_func._c_func.Get(); | |||
| c_api.TF_GraphCopyFunction(c_graph, func, gradient, status); | |||
| status.Check(true); | |||
| } | |||
| else | |||
| { | |||
| c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status); | |||
| status.Check(true); | |||
| } | |||
| _functions[tf.compat.as_str(name)] = function; | |||
| if(_graph_def_versions.MinConsumer < 12) | |||
| { | |||
| _graph_def_versions.MinConsumer = 12; | |||
| } | |||
| } | |||
| private Tensor _as_graph_element(object obj) | |||
| { | |||
| if (obj is RefVariable var) | |||
| @@ -308,6 +358,9 @@ namespace Tensorflow | |||
| private void _create_op_helper(Operation op, bool compute_device = true) | |||
| { | |||
| // high priority | |||
| // TODO(Rinne): complete the implementation. | |||
| op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null); | |||
| _record_op_seen_by_control_dependencies(op); | |||
| } | |||
| @@ -524,6 +577,11 @@ namespace Tensorflow | |||
| ops.pop_graph(); | |||
| } | |||
| internal EagerDefinedFunction _get_function(string name) | |||
| { | |||
| return _functions.GetOrDefault(name, null); | |||
| } | |||
| string debugString = string.Empty; | |||
| public override string ToString() | |||
| { | |||
| @@ -0,0 +1,37 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| namespace Tensorflow.Graphs | |||
| { | |||
| internal class GraphOverrideGradientContext: ITensorFlowObject | |||
| { | |||
| Graph _graph; | |||
| Dictionary<string, Func<Operation, object[], Tensor[]>> _new_gradient_function_map; | |||
| public GraphOverrideGradientContext(Graph graph, | |||
| Dictionary<string, Func<Operation, object[], Tensor[]>> new_gradient_function_map) | |||
| { | |||
| _graph = graph; | |||
| _new_gradient_function_map = new_gradient_function_map; | |||
| } | |||
| [DebuggerStepThrough] | |||
| public void __enter__() | |||
| { | |||
| Debug.Assert(_graph._gradient_function_map.Count == 0); | |||
| _graph._gradient_function_map = _new_gradient_function_map; | |||
| } | |||
| [DebuggerStepThrough] | |||
| public void __exit__() | |||
| { | |||
| _graph._gradient_function_map = new Dictionary<string, Func<Operation, object[], Tensor[]>>(); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -28,6 +28,8 @@ public sealed class ImportGraphDefOptions | |||
| _handle = c_api.TF_NewImportGraphDefOptions(); | |||
| } | |||
| public SafeImportGraphDefOptionsHandle Options => _handle; | |||
| public void AddReturnOutput(string name, int index) | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
| @@ -185,6 +185,9 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddReturnOperation(SafeImportGraphDefOptionsHandle opts, string oper_name); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints(SafeImportGraphDefOptionsHandle options, bool validate_colocation_constraints); | |||
| /// <summary> | |||
| /// Add an output in `graph_def` to be returned via the `return_outputs` output | |||
| /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | |||
| @@ -246,7 +249,7 @@ namespace Tensorflow | |||
| /// <param name="ops">TF_ImportGraphDefOptions*</param> | |||
| /// <param name="uniquify_prefix">unsigned char</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, char uniquify_prefix); | |||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(SafeImportGraphDefOptionsHandle ops, bool uniquify_prefix); | |||
| /// <summary> | |||
| /// Fetches the return operations requested via | |||
| @@ -308,7 +311,7 @@ namespace Tensorflow | |||
| /// <param name="types">const TF_DataType*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output, | |||
| public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output, | |||
| int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, | |||
| SafeStatusHandle status); | |||
| @@ -4,10 +4,10 @@ public interface IOptimizer | |||
| { | |||
| Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); | |||
| Tensor[] clip_gradients(Tensor[] grads); | |||
| void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||
| void apply_gradients((Tensor, IVariableV1) grads_and_vars, | |||
| string name = null, | |||
| bool experimental_aggregate_gradients = true); | |||
| void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||
| void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
| string name = null, | |||
| bool experimental_aggregate_gradients = true); | |||
| } | |||
| @@ -20,6 +20,9 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| using Google.Protobuf; | |||
| using Google.Protobuf.WellKnownTypes; | |||
| using System.Diagnostics; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -47,6 +50,8 @@ namespace Tensorflow | |||
| private readonly Graph _graph; | |||
| internal Func<Operation, object[], Tensor[]> _gradient_function; | |||
| public string type => OpType; | |||
| public Graph graph => _graph; | |||
| @@ -61,7 +66,7 @@ namespace Tensorflow | |||
| public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
| // OperationDescription _opDesc; | |||
| //private OperationDescription _op_desc; | |||
| public NodeDef node_def => GetNodeDef(); | |||
| @@ -216,21 +221,19 @@ namespace Tensorflow | |||
| var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
| string oneof_value = x.ValueCase.ToString(); | |||
| if (string.IsNullOrEmpty(oneof_value)) | |||
| return null; | |||
| var oneof_value = x.ValueCase; | |||
| if (oneof_value == AttrValue.ValueOneofCase.None) | |||
| return new object[0]; | |||
| switch (oneof_value.ToLower()) | |||
| if(oneof_value == AttrValue.ValueOneofCase.List) | |||
| { | |||
| case "list": | |||
| throw new NotImplementedException($"Unsupported field type in {oneof_value}"); | |||
| case "type": | |||
| return x.Type; | |||
| case "s": | |||
| return x.S.ToStringUtf8(); | |||
| default: | |||
| return x.GetType().GetProperty(oneof_value).GetValue(x); | |||
| throw new NotImplementedException($"Unsupported field type in {oneof_value}"); | |||
| } | |||
| if(oneof_value == AttrValue.ValueOneofCase.Type) | |||
| { | |||
| return dtypes.as_tf_dtype(x.Type); | |||
| } | |||
| return ProtoUtils.GetSingleAttrValue(x, oneof_value); | |||
| } | |||
| public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||
| @@ -238,6 +241,19 @@ namespace Tensorflow | |||
| return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||
| } | |||
| [Obsolete("The implementation is not complete.")] | |||
| internal void _set_device_from_string(string device_str) | |||
| { | |||
| // TODO(Rinne): complete it with new C API `SetRequestedDevice`. | |||
| //c_api.TF_SetDevice(_handle, device_str); | |||
| } | |||
| [Obsolete("The implementation is not complete.")] | |||
| internal void _set_device(string device) | |||
| { | |||
| _set_device_from_string(device); | |||
| } | |||
| private NodeDef GetNodeDef() | |||
| { | |||
| var buffer = new Buffer(); | |||
| @@ -296,5 +312,60 @@ namespace Tensorflow | |||
| } | |||
| public NDArray numpy() => throw new NotImplementedException(""); | |||
| internal void _add_outputs(TF_DataType[] types, Shape[] shapes) | |||
| { | |||
| Debug.Assert(types.Length == shapes.Length); | |||
| int orig_num_outputs = this.outputs.Length; | |||
| var new_outputs = new List<Tensor>(_outputs); | |||
| // Since the `_outputs` is defined as `Array`, when we add new output, we | |||
| // have to create a new array, which brings some performance concerns. | |||
| // In the future maybe the type of `outputs` should be reconsidered. | |||
| for(int i = 0; i < types.Length; i++) | |||
| { | |||
| var t = new Tensor(this, orig_num_outputs + i, types[i]); | |||
| t.shape = shapes[i]; | |||
| new_outputs.Add(t); | |||
| } | |||
| _outputs = new_outputs.ToArray(); | |||
| } | |||
| internal void _set_func_attr(string attr_name, string func_name) | |||
| { | |||
| var func = new NameAttrList() { Name = func_name }; | |||
| _set_attr(attr_name, new AttrValue() { Func = func }); | |||
| } | |||
| internal void _set_type_list_attr(string attr_name, DataType[] types) | |||
| { | |||
| if(types is null || types.Length == 0) | |||
| { | |||
| return; | |||
| } | |||
| var type_list = new AttrValue.Types.ListValue(); | |||
| type_list.Type.AddRange(types); | |||
| _set_attr(attr_name, new AttrValue() { List = type_list }); | |||
| } | |||
| internal void _set_attr(string attr_name, AttrValue attr_value) | |||
| { | |||
| var buffer = new Buffer(attr_value.ToByteArray()); | |||
| try | |||
| { | |||
| _set_attr_with_buf(attr_name, buffer); | |||
| } | |||
| finally | |||
| { | |||
| buffer.Release(); | |||
| } | |||
| } | |||
| internal void _set_attr_with_buf(string attr_name, Buffer attr_buf) | |||
| { | |||
| Status status = new(); | |||
| c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| } | |||
| @@ -14,10 +14,14 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using Google.Protobuf.WellKnownTypes; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -25,6 +29,74 @@ namespace Tensorflow | |||
| { | |||
| public class functional_ops | |||
| { | |||
| public static Tensor[] partitioned_call(Tensors args, EagerDefinedFunction f, DataType[] tout, | |||
| bool executing_eagerly, string config, string executor_type) | |||
| { | |||
| if (tout is null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| if (config is null) | |||
| { | |||
| config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||
| } | |||
| if (executor_type is null) | |||
| { | |||
| executor_type = ""; | |||
| } | |||
| if (executing_eagerly) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| var converted_args = args.Select(x => ops.convert_to_tensor(x)).ToArray(); | |||
| AttrValue tin_attr = new() | |||
| { | |||
| List = new AttrValue.Types.ListValue() | |||
| }; | |||
| tin_attr.List.Type.AddRange(args.Select(x => x.dtype.as_datatype_enum())); | |||
| AttrValue tout_attr = new() | |||
| { | |||
| List = new AttrValue.Types.ListValue() | |||
| }; | |||
| tout_attr.List.Type.AddRange(tout); | |||
| AttrValue func_attr = new() | |||
| { | |||
| Func = new NameAttrList() | |||
| }; | |||
| func_attr.Func.Name = f.Name; | |||
| AttrValue executor_type_attr = new AttrValue() | |||
| { | |||
| S = tf.compat.as_bytes(executor_type) | |||
| }; | |||
| AttrValue config_proto = new AttrValue() | |||
| { | |||
| S = ByteString.CopyFromUtf8(executor_type) | |||
| }; | |||
| var graph = ops.get_default_graph(); | |||
| f.AddToGraph(graph); | |||
| // TODO(Rinne): complete it with `f.stateful` | |||
| var op_name = "PartitionedCall"; | |||
| string xla_compile_attr = "_XlaMustCompile"; | |||
| Dictionary<string, AttrValue> op_attrs = new(); | |||
| op_attrs["Tin"] = tin_attr; | |||
| op_attrs["Tout"] = tout_attr; | |||
| op_attrs["f"] = func_attr; | |||
| op_attrs["config_proto"] = config_proto; | |||
| op_attrs["executor_type"] = executor_type_attr; | |||
| // TODO(Rinne): deal with `f.definition`. | |||
| var op = graph.create_op(op_name, args, tout.Select(x => x.as_tf_dtype()).ToArray(), | |||
| name: op_name, attrs: op_attrs); | |||
| var outputs = op.outputs; | |||
| // TODO(Rinne): deal with `f.graph`. | |||
| return outputs; | |||
| } | |||
| public static Tensor scan( | |||
| Func<Tensor, Tensor, Tensor> fn, | |||
| Tensor elems, | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -210,7 +211,51 @@ namespace Tensorflow | |||
| /// <param name="name">A name for the operation (optional).</param> | |||
| /// <returns>A `Tensor`. Has the same type as `value`.</returns> | |||
| public static Tensor fill<T>(Tensor dims, T value, string name = null) | |||
| => tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value)); | |||
| return _result[0]; | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| try | |||
| { | |||
| return fill_eager_fallback(dims, value as Tensor, name, ctx); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| } | |||
| Dictionary<string, object> attrs = new Dictionary<string, object>(); | |||
| attrs["dims"] = dims; | |||
| attrs["value"] = value; | |||
| var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return result.output; | |||
| } | |||
| public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx) | |||
| { | |||
| object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() }; | |||
| var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return _result[0]; | |||
| } | |||
| //=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||
| /// <summary> | |||
| /// Return the reduction indices for computing gradients of s0 op s1 with broadcast. | |||
| @@ -0,0 +1,128 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System.Xml.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Functions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public class gen_functional_ops | |||
| { | |||
| public static Tensor[] partitioned_call(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||
| string config = "", string config_proto = "", string executor_type = "", string name = null) | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("PartitionedCall", name, | |||
| args, tout, f, config, config_proto, executor_type)); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| } | |||
| if (config is null) | |||
| { | |||
| config = ""; | |||
| } | |||
| if (config_proto is null) | |||
| { | |||
| config_proto = ""; | |||
| } | |||
| if (executor_type is null) | |||
| { | |||
| executor_type = ""; | |||
| } | |||
| Dictionary<string, object> kwargs = new(); | |||
| kwargs["args"] = args; | |||
| kwargs["Tout"] = tout; | |||
| kwargs["f"] = f; | |||
| kwargs["config"] = config; | |||
| kwargs["config_proto"] = config_proto; | |||
| kwargs["executor_type"] = executor_type; | |||
| var output = tf.OpDefLib._apply_op_helper("PartitionedCall", | |||
| name, kwargs); | |||
| var result = output.outputs; | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return result; | |||
| } | |||
| public static Tensor[] partitioned_call_eager_fallback(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, | |||
| string config, string config_proto, string executor_type, string name, Context ctx) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| if(config is null) | |||
| { | |||
| config = ""; | |||
| } | |||
| if(config_proto is null) | |||
| { | |||
| config_proto = ""; | |||
| } | |||
| if(executor_type is null) | |||
| { | |||
| executor_type = ""; | |||
| } | |||
| object[] attrs = new object[] | |||
| { | |||
| }; | |||
| } | |||
| public static Tensor[] symbolic_gradient(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name = null) | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
| "SymbolicGradient", name, input, Tout, f)); | |||
| return _result; | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| try | |||
| { | |||
| return symbolic_gradient_eager_fallback(input, Tout, f, name, ctx); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| } | |||
| var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f }); | |||
| var result = op.outputs; | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return result; | |||
| } | |||
| public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx) | |||
| { | |||
| object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f }; | |||
| var result = execute.executes("SymbolicGradient", Tout.Length, input, attrs, ctx, name); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return result; | |||
| } | |||
| } | |||
| } | |||
| @@ -10050,13 +10050,51 @@ namespace Tensorflow.Operations | |||
| /// </remarks> | |||
| public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape") | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("EnsureShape", name, input, shape)); | |||
| return _result[0]; | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| try | |||
| { | |||
| return ensure_shape_eager_fallback(input, shape, name, ctx); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| } | |||
| var dict = new Dictionary<string, object>(); | |||
| dict["input"] = input; | |||
| dict["shape"] = shape; | |||
| var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return op.output; | |||
| } | |||
| public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx) | |||
| { | |||
| object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() }; | |||
| var _result = execute.executes("EnsureShape", 1, new Tensor[] { input }, | |||
| attrs, ctx, name); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return _result[0]; | |||
| } | |||
| /// <summary> | |||
| /// Creates or finds a child frame, and makes <c>data</c> available to the child frame. | |||
| /// </summary> | |||
| @@ -0,0 +1,60 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public static class handle_data_util | |||
| { | |||
| public static void copy_handle_data(Tensor source_t, Tensor target_t) | |||
| { | |||
| if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) | |||
| { | |||
| HandleData handle_data; | |||
| if(source_t is EagerTensor) | |||
| { | |||
| handle_data = source_t.HandleData; | |||
| } | |||
| else | |||
| { | |||
| handle_data = ops.get_resource_handle_data(source_t); | |||
| } | |||
| if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null | |||
| && handle_data.ShapeAndType.Count > 0) | |||
| { | |||
| set_handle_data(target_t, handle_data); | |||
| } | |||
| } | |||
| } | |||
| public static HandleData create_handle_data(Shape shape, TF_DataType dtype) | |||
| { | |||
| HandleData handle_data = new(); | |||
| handle_data.IsSet = true; | |||
| handle_data.ShapeAndType.Add(new HandleShapeAndType() | |||
| { | |||
| Shape = shape.as_proto(), | |||
| Dtype = dtype.as_datatype_enum() | |||
| }); | |||
| return handle_data; | |||
| } | |||
| public static void set_handle_data(Tensor target_t, HandleData handle_data) | |||
| { | |||
| if(target_t is EagerTensor) | |||
| { | |||
| target_t.HandleData = handle_data; | |||
| return; | |||
| } | |||
| Status status = new(); | |||
| var proto = handle_data.ToByteArray(); | |||
| c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); | |||
| status.Check(true); | |||
| } | |||
| public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); | |||
| } | |||
| } | |||
| @@ -21,6 +21,11 @@ using Tensorflow.Train; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using Tensorflow.Variables; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Operations; | |||
| using System.Buffers; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -31,18 +36,14 @@ namespace Tensorflow | |||
| { | |||
| public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) | |||
| { | |||
| // TODO(Rinne): deal with `_handle_graph`. | |||
| var value_tensor = ops.convert_to_tensor(value); | |||
| return gen_resource_variable_ops.assign_variable_op(handle, | |||
| value_tensor, | |||
| name: name); | |||
| } | |||
| public static bool is_resource_variable(IVariableV1 var) | |||
| { | |||
| return var is ResourceVariable; | |||
| } | |||
| public static bool is_resource_variable(Trackable var) | |||
| public static bool is_resource_variable(object var) | |||
| { | |||
| return var is BaseResourceVariable; | |||
| } | |||
| @@ -78,6 +79,18 @@ namespace Tensorflow | |||
| string shared_name, string name, bool graph_mode, Tensor initial_value = null) | |||
| { | |||
| var container = ops.get_default_graph().Container; | |||
| if(container is null) | |||
| { | |||
| container = ""; | |||
| } | |||
| if (!graph_mode) | |||
| { | |||
| if(shared_name is not null) | |||
| { | |||
| throw new Exception("Using an explicit shared_name is not allowed when executing eagerly."); | |||
| } | |||
| shared_name = tf.Context.anonymous_name(); | |||
| } | |||
| var handle = gen_resource_variable_ops.var_handle_op(shape: shape, | |||
| dtype: dtype, | |||
| shared_name: shared_name, | |||
| @@ -95,26 +108,20 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| // We do not want two distinct ResourceVariable objects for the same | |||
| // underlying resource in the runtime. | |||
| // When in eager mode, explicitly ensure so here. When in graph mode, it's | |||
| // ensured by always generating different variable names. | |||
| var exists = gen_resource_variable_ops.var_is_initialized_op(handle); | |||
| // We create an assert Op instead of checking right away in order to be | |||
| // compatible with ASYNC execution mode. Further, since not all devices | |||
| // support string tensors, we encode the assertion string in the Op name | |||
| /*gen_logging_ops.assert(gen_math_ops.logical_not(exists), | |||
| new[] { exists }, | |||
| name: "EagerVariableNameReuse");*/ | |||
| var handle_data = new HandleData(); | |||
| handle_data.IsSet = true; | |||
| handle_data.ShapeAndType.Add(new HandleShapeAndType | |||
| var handle_data = handle_data_util.create_handle_data(shape, dtype); | |||
| if (initial_value is not null && initial_value.dtype == dtypes.variant) | |||
| { | |||
| Dtype = dtype.as_datatype_enum(), | |||
| Shape = shape.as_proto() | |||
| }); | |||
| var extra_handle_data = get_eager_safe_handle_data(initial_value); | |||
| if (extra_handle_data is not null && extra_handle_data.IsSet) | |||
| { | |||
| if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1) | |||
| { | |||
| throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " + | |||
| $"but saw: '{handle_data}'"); | |||
| } | |||
| handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType); | |||
| } | |||
| } | |||
| _set_handle_shapes_and_types(handle, handle_data, graph_mode); | |||
| return handle; | |||
| } | |||
| @@ -126,7 +133,7 @@ namespace Tensorflow | |||
| /// <param name="handle"></param> | |||
| /// <param name="handle_data"></param> | |||
| /// <param name="graph_mode"></param> | |||
| private static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
| internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) | |||
| { | |||
| if (!graph_mode) | |||
| return; | |||
| @@ -144,6 +151,47 @@ namespace Tensorflow | |||
| ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; | |||
| var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); | |||
| } | |||
| //tensor.HandleData = handle_data; | |||
| //if (!graph_mode) | |||
| // return; | |||
| //var shapes = handle_data.ShapeAndType.Select(x => x.Shape); | |||
| //var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); | |||
| //var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); | |||
| //var converted_shapes = shapes.Select<TensorShapeProto, Memory<int>>(s => | |||
| //{ | |||
| // if (!s.UnknownRank) | |||
| // { | |||
| // return s.Dim.Select(d => (int)d.Size).ToArray(); | |||
| // } | |||
| // else | |||
| // { | |||
| // return Memory<int>.Empty; | |||
| // } | |||
| //}).ToArray(); | |||
| //List<MemoryHandle> handles = new(); | |||
| //IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length]; | |||
| //foreach(var (i, m) in enumerate(converted_shapes)) | |||
| //{ | |||
| // if(m.IsEmpty) | |||
| // { | |||
| // shapes_with_ptr[i] = IntPtr.Zero; | |||
| // } | |||
| // else | |||
| // { | |||
| // var handle = m.Pin(); | |||
| // handles.Add(handle); | |||
| // shapes_with_ptr[i] = new IntPtr(handle.Pointer); | |||
| // } | |||
| //} | |||
| //Status status = new(); | |||
| //// TODO(Rinne): enable it. | |||
| //c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(), | |||
| // shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status); | |||
| //handles = null; | |||
| } | |||
| /// <summary> | |||
| @@ -162,24 +210,6 @@ namespace Tensorflow | |||
| throw new NotImplementedException(""); | |||
| } | |||
| private static HandleData get_eager_safe_handle_data(Tensor handle) | |||
| { | |||
| if (handle.Handle == null) | |||
| { | |||
| var data = new HandleData(); | |||
| data.ShapeAndType.Add(new HandleShapeAndType | |||
| { | |||
| Shape = handle.shape.as_shape_proto(), | |||
| Dtype = handle.dtype.as_datatype_enum() | |||
| }); | |||
| return data; | |||
| } | |||
| else | |||
| { | |||
| return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Copies an existing variable to a new graph, with no initializer. | |||
| /// </summary> | |||
| @@ -231,5 +261,60 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| } | |||
| public static void _maybe_set_handle_data(TF_DataType dtype, Tensor handle, Tensor tensor) | |||
| { | |||
| if(dtype == dtypes.variant) | |||
| { | |||
| var handle_data = get_eager_safe_handle_data(handle); | |||
| if(handle_data.IsSet && handle_data.ShapeAndType.Count > 1) | |||
| { | |||
| tensor.HandleData = new HandleData() | |||
| { | |||
| IsSet = true | |||
| }; | |||
| tensor.HandleData.ShapeAndType.AddRange(handle_data.ShapeAndType.Skip(1)); | |||
| } | |||
| } | |||
| } | |||
| public static HandleData get_eager_safe_handle_data(Tensor handle) | |||
| { | |||
| if (handle.Handle == null) | |||
| { | |||
| var data = new HandleData(); | |||
| data.ShapeAndType.Add(new HandleShapeAndType | |||
| { | |||
| Shape = handle.shape.as_shape_proto(), | |||
| Dtype = handle.dtype.as_datatype_enum() | |||
| }); | |||
| return data; | |||
| } | |||
| else | |||
| { | |||
| return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||
| } | |||
| //if(handle is EagerTensor) | |||
| //{ | |||
| // return handle.HandleData; | |||
| //} | |||
| //else | |||
| //{ | |||
| // return handle_data_util.get_resource_handle_data(handle); | |||
| //} | |||
| } | |||
| public static void variable_accessed(IVariableV1 variable) | |||
| { | |||
| if (ops.get_default_graph() is FuncGraph func_graph) | |||
| { | |||
| func_graph.watch_variable(variable); | |||
| } | |||
| if (variable.Trainable) | |||
| { | |||
| foreach (var tape in tf.GetTapeSet()) | |||
| tape.VariableAccessed(variable); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/core/framework/allocation_description.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| @@ -43,23 +43,31 @@ namespace Tensorflow { | |||
| } | |||
| #region Messages | |||
| public sealed partial class AllocationDescription : pb::IMessage<AllocationDescription> { | |||
| public sealed partial class AllocationDescription : pb::IMessage<AllocationDescription> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<AllocationDescription> _parser = new pb::MessageParser<AllocationDescription>(() => new AllocationDescription()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<AllocationDescription> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.AllocationDescriptionReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public AllocationDescription() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -67,6 +75,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public AllocationDescription(AllocationDescription other) : this() { | |||
| requestedBytes_ = other.requestedBytes_; | |||
| allocatedBytes_ = other.allocatedBytes_; | |||
| @@ -78,6 +87,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public AllocationDescription Clone() { | |||
| return new AllocationDescription(this); | |||
| } | |||
| @@ -89,6 +99,7 @@ namespace Tensorflow { | |||
| /// Total number of bytes requested | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long RequestedBytes { | |||
| get { return requestedBytes_; } | |||
| set { | |||
| @@ -103,6 +114,7 @@ namespace Tensorflow { | |||
| /// Total number of bytes allocated if known | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long AllocatedBytes { | |||
| get { return allocatedBytes_; } | |||
| set { | |||
| @@ -117,6 +129,7 @@ namespace Tensorflow { | |||
| /// Name of the allocator used | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string AllocatorName { | |||
| get { return allocatorName_; } | |||
| set { | |||
| @@ -131,6 +144,7 @@ namespace Tensorflow { | |||
| /// Identifier of the allocated buffer if known | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long AllocationId { | |||
| get { return allocationId_; } | |||
| set { | |||
| @@ -145,6 +159,7 @@ namespace Tensorflow { | |||
| /// Set if this tensor only has one remaining reference | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool HasSingleReference { | |||
| get { return hasSingleReference_; } | |||
| set { | |||
| @@ -159,6 +174,7 @@ namespace Tensorflow { | |||
| /// Address of the allocation. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ulong Ptr { | |||
| get { return ptr_; } | |||
| set { | |||
| @@ -167,11 +183,13 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as AllocationDescription); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(AllocationDescription other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -189,6 +207,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (RequestedBytes != 0L) hash ^= RequestedBytes.GetHashCode(); | |||
| @@ -204,12 +223,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (RequestedBytes != 0L) { | |||
| output.WriteRawTag(8); | |||
| output.WriteInt64(RequestedBytes); | |||
| @@ -237,9 +261,45 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (RequestedBytes != 0L) { | |||
| output.WriteRawTag(8); | |||
| output.WriteInt64(RequestedBytes); | |||
| } | |||
| if (AllocatedBytes != 0L) { | |||
| output.WriteRawTag(16); | |||
| output.WriteInt64(AllocatedBytes); | |||
| } | |||
| if (AllocatorName.Length != 0) { | |||
| output.WriteRawTag(26); | |||
| output.WriteString(AllocatorName); | |||
| } | |||
| if (AllocationId != 0L) { | |||
| output.WriteRawTag(32); | |||
| output.WriteInt64(AllocationId); | |||
| } | |||
| if (HasSingleReference != false) { | |||
| output.WriteRawTag(40); | |||
| output.WriteBool(HasSingleReference); | |||
| } | |||
| if (Ptr != 0UL) { | |||
| output.WriteRawTag(48); | |||
| output.WriteUInt64(Ptr); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (RequestedBytes != 0L) { | |||
| @@ -267,6 +327,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(AllocationDescription other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -293,7 +354,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -326,7 +391,47 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 8: { | |||
| RequestedBytes = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 16: { | |||
| AllocatedBytes = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 26: { | |||
| AllocatorName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| AllocationId = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 40: { | |||
| HasSingleReference = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 48: { | |||
| Ptr = input.ReadUInt64(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/core/framework/attr_value.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| @@ -63,23 +63,31 @@ namespace Tensorflow { | |||
| /// Comment indicates the corresponding attr type. Only the field matching the | |||
| /// attr type may be filled. | |||
| /// </summary> | |||
| public sealed partial class AttrValue : pb::IMessage<AttrValue> { | |||
| public sealed partial class AttrValue : pb::IMessage<AttrValue> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<AttrValue> _parser = new pb::MessageParser<AttrValue>(() => new AttrValue()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<AttrValue> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public AttrValue() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -87,6 +95,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public AttrValue(AttrValue other) : this() { | |||
| switch (other.ValueCase) { | |||
| case ValueOneofCase.S: | |||
| @@ -125,6 +134,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public AttrValue Clone() { | |||
| return new AttrValue(this); | |||
| } | |||
| @@ -135,6 +145,7 @@ namespace Tensorflow { | |||
| /// "string" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pb::ByteString S { | |||
| get { return valueCase_ == ValueOneofCase.S ? (pb::ByteString) value_ : pb::ByteString.Empty; } | |||
| set { | |||
| @@ -149,6 +160,7 @@ namespace Tensorflow { | |||
| /// "int" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long I { | |||
| get { return valueCase_ == ValueOneofCase.I ? (long) value_ : 0L; } | |||
| set { | |||
| @@ -163,6 +175,7 @@ namespace Tensorflow { | |||
| /// "float" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public float F { | |||
| get { return valueCase_ == ValueOneofCase.F ? (float) value_ : 0F; } | |||
| set { | |||
| @@ -177,6 +190,7 @@ namespace Tensorflow { | |||
| /// "bool" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool B { | |||
| get { return valueCase_ == ValueOneofCase.B ? (bool) value_ : false; } | |||
| set { | |||
| @@ -191,6 +205,7 @@ namespace Tensorflow { | |||
| /// "type" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.DataType Type { | |||
| get { return valueCase_ == ValueOneofCase.Type ? (global::Tensorflow.DataType) value_ : global::Tensorflow.DataType.DtInvalid; } | |||
| set { | |||
| @@ -205,6 +220,7 @@ namespace Tensorflow { | |||
| /// "shape" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.TensorShapeProto Shape { | |||
| get { return valueCase_ == ValueOneofCase.Shape ? (global::Tensorflow.TensorShapeProto) value_ : null; } | |||
| set { | |||
| @@ -219,6 +235,7 @@ namespace Tensorflow { | |||
| /// "tensor" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.TensorProto Tensor { | |||
| get { return valueCase_ == ValueOneofCase.Tensor ? (global::Tensorflow.TensorProto) value_ : null; } | |||
| set { | |||
| @@ -233,6 +250,7 @@ namespace Tensorflow { | |||
| /// any "list(...)" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.AttrValue.Types.ListValue List { | |||
| get { return valueCase_ == ValueOneofCase.List ? (global::Tensorflow.AttrValue.Types.ListValue) value_ : null; } | |||
| set { | |||
| @@ -250,6 +268,7 @@ namespace Tensorflow { | |||
| /// that attr in the instantiation. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.NameAttrList Func { | |||
| get { return valueCase_ == ValueOneofCase.Func ? (global::Tensorflow.NameAttrList) value_ : null; } | |||
| set { | |||
| @@ -270,6 +289,7 @@ namespace Tensorflow { | |||
| /// given the value "bar". | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string Placeholder { | |||
| get { return valueCase_ == ValueOneofCase.Placeholder ? (string) value_ : ""; } | |||
| set { | |||
| @@ -295,22 +315,26 @@ namespace Tensorflow { | |||
| } | |||
| private ValueOneofCase valueCase_ = ValueOneofCase.None; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ValueOneofCase ValueCase { | |||
| get { return valueCase_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void ClearValue() { | |||
| valueCase_ = ValueOneofCase.None; | |||
| value_ = null; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as AttrValue); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(AttrValue other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -333,6 +357,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (valueCase_ == ValueOneofCase.S) hash ^= S.GetHashCode(); | |||
| @@ -353,12 +378,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (valueCase_ == ValueOneofCase.List) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(List); | |||
| @@ -402,9 +432,61 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (valueCase_ == ValueOneofCase.List) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(List); | |||
| } | |||
| if (valueCase_ == ValueOneofCase.S) { | |||
| output.WriteRawTag(18); | |||
| output.WriteBytes(S); | |||
| } | |||
| if (valueCase_ == ValueOneofCase.I) { | |||
| output.WriteRawTag(24); | |||
| output.WriteInt64(I); | |||
| } | |||
| if (valueCase_ == ValueOneofCase.F) { | |||
| output.WriteRawTag(37); | |||
| output.WriteFloat(F); | |||
| } | |||
| if (valueCase_ == ValueOneofCase.B) { | |||
| output.WriteRawTag(40); | |||
| output.WriteBool(B); | |||
| } | |||
| if (valueCase_ == ValueOneofCase.Type) { | |||
| output.WriteRawTag(48); | |||
| output.WriteEnum((int) Type); | |||
| } | |||
| if (valueCase_ == ValueOneofCase.Shape) { | |||
| output.WriteRawTag(58); | |||
| output.WriteMessage(Shape); | |||
| } | |||
| if (valueCase_ == ValueOneofCase.Tensor) { | |||
| output.WriteRawTag(66); | |||
| output.WriteMessage(Tensor); | |||
| } | |||
| if (valueCase_ == ValueOneofCase.Placeholder) { | |||
| output.WriteRawTag(74); | |||
| output.WriteString(Placeholder); | |||
| } | |||
| if (valueCase_ == ValueOneofCase.Func) { | |||
| output.WriteRawTag(82); | |||
| output.WriteMessage(Func); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (valueCase_ == ValueOneofCase.S) { | |||
| @@ -444,6 +526,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(AttrValue other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -497,7 +580,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -567,32 +654,118 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| global::Tensorflow.AttrValue.Types.ListValue subBuilder = new global::Tensorflow.AttrValue.Types.ListValue(); | |||
| if (valueCase_ == ValueOneofCase.List) { | |||
| subBuilder.MergeFrom(List); | |||
| } | |||
| input.ReadMessage(subBuilder); | |||
| List = subBuilder; | |||
| break; | |||
| } | |||
| case 18: { | |||
| S = input.ReadBytes(); | |||
| break; | |||
| } | |||
| case 24: { | |||
| I = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 37: { | |||
| F = input.ReadFloat(); | |||
| break; | |||
| } | |||
| case 40: { | |||
| B = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 48: { | |||
| value_ = input.ReadEnum(); | |||
| valueCase_ = ValueOneofCase.Type; | |||
| break; | |||
| } | |||
| case 58: { | |||
| global::Tensorflow.TensorShapeProto subBuilder = new global::Tensorflow.TensorShapeProto(); | |||
| if (valueCase_ == ValueOneofCase.Shape) { | |||
| subBuilder.MergeFrom(Shape); | |||
| } | |||
| input.ReadMessage(subBuilder); | |||
| Shape = subBuilder; | |||
| break; | |||
| } | |||
| case 66: { | |||
| global::Tensorflow.TensorProto subBuilder = new global::Tensorflow.TensorProto(); | |||
| if (valueCase_ == ValueOneofCase.Tensor) { | |||
| subBuilder.MergeFrom(Tensor); | |||
| } | |||
| input.ReadMessage(subBuilder); | |||
| Tensor = subBuilder; | |||
| break; | |||
| } | |||
| case 74: { | |||
| Placeholder = input.ReadString(); | |||
| break; | |||
| } | |||
| case 82: { | |||
| global::Tensorflow.NameAttrList subBuilder = new global::Tensorflow.NameAttrList(); | |||
| if (valueCase_ == ValueOneofCase.Func) { | |||
| subBuilder.MergeFrom(Func); | |||
| } | |||
| input.ReadMessage(subBuilder); | |||
| Func = subBuilder; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| #region Nested types | |||
| /// <summary>Container for nested types declared in the AttrValue message type.</summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static partial class Types { | |||
| /// <summary> | |||
| /// LINT.IfChange | |||
| /// </summary> | |||
| public sealed partial class ListValue : pb::IMessage<ListValue> { | |||
| public sealed partial class ListValue : pb::IMessage<ListValue> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<ListValue> _parser = new pb::MessageParser<ListValue>(() => new ListValue()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<ListValue> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.AttrValue.Descriptor.NestedTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ListValue() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -600,6 +773,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ListValue(ListValue other) : this() { | |||
| s_ = other.s_.Clone(); | |||
| i_ = other.i_.Clone(); | |||
| @@ -613,6 +787,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ListValue Clone() { | |||
| return new ListValue(this); | |||
| } | |||
| @@ -626,6 +801,7 @@ namespace Tensorflow { | |||
| /// "list(string)" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<pb::ByteString> S { | |||
| get { return s_; } | |||
| } | |||
| @@ -639,6 +815,7 @@ namespace Tensorflow { | |||
| /// "list(int)" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<long> I { | |||
| get { return i_; } | |||
| } | |||
| @@ -652,6 +829,7 @@ namespace Tensorflow { | |||
| /// "list(float)" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<float> F { | |||
| get { return f_; } | |||
| } | |||
| @@ -665,6 +843,7 @@ namespace Tensorflow { | |||
| /// "list(bool)" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<bool> B { | |||
| get { return b_; } | |||
| } | |||
| @@ -678,6 +857,7 @@ namespace Tensorflow { | |||
| /// "list(type)" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.DataType> Type { | |||
| get { return type_; } | |||
| } | |||
| @@ -691,6 +871,7 @@ namespace Tensorflow { | |||
| /// "list(shape)" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.TensorShapeProto> Shape { | |||
| get { return shape_; } | |||
| } | |||
| @@ -704,6 +885,7 @@ namespace Tensorflow { | |||
| /// "list(tensor)" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.TensorProto> Tensor { | |||
| get { return tensor_; } | |||
| } | |||
| @@ -717,16 +899,19 @@ namespace Tensorflow { | |||
| /// "list(attr)" | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.NameAttrList> Func { | |||
| get { return func_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as ListValue); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(ListValue other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -746,6 +931,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= s_.GetHashCode(); | |||
| @@ -763,12 +949,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| s_.WriteTo(output, _repeated_s_codec); | |||
| i_.WriteTo(output, _repeated_i_codec); | |||
| f_.WriteTo(output, _repeated_f_codec); | |||
| @@ -780,9 +971,29 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| s_.WriteTo(ref output, _repeated_s_codec); | |||
| i_.WriteTo(ref output, _repeated_i_codec); | |||
| f_.WriteTo(ref output, _repeated_f_codec); | |||
| b_.WriteTo(ref output, _repeated_b_codec); | |||
| type_.WriteTo(ref output, _repeated_type_codec); | |||
| shape_.WriteTo(ref output, _repeated_shape_codec); | |||
| tensor_.WriteTo(ref output, _repeated_tensor_codec); | |||
| func_.WriteTo(ref output, _repeated_func_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += s_.CalculateSize(_repeated_s_codec); | |||
| @@ -800,6 +1011,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(ListValue other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -816,7 +1028,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -861,7 +1077,59 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 18: { | |||
| s_.AddEntriesFrom(ref input, _repeated_s_codec); | |||
| break; | |||
| } | |||
| case 26: | |||
| case 24: { | |||
| i_.AddEntriesFrom(ref input, _repeated_i_codec); | |||
| break; | |||
| } | |||
| case 34: | |||
| case 37: { | |||
| f_.AddEntriesFrom(ref input, _repeated_f_codec); | |||
| break; | |||
| } | |||
| case 42: | |||
| case 40: { | |||
| b_.AddEntriesFrom(ref input, _repeated_b_codec); | |||
| break; | |||
| } | |||
| case 50: | |||
| case 48: { | |||
| type_.AddEntriesFrom(ref input, _repeated_type_codec); | |||
| break; | |||
| } | |||
| case 58: { | |||
| shape_.AddEntriesFrom(ref input, _repeated_shape_codec); | |||
| break; | |||
| } | |||
| case 66: { | |||
| tensor_.AddEntriesFrom(ref input, _repeated_tensor_codec); | |||
| break; | |||
| } | |||
| case 74: { | |||
| func_.AddEntriesFrom(ref input, _repeated_func_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -874,23 +1142,31 @@ namespace Tensorflow { | |||
| /// A list of attr names and their values. The whole list is attached | |||
| /// with a string name. E.g., MatMul[T=float]. | |||
| /// </summary> | |||
| public sealed partial class NameAttrList : pb::IMessage<NameAttrList> { | |||
| public sealed partial class NameAttrList : pb::IMessage<NameAttrList> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<NameAttrList> _parser = new pb::MessageParser<NameAttrList>(() => new NameAttrList()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<NameAttrList> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.AttrValueReflection.Descriptor.MessageTypes[1]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public NameAttrList() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -898,6 +1174,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public NameAttrList(NameAttrList other) : this() { | |||
| name_ = other.name_; | |||
| attr_ = other.attr_.Clone(); | |||
| @@ -905,6 +1182,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public NameAttrList Clone() { | |||
| return new NameAttrList(this); | |||
| } | |||
| @@ -913,6 +1191,7 @@ namespace Tensorflow { | |||
| public const int NameFieldNumber = 1; | |||
| private string name_ = ""; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string Name { | |||
| get { return name_; } | |||
| set { | |||
| @@ -926,16 +1205,19 @@ namespace Tensorflow { | |||
| = new pbc::MapField<string, global::Tensorflow.AttrValue>.Codec(pb::FieldCodec.ForString(10, ""), pb::FieldCodec.ForMessage(18, global::Tensorflow.AttrValue.Parser), 18); | |||
| private readonly pbc::MapField<string, global::Tensorflow.AttrValue> attr_ = new pbc::MapField<string, global::Tensorflow.AttrValue>(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::MapField<string, global::Tensorflow.AttrValue> Attr { | |||
| get { return attr_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as NameAttrList); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(NameAttrList other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -949,6 +1231,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (Name.Length != 0) hash ^= Name.GetHashCode(); | |||
| @@ -960,12 +1243,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (Name.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Name); | |||
| @@ -974,9 +1262,26 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (Name.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Name); | |||
| } | |||
| attr_.WriteTo(ref output, _map_attr_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (Name.Length != 0) { | |||
| @@ -990,6 +1295,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(NameAttrList other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -1002,7 +1308,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -1019,7 +1329,31 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| Name = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| attr_.AddEntriesFrom(ref input, _map_attr_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/python/training/checkpoint_state.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| @@ -43,23 +43,31 @@ namespace Tensorflow { | |||
| /// <summary> | |||
| /// Protocol buffer representing the checkpoint state. | |||
| /// </summary> | |||
| public sealed partial class CheckpointState : pb::IMessage<CheckpointState> { | |||
| public sealed partial class CheckpointState : pb::IMessage<CheckpointState> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<CheckpointState> _parser = new pb::MessageParser<CheckpointState>(() => new CheckpointState()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<CheckpointState> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.CheckpointStateReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CheckpointState() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -67,6 +75,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CheckpointState(CheckpointState other) : this() { | |||
| modelCheckpointPath_ = other.modelCheckpointPath_; | |||
| allModelCheckpointPaths_ = other.allModelCheckpointPaths_.Clone(); | |||
| @@ -76,6 +85,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CheckpointState Clone() { | |||
| return new CheckpointState(this); | |||
| } | |||
| @@ -87,6 +97,7 @@ namespace Tensorflow { | |||
| /// Path to the most-recent model checkpoint. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string ModelCheckpointPath { | |||
| get { return modelCheckpointPath_; } | |||
| set { | |||
| @@ -106,6 +117,7 @@ namespace Tensorflow { | |||
| /// this list. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<string> AllModelCheckpointPaths { | |||
| get { return allModelCheckpointPaths_; } | |||
| } | |||
| @@ -120,6 +132,7 @@ namespace Tensorflow { | |||
| /// when each checkpoint was created. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<double> AllModelCheckpointTimestamps { | |||
| get { return allModelCheckpointTimestamps_; } | |||
| } | |||
| @@ -132,6 +145,7 @@ namespace Tensorflow { | |||
| /// checkpoint. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public double LastPreservedTimestamp { | |||
| get { return lastPreservedTimestamp_; } | |||
| set { | |||
| @@ -140,11 +154,13 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as CheckpointState); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(CheckpointState other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -160,6 +176,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (ModelCheckpointPath.Length != 0) hash ^= ModelCheckpointPath.GetHashCode(); | |||
| @@ -173,12 +190,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (ModelCheckpointPath.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(ModelCheckpointPath); | |||
| @@ -192,9 +214,31 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (ModelCheckpointPath.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(ModelCheckpointPath); | |||
| } | |||
| allModelCheckpointPaths_.WriteTo(ref output, _repeated_allModelCheckpointPaths_codec); | |||
| allModelCheckpointTimestamps_.WriteTo(ref output, _repeated_allModelCheckpointTimestamps_codec); | |||
| if (LastPreservedTimestamp != 0D) { | |||
| output.WriteRawTag(33); | |||
| output.WriteDouble(LastPreservedTimestamp); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (ModelCheckpointPath.Length != 0) { | |||
| @@ -212,6 +256,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(CheckpointState other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -228,7 +273,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -254,7 +303,40 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| ModelCheckpointPath = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| allModelCheckpointPaths_.AddEntriesFrom(ref input, _repeated_allModelCheckpointPaths_codec); | |||
| break; | |||
| } | |||
| case 26: | |||
| case 25: { | |||
| allModelCheckpointTimestamps_.AddEntriesFrom(ref input, _repeated_allModelCheckpointTimestamps_codec); | |||
| break; | |||
| } | |||
| case 33: { | |||
| LastPreservedTimestamp = input.ReadDouble(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/core/protobuf/cluster.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| @@ -47,23 +47,31 @@ namespace Tensorflow { | |||
| /// <summary> | |||
| /// Defines a single job in a TensorFlow cluster. | |||
| /// </summary> | |||
| public sealed partial class JobDef : pb::IMessage<JobDef> { | |||
| public sealed partial class JobDef : pb::IMessage<JobDef> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<JobDef> _parser = new pb::MessageParser<JobDef>(() => new JobDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<JobDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public JobDef() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -71,6 +79,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public JobDef(JobDef other) : this() { | |||
| name_ = other.name_; | |||
| tasks_ = other.tasks_.Clone(); | |||
| @@ -78,6 +87,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public JobDef Clone() { | |||
| return new JobDef(this); | |||
| } | |||
| @@ -89,6 +99,7 @@ namespace Tensorflow { | |||
| /// The name of this job. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string Name { | |||
| get { return name_; } | |||
| set { | |||
| @@ -109,16 +120,19 @@ namespace Tensorflow { | |||
| /// "/job:worker/task:7" will be assigned to "example.org:2222". | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::MapField<int, string> Tasks { | |||
| get { return tasks_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as JobDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(JobDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -132,6 +146,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (Name.Length != 0) hash ^= Name.GetHashCode(); | |||
| @@ -143,12 +158,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (Name.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Name); | |||
| @@ -157,9 +177,26 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (Name.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Name); | |||
| } | |||
| tasks_.WriteTo(ref output, _map_tasks_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (Name.Length != 0) { | |||
| @@ -173,6 +210,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(JobDef other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -185,7 +223,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -202,30 +244,62 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| Name = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| tasks_.AddEntriesFrom(ref input, _map_tasks_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| /// <summary> | |||
| /// Defines a TensorFlow cluster as a set of jobs. | |||
| /// </summary> | |||
| public sealed partial class ClusterDef : pb::IMessage<ClusterDef> { | |||
| public sealed partial class ClusterDef : pb::IMessage<ClusterDef> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<ClusterDef> _parser = new pb::MessageParser<ClusterDef>(() => new ClusterDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<ClusterDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.ClusterReflection.Descriptor.MessageTypes[1]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ClusterDef() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -233,12 +307,14 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ClusterDef(ClusterDef other) : this() { | |||
| job_ = other.job_.Clone(); | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ClusterDef Clone() { | |||
| return new ClusterDef(this); | |||
| } | |||
| @@ -252,16 +328,19 @@ namespace Tensorflow { | |||
| /// The jobs that comprise the cluster. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.JobDef> Job { | |||
| get { return job_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as ClusterDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(ClusterDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -274,6 +353,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= job_.GetHashCode(); | |||
| @@ -284,19 +364,37 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| job_.WriteTo(output, _repeated_job_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| job_.WriteTo(ref output, _repeated_job_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += job_.CalculateSize(_repeated_job_codec); | |||
| @@ -307,6 +405,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(ClusterDef other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -316,7 +415,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -329,7 +432,27 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| job_.AddEntriesFrom(ref input, _repeated_job_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/core/protobuf/control_flow.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| @@ -64,23 +64,31 @@ namespace Tensorflow { | |||
| /// <summary> | |||
| /// Protocol buffer representing the values in ControlFlowContext. | |||
| /// </summary> | |||
| public sealed partial class ValuesDef : pb::IMessage<ValuesDef> { | |||
| public sealed partial class ValuesDef : pb::IMessage<ValuesDef> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<ValuesDef> _parser = new pb::MessageParser<ValuesDef>(() => new ValuesDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<ValuesDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ValuesDef() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -88,6 +96,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ValuesDef(ValuesDef other) : this() { | |||
| values_ = other.values_.Clone(); | |||
| externalValues_ = other.externalValues_.Clone(); | |||
| @@ -95,6 +104,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ValuesDef Clone() { | |||
| return new ValuesDef(this); | |||
| } | |||
| @@ -108,6 +118,7 @@ namespace Tensorflow { | |||
| /// Value names that have been seen in this context. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<string> Values { | |||
| get { return values_; } | |||
| } | |||
| @@ -121,16 +132,19 @@ namespace Tensorflow { | |||
| /// Value names referenced by but external to this context. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::MapField<string, string> ExternalValues { | |||
| get { return externalValues_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as ValuesDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(ValuesDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -144,6 +158,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= values_.GetHashCode(); | |||
| @@ -155,20 +170,39 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| values_.WriteTo(output, _repeated_values_codec); | |||
| externalValues_.WriteTo(output, _map_externalValues_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| values_.WriteTo(ref output, _repeated_values_codec); | |||
| externalValues_.WriteTo(ref output, _map_externalValues_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += values_.CalculateSize(_repeated_values_codec); | |||
| @@ -180,6 +214,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(ValuesDef other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -190,7 +225,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -207,7 +246,31 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| values_.AddEntriesFrom(ref input, _repeated_values_codec); | |||
| break; | |||
| } | |||
| case 18: { | |||
| externalValues_.AddEntriesFrom(ref input, _map_externalValues_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -215,23 +278,31 @@ namespace Tensorflow { | |||
| /// Container for any kind of control flow context. Any other control flow | |||
| /// contexts that are added below should also be added here. | |||
| /// </summary> | |||
| public sealed partial class ControlFlowContextDef : pb::IMessage<ControlFlowContextDef> { | |||
| public sealed partial class ControlFlowContextDef : pb::IMessage<ControlFlowContextDef> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<ControlFlowContextDef> _parser = new pb::MessageParser<ControlFlowContextDef>(() => new ControlFlowContextDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<ControlFlowContextDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[1]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ControlFlowContextDef() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -239,6 +310,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ControlFlowContextDef(ControlFlowContextDef other) : this() { | |||
| switch (other.CtxtCase) { | |||
| case CtxtOneofCase.CondCtxt: | |||
| @@ -253,6 +325,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ControlFlowContextDef Clone() { | |||
| return new ControlFlowContextDef(this); | |||
| } | |||
| @@ -260,6 +333,7 @@ namespace Tensorflow { | |||
| /// <summary>Field number for the "cond_ctxt" field.</summary> | |||
| public const int CondCtxtFieldNumber = 1; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.CondContextDef CondCtxt { | |||
| get { return ctxtCase_ == CtxtOneofCase.CondCtxt ? (global::Tensorflow.CondContextDef) ctxt_ : null; } | |||
| set { | |||
| @@ -271,6 +345,7 @@ namespace Tensorflow { | |||
| /// <summary>Field number for the "while_ctxt" field.</summary> | |||
| public const int WhileCtxtFieldNumber = 2; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.WhileContextDef WhileCtxt { | |||
| get { return ctxtCase_ == CtxtOneofCase.WhileCtxt ? (global::Tensorflow.WhileContextDef) ctxt_ : null; } | |||
| set { | |||
| @@ -288,22 +363,26 @@ namespace Tensorflow { | |||
| } | |||
| private CtxtOneofCase ctxtCase_ = CtxtOneofCase.None; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CtxtOneofCase CtxtCase { | |||
| get { return ctxtCase_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void ClearCtxt() { | |||
| ctxtCase_ = CtxtOneofCase.None; | |||
| ctxt_ = null; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as ControlFlowContextDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(ControlFlowContextDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -318,6 +397,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (ctxtCase_ == CtxtOneofCase.CondCtxt) hash ^= CondCtxt.GetHashCode(); | |||
| @@ -330,12 +410,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(CondCtxt); | |||
| @@ -347,9 +432,29 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(CondCtxt); | |||
| } | |||
| if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { | |||
| output.WriteRawTag(18); | |||
| output.WriteMessage(WhileCtxt); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||
| @@ -365,6 +470,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(ControlFlowContextDef other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -388,7 +494,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -415,30 +525,72 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| global::Tensorflow.CondContextDef subBuilder = new global::Tensorflow.CondContextDef(); | |||
| if (ctxtCase_ == CtxtOneofCase.CondCtxt) { | |||
| subBuilder.MergeFrom(CondCtxt); | |||
| } | |||
| input.ReadMessage(subBuilder); | |||
| CondCtxt = subBuilder; | |||
| break; | |||
| } | |||
| case 18: { | |||
| global::Tensorflow.WhileContextDef subBuilder = new global::Tensorflow.WhileContextDef(); | |||
| if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { | |||
| subBuilder.MergeFrom(WhileCtxt); | |||
| } | |||
| input.ReadMessage(subBuilder); | |||
| WhileCtxt = subBuilder; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| /// <summary> | |||
| /// Protocol buffer representing a CondContext object. | |||
| /// </summary> | |||
| public sealed partial class CondContextDef : pb::IMessage<CondContextDef> { | |||
| public sealed partial class CondContextDef : pb::IMessage<CondContextDef> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<CondContextDef> _parser = new pb::MessageParser<CondContextDef>(() => new CondContextDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<CondContextDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[2]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CondContextDef() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -446,6 +598,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CondContextDef(CondContextDef other) : this() { | |||
| contextName_ = other.contextName_; | |||
| predName_ = other.predName_; | |||
| @@ -457,6 +610,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CondContextDef Clone() { | |||
| return new CondContextDef(this); | |||
| } | |||
| @@ -468,6 +622,7 @@ namespace Tensorflow { | |||
| /// Name of the context. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string ContextName { | |||
| get { return contextName_; } | |||
| set { | |||
| @@ -482,6 +637,7 @@ namespace Tensorflow { | |||
| /// Name of the pred tensor. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string PredName { | |||
| get { return predName_; } | |||
| set { | |||
| @@ -496,6 +652,7 @@ namespace Tensorflow { | |||
| /// Name of the pivot tensor. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string PivotName { | |||
| get { return pivotName_; } | |||
| set { | |||
| @@ -510,6 +667,7 @@ namespace Tensorflow { | |||
| /// Branch prediction. 0 or 1. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int Branch { | |||
| get { return branch_; } | |||
| set { | |||
| @@ -524,6 +682,7 @@ namespace Tensorflow { | |||
| /// Values and external values in control flow context. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.ValuesDef ValuesDef { | |||
| get { return valuesDef_; } | |||
| set { | |||
| @@ -540,16 +699,19 @@ namespace Tensorflow { | |||
| /// Contexts contained inside this context (e.g. nested conds). | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts { | |||
| get { return nestedContexts_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as CondContextDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(CondContextDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -567,6 +729,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); | |||
| @@ -582,12 +745,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (ContextName.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(ContextName); | |||
| @@ -612,9 +780,42 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (ContextName.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(ContextName); | |||
| } | |||
| if (PredName.Length != 0) { | |||
| output.WriteRawTag(18); | |||
| output.WriteString(PredName); | |||
| } | |||
| if (PivotName.Length != 0) { | |||
| output.WriteRawTag(26); | |||
| output.WriteString(PivotName); | |||
| } | |||
| if (Branch != 0) { | |||
| output.WriteRawTag(32); | |||
| output.WriteInt32(Branch); | |||
| } | |||
| if (valuesDef_ != null) { | |||
| output.WriteRawTag(42); | |||
| output.WriteMessage(ValuesDef); | |||
| } | |||
| nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (ContextName.Length != 0) { | |||
| @@ -640,6 +841,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(CondContextDef other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -667,7 +869,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -703,30 +909,81 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| ContextName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| PredName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 26: { | |||
| PivotName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| Branch = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 42: { | |||
| if (valuesDef_ == null) { | |||
| ValuesDef = new global::Tensorflow.ValuesDef(); | |||
| } | |||
| input.ReadMessage(ValuesDef); | |||
| break; | |||
| } | |||
| case 50: { | |||
| nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| /// <summary> | |||
| /// Protocol buffer representing a WhileContext object. | |||
| /// </summary> | |||
| public sealed partial class WhileContextDef : pb::IMessage<WhileContextDef> { | |||
| public sealed partial class WhileContextDef : pb::IMessage<WhileContextDef> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<WhileContextDef> _parser = new pb::MessageParser<WhileContextDef>(() => new WhileContextDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<WhileContextDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[3]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public WhileContextDef() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -734,6 +991,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public WhileContextDef(WhileContextDef other) : this() { | |||
| contextName_ = other.contextName_; | |||
| parallelIterations_ = other.parallelIterations_; | |||
| @@ -751,6 +1009,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public WhileContextDef Clone() { | |||
| return new WhileContextDef(this); | |||
| } | |||
| @@ -762,6 +1021,7 @@ namespace Tensorflow { | |||
| /// Name of the context. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string ContextName { | |||
| get { return contextName_; } | |||
| set { | |||
| @@ -776,6 +1036,7 @@ namespace Tensorflow { | |||
| /// The number of iterations allowed to run in parallel. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int ParallelIterations { | |||
| get { return parallelIterations_; } | |||
| set { | |||
| @@ -790,6 +1051,7 @@ namespace Tensorflow { | |||
| /// Whether backprop is enabled for this while loop. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool BackProp { | |||
| get { return backProp_; } | |||
| set { | |||
| @@ -804,6 +1066,7 @@ namespace Tensorflow { | |||
| /// Whether GPU-CPU memory swap is enabled for this loop. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool SwapMemory { | |||
| get { return swapMemory_; } | |||
| set { | |||
| @@ -818,6 +1081,7 @@ namespace Tensorflow { | |||
| /// Name of the pivot tensor. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string PivotName { | |||
| get { return pivotName_; } | |||
| set { | |||
| @@ -832,6 +1096,7 @@ namespace Tensorflow { | |||
| /// Name of the pivot_for_pred tensor. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string PivotForPredName { | |||
| get { return pivotForPredName_; } | |||
| set { | |||
| @@ -846,6 +1111,7 @@ namespace Tensorflow { | |||
| /// Name of the pivot_for_body tensor. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string PivotForBodyName { | |||
| get { return pivotForBodyName_; } | |||
| set { | |||
| @@ -862,6 +1128,7 @@ namespace Tensorflow { | |||
| /// List of names for exit tensors. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<string> LoopExitNames { | |||
| get { return loopExitNames_; } | |||
| } | |||
| @@ -875,6 +1142,7 @@ namespace Tensorflow { | |||
| /// List of names for enter tensors. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<string> LoopEnterNames { | |||
| get { return loopEnterNames_; } | |||
| } | |||
| @@ -886,6 +1154,7 @@ namespace Tensorflow { | |||
| /// Values and external values in control flow context. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.ValuesDef ValuesDef { | |||
| get { return valuesDef_; } | |||
| set { | |||
| @@ -900,6 +1169,7 @@ namespace Tensorflow { | |||
| /// Optional name of the maximum_iterations tensor. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string MaximumIterationsName { | |||
| get { return maximumIterationsName_; } | |||
| set { | |||
| @@ -916,16 +1186,19 @@ namespace Tensorflow { | |||
| /// Contexts contained inside this context (e.g. nested whiles). | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.ControlFlowContextDef> NestedContexts { | |||
| get { return nestedContexts_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as WhileContextDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(WhileContextDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -949,6 +1222,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); | |||
| @@ -970,12 +1244,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (ContextName.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(ContextName); | |||
| @@ -1018,9 +1297,60 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (ContextName.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(ContextName); | |||
| } | |||
| if (ParallelIterations != 0) { | |||
| output.WriteRawTag(16); | |||
| output.WriteInt32(ParallelIterations); | |||
| } | |||
| if (BackProp != false) { | |||
| output.WriteRawTag(24); | |||
| output.WriteBool(BackProp); | |||
| } | |||
| if (SwapMemory != false) { | |||
| output.WriteRawTag(32); | |||
| output.WriteBool(SwapMemory); | |||
| } | |||
| if (PivotName.Length != 0) { | |||
| output.WriteRawTag(42); | |||
| output.WriteString(PivotName); | |||
| } | |||
| if (PivotForPredName.Length != 0) { | |||
| output.WriteRawTag(50); | |||
| output.WriteString(PivotForPredName); | |||
| } | |||
| if (PivotForBodyName.Length != 0) { | |||
| output.WriteRawTag(58); | |||
| output.WriteString(PivotForBodyName); | |||
| } | |||
| loopExitNames_.WriteTo(ref output, _repeated_loopExitNames_codec); | |||
| if (valuesDef_ != null) { | |||
| output.WriteRawTag(74); | |||
| output.WriteMessage(ValuesDef); | |||
| } | |||
| loopEnterNames_.WriteTo(ref output, _repeated_loopEnterNames_codec); | |||
| if (MaximumIterationsName.Length != 0) { | |||
| output.WriteRawTag(90); | |||
| output.WriteString(MaximumIterationsName); | |||
| } | |||
| nestedContexts_.WriteTo(ref output, _repeated_nestedContexts_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (ContextName.Length != 0) { | |||
| @@ -1060,6 +1390,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(WhileContextDef other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -1101,7 +1432,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -1161,7 +1496,74 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| ContextName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 16: { | |||
| ParallelIterations = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 24: { | |||
| BackProp = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| SwapMemory = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 42: { | |||
| PivotName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 50: { | |||
| PivotForPredName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 58: { | |||
| PivotForBodyName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 66: { | |||
| loopExitNames_.AddEntriesFrom(ref input, _repeated_loopExitNames_codec); | |||
| break; | |||
| } | |||
| case 74: { | |||
| if (valuesDef_ == null) { | |||
| ValuesDef = new global::Tensorflow.ValuesDef(); | |||
| } | |||
| input.ReadMessage(ValuesDef); | |||
| break; | |||
| } | |||
| case 82: { | |||
| loopEnterNames_.AddEntriesFrom(ref input, _repeated_loopEnterNames_codec); | |||
| break; | |||
| } | |||
| case 90: { | |||
| MaximumIterationsName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 98: { | |||
| nestedContexts_.AddEntriesFrom(ref input, _repeated_nestedContexts_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -0,0 +1,791 @@ | |||
| // <auto-generated> | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/core/protobuf/coordination_config.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using pbr = global::Google.Protobuf.Reflection; | |||
| using scg = global::System.Collections.Generic; | |||
| namespace Tensorflow { | |||
| /// <summary>Holder for reflection information generated from tensorflow/core/protobuf/coordination_config.proto</summary> | |||
| public static partial class CoordinationConfigReflection { | |||
| #region Descriptor | |||
| /// <summary>File descriptor for tensorflow/core/protobuf/coordination_config.proto</summary> | |||
| public static pbr::FileDescriptor Descriptor { | |||
| get { return descriptor; } | |||
| } | |||
| private static pbr::FileDescriptor descriptor; | |||
| static CoordinationConfigReflection() { | |||
| byte[] descriptorData = global::System.Convert.FromBase64String( | |||
| string.Concat( | |||
| "CjJ0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29vcmRpbmF0aW9uX2NvbmZp", | |||
| "Zy5wcm90bxIKdGVuc29yZmxvdyIxCg5Db29yZGluYXRlZEpvYhIMCgRuYW1l", | |||
| "GAEgASgJEhEKCW51bV90YXNrcxgCIAEoBSLdAgoZQ29vcmRpbmF0aW9uU2Vy", | |||
| "dmljZUNvbmZpZxIUCgxzZXJ2aWNlX3R5cGUYASABKAkSFgoOc2VydmljZV9s", | |||
| "ZWFkZXIYAiABKAkSGwoTZW5hYmxlX2hlYWx0aF9jaGVjaxgDIAEoCBImCh5j", | |||
| "bHVzdGVyX3JlZ2lzdGVyX3RpbWVvdXRfaW5fbXMYBCABKAMSHwoXaGVhcnRi", | |||
| "ZWF0X3RpbWVvdXRfaW5fbXMYBSABKAMSOAoUY29vcmRpbmF0ZWRfam9iX2xp", | |||
| "c3QYCiADKAsyGi50ZW5zb3JmbG93LkNvb3JkaW5hdGVkSm9iEiYKHnNodXRk", | |||
| "b3duX2JhcnJpZXJfdGltZW91dF9pbl9tcxgHIAEoAxIqCiJhZ2VudF9kZXN0", | |||
| "cnVjdGlvbl93aXRob3V0X3NodXRkb3duGAggASgIEhgKEHJlY292ZXJhYmxl", | |||
| "X2pvYnMYCSADKAlKBAgGEAdCV1pVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", | |||
| "bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3Jl", | |||
| "X3Byb3Rvc19nb19wcm90b2IGcHJvdG8z")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { }, | |||
| new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinatedJob), global::Tensorflow.CoordinatedJob.Parser, new[]{ "Name", "NumTasks" }, null, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CoordinationServiceConfig), global::Tensorflow.CoordinationServiceConfig.Parser, new[]{ "ServiceType", "ServiceLeader", "EnableHealthCheck", "ClusterRegisterTimeoutInMs", "HeartbeatTimeoutInMs", "CoordinatedJobList", "ShutdownBarrierTimeoutInMs", "AgentDestructionWithoutShutdown", "RecoverableJobs" }, null, null, null, null) | |||
| })); | |||
| } | |||
| #endregion | |||
| } | |||
| #region Messages | |||
| /// <summary> | |||
| /// Represents a job type and the number of tasks under this job. | |||
| /// For example, ("worker", 20) implies that there will be 20 worker tasks. | |||
| /// </summary> | |||
| public sealed partial class CoordinatedJob : pb::IMessage<CoordinatedJob> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<CoordinatedJob> _parser = new pb::MessageParser<CoordinatedJob>(() => new CoordinatedJob()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<CoordinatedJob> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CoordinatedJob() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CoordinatedJob(CoordinatedJob other) : this() { | |||
| name_ = other.name_; | |||
| numTasks_ = other.numTasks_; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CoordinatedJob Clone() { | |||
| return new CoordinatedJob(this); | |||
| } | |||
| /// <summary>Field number for the "name" field.</summary> | |||
| public const int NameFieldNumber = 1; | |||
| private string name_ = ""; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string Name { | |||
| get { return name_; } | |||
| set { | |||
| name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
| } | |||
| } | |||
| /// <summary>Field number for the "num_tasks" field.</summary> | |||
| public const int NumTasksFieldNumber = 2; | |||
| private int numTasks_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int NumTasks { | |||
| get { return numTasks_; } | |||
| set { | |||
| numTasks_ = value; | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as CoordinatedJob); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(CoordinatedJob other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if (Name != other.Name) return false; | |||
| if (NumTasks != other.NumTasks) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (Name.Length != 0) hash ^= Name.GetHashCode(); | |||
| if (NumTasks != 0) hash ^= NumTasks.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (Name.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Name); | |||
| } | |||
| if (NumTasks != 0) { | |||
| output.WriteRawTag(16); | |||
| output.WriteInt32(NumTasks); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (Name.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Name); | |||
| } | |||
| if (NumTasks != 0) { | |||
| output.WriteRawTag(16); | |||
| output.WriteInt32(NumTasks); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (Name.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); | |||
| } | |||
| if (NumTasks != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumTasks); | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(CoordinatedJob other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| if (other.Name.Length != 0) { | |||
| Name = other.Name; | |||
| } | |||
| if (other.NumTasks != 0) { | |||
| NumTasks = other.NumTasks; | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| Name = input.ReadString(); | |||
| break; | |||
| } | |||
| case 16: { | |||
| NumTasks = input.ReadInt32(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| Name = input.ReadString(); | |||
| break; | |||
| } | |||
| case 16: { | |||
| NumTasks = input.ReadInt32(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| /// <summary> | |||
| /// Coordination service configuration parameters. | |||
| /// The system picks appropriate values for fields that are not set. | |||
| /// </summary> | |||
| public sealed partial class CoordinationServiceConfig : pb::IMessage<CoordinationServiceConfig> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<CoordinationServiceConfig> _parser = new pb::MessageParser<CoordinationServiceConfig>(() => new CoordinationServiceConfig()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<CoordinationServiceConfig> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.CoordinationConfigReflection.Descriptor.MessageTypes[1]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CoordinationServiceConfig() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CoordinationServiceConfig(CoordinationServiceConfig other) : this() { | |||
| serviceType_ = other.serviceType_; | |||
| serviceLeader_ = other.serviceLeader_; | |||
| enableHealthCheck_ = other.enableHealthCheck_; | |||
| clusterRegisterTimeoutInMs_ = other.clusterRegisterTimeoutInMs_; | |||
| heartbeatTimeoutInMs_ = other.heartbeatTimeoutInMs_; | |||
| coordinatedJobList_ = other.coordinatedJobList_.Clone(); | |||
| shutdownBarrierTimeoutInMs_ = other.shutdownBarrierTimeoutInMs_; | |||
| agentDestructionWithoutShutdown_ = other.agentDestructionWithoutShutdown_; | |||
| recoverableJobs_ = other.recoverableJobs_.Clone(); | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CoordinationServiceConfig Clone() { | |||
| return new CoordinationServiceConfig(this); | |||
| } | |||
| /// <summary>Field number for the "service_type" field.</summary> | |||
| public const int ServiceTypeFieldNumber = 1; | |||
| private string serviceType_ = ""; | |||
| /// <summary> | |||
| /// Type of coordination service implementation to enable. | |||
| /// For example, setting the service type as "standalone" starts a service | |||
| /// instance on the leader task to provide the coordination services such as | |||
| /// heartbeats and consistent key-value store. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string ServiceType { | |||
| get { return serviceType_; } | |||
| set { | |||
| serviceType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
| } | |||
| } | |||
| /// <summary>Field number for the "service_leader" field.</summary> | |||
| public const int ServiceLeaderFieldNumber = 2; | |||
| private string serviceLeader_ = ""; | |||
| /// <summary> | |||
| /// Address where the coordination service instance is hosted. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string ServiceLeader { | |||
| get { return serviceLeader_; } | |||
| set { | |||
| serviceLeader_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); | |||
| } | |||
| } | |||
| /// <summary>Field number for the "enable_health_check" field.</summary> | |||
| public const int EnableHealthCheckFieldNumber = 3; | |||
| private bool enableHealthCheck_; | |||
| /// <summary> | |||
| /// Whether to enable the health check mechanism. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool EnableHealthCheck { | |||
| get { return enableHealthCheck_; } | |||
| set { | |||
| enableHealthCheck_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "cluster_register_timeout_in_ms" field.</summary> | |||
| public const int ClusterRegisterTimeoutInMsFieldNumber = 4; | |||
| private long clusterRegisterTimeoutInMs_; | |||
| /// <summary> | |||
| /// Maximum wait time for all members in the cluster to be registered. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long ClusterRegisterTimeoutInMs { | |||
| get { return clusterRegisterTimeoutInMs_; } | |||
| set { | |||
| clusterRegisterTimeoutInMs_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "heartbeat_timeout_in_ms" field.</summary> | |||
| public const int HeartbeatTimeoutInMsFieldNumber = 5; | |||
| private long heartbeatTimeoutInMs_; | |||
| /// <summary> | |||
| /// Heartbeat timeout, if a task does not record heartbeat in this time | |||
| /// window, it will be considered disconnected. | |||
| /// Note: This is also used as a grace period to accept any heartbeats after | |||
| /// the agent has disconnected, to account for the lag time between the service | |||
| /// recording the state change and the agent stopping heartbeats. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long HeartbeatTimeoutInMs { | |||
| get { return heartbeatTimeoutInMs_; } | |||
| set { | |||
| heartbeatTimeoutInMs_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "coordinated_job_list" field.</summary> | |||
| public const int CoordinatedJobListFieldNumber = 10; | |||
| private static readonly pb::FieldCodec<global::Tensorflow.CoordinatedJob> _repeated_coordinatedJobList_codec | |||
| = pb::FieldCodec.ForMessage(82, global::Tensorflow.CoordinatedJob.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.CoordinatedJob> coordinatedJobList_ = new pbc::RepeatedField<global::Tensorflow.CoordinatedJob>(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.CoordinatedJob> CoordinatedJobList { | |||
| get { return coordinatedJobList_; } | |||
| } | |||
| /// <summary>Field number for the "shutdown_barrier_timeout_in_ms" field.</summary> | |||
| public const int ShutdownBarrierTimeoutInMsFieldNumber = 7; | |||
| private long shutdownBarrierTimeoutInMs_; | |||
| /// <summary> | |||
| /// Denotes how long to wait for all coordination agents to reach the barriers | |||
| /// (after the first shutdown request) before disconnecting together. If | |||
| /// set to 0, no barrier is imposed upon shutdown and each worker can | |||
| /// disconnect individually. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long ShutdownBarrierTimeoutInMs { | |||
| get { return shutdownBarrierTimeoutInMs_; } | |||
| set { | |||
| shutdownBarrierTimeoutInMs_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "agent_destruction_without_shutdown" field.</summary> | |||
| public const int AgentDestructionWithoutShutdownFieldNumber = 8; | |||
| private bool agentDestructionWithoutShutdown_; | |||
| /// <summary> | |||
| /// If set, agents do not make an explicit Shutdown() call. Service will only | |||
| /// find out about the disconnecte agent via stale heartbeats. Used for | |||
| /// testing. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool AgentDestructionWithoutShutdown { | |||
| get { return agentDestructionWithoutShutdown_; } | |||
| set { | |||
| agentDestructionWithoutShutdown_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "recoverable_jobs" field.</summary> | |||
| public const int RecoverableJobsFieldNumber = 9; | |||
| private static readonly pb::FieldCodec<string> _repeated_recoverableJobs_codec | |||
| = pb::FieldCodec.ForString(74); | |||
| private readonly pbc::RepeatedField<string> recoverableJobs_ = new pbc::RepeatedField<string>(); | |||
| /// <summary> | |||
| /// The list of jobs which are recoverable. If a task in this list fails, | |||
| /// it will not propagate error to other tasks. | |||
| /// If empty, no jobs will be recoverable and every task failure will cause | |||
| /// error propagation to other tasks. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<string> RecoverableJobs { | |||
| get { return recoverableJobs_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as CoordinationServiceConfig); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(CoordinationServiceConfig other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if (ServiceType != other.ServiceType) return false; | |||
| if (ServiceLeader != other.ServiceLeader) return false; | |||
| if (EnableHealthCheck != other.EnableHealthCheck) return false; | |||
| if (ClusterRegisterTimeoutInMs != other.ClusterRegisterTimeoutInMs) return false; | |||
| if (HeartbeatTimeoutInMs != other.HeartbeatTimeoutInMs) return false; | |||
| if(!coordinatedJobList_.Equals(other.coordinatedJobList_)) return false; | |||
| if (ShutdownBarrierTimeoutInMs != other.ShutdownBarrierTimeoutInMs) return false; | |||
| if (AgentDestructionWithoutShutdown != other.AgentDestructionWithoutShutdown) return false; | |||
| if(!recoverableJobs_.Equals(other.recoverableJobs_)) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (ServiceType.Length != 0) hash ^= ServiceType.GetHashCode(); | |||
| if (ServiceLeader.Length != 0) hash ^= ServiceLeader.GetHashCode(); | |||
| if (EnableHealthCheck != false) hash ^= EnableHealthCheck.GetHashCode(); | |||
| if (ClusterRegisterTimeoutInMs != 0L) hash ^= ClusterRegisterTimeoutInMs.GetHashCode(); | |||
| if (HeartbeatTimeoutInMs != 0L) hash ^= HeartbeatTimeoutInMs.GetHashCode(); | |||
| hash ^= coordinatedJobList_.GetHashCode(); | |||
| if (ShutdownBarrierTimeoutInMs != 0L) hash ^= ShutdownBarrierTimeoutInMs.GetHashCode(); | |||
| if (AgentDestructionWithoutShutdown != false) hash ^= AgentDestructionWithoutShutdown.GetHashCode(); | |||
| hash ^= recoverableJobs_.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (ServiceType.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(ServiceType); | |||
| } | |||
| if (ServiceLeader.Length != 0) { | |||
| output.WriteRawTag(18); | |||
| output.WriteString(ServiceLeader); | |||
| } | |||
| if (EnableHealthCheck != false) { | |||
| output.WriteRawTag(24); | |||
| output.WriteBool(EnableHealthCheck); | |||
| } | |||
| if (ClusterRegisterTimeoutInMs != 0L) { | |||
| output.WriteRawTag(32); | |||
| output.WriteInt64(ClusterRegisterTimeoutInMs); | |||
| } | |||
| if (HeartbeatTimeoutInMs != 0L) { | |||
| output.WriteRawTag(40); | |||
| output.WriteInt64(HeartbeatTimeoutInMs); | |||
| } | |||
| if (ShutdownBarrierTimeoutInMs != 0L) { | |||
| output.WriteRawTag(56); | |||
| output.WriteInt64(ShutdownBarrierTimeoutInMs); | |||
| } | |||
| if (AgentDestructionWithoutShutdown != false) { | |||
| output.WriteRawTag(64); | |||
| output.WriteBool(AgentDestructionWithoutShutdown); | |||
| } | |||
| recoverableJobs_.WriteTo(output, _repeated_recoverableJobs_codec); | |||
| coordinatedJobList_.WriteTo(output, _repeated_coordinatedJobList_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (ServiceType.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(ServiceType); | |||
| } | |||
| if (ServiceLeader.Length != 0) { | |||
| output.WriteRawTag(18); | |||
| output.WriteString(ServiceLeader); | |||
| } | |||
| if (EnableHealthCheck != false) { | |||
| output.WriteRawTag(24); | |||
| output.WriteBool(EnableHealthCheck); | |||
| } | |||
| if (ClusterRegisterTimeoutInMs != 0L) { | |||
| output.WriteRawTag(32); | |||
| output.WriteInt64(ClusterRegisterTimeoutInMs); | |||
| } | |||
| if (HeartbeatTimeoutInMs != 0L) { | |||
| output.WriteRawTag(40); | |||
| output.WriteInt64(HeartbeatTimeoutInMs); | |||
| } | |||
| if (ShutdownBarrierTimeoutInMs != 0L) { | |||
| output.WriteRawTag(56); | |||
| output.WriteInt64(ShutdownBarrierTimeoutInMs); | |||
| } | |||
| if (AgentDestructionWithoutShutdown != false) { | |||
| output.WriteRawTag(64); | |||
| output.WriteBool(AgentDestructionWithoutShutdown); | |||
| } | |||
| recoverableJobs_.WriteTo(ref output, _repeated_recoverableJobs_codec); | |||
| coordinatedJobList_.WriteTo(ref output, _repeated_coordinatedJobList_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (ServiceType.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceType); | |||
| } | |||
| if (ServiceLeader.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(ServiceLeader); | |||
| } | |||
| if (EnableHealthCheck != false) { | |||
| size += 1 + 1; | |||
| } | |||
| if (ClusterRegisterTimeoutInMs != 0L) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt64Size(ClusterRegisterTimeoutInMs); | |||
| } | |||
| if (HeartbeatTimeoutInMs != 0L) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt64Size(HeartbeatTimeoutInMs); | |||
| } | |||
| size += coordinatedJobList_.CalculateSize(_repeated_coordinatedJobList_codec); | |||
| if (ShutdownBarrierTimeoutInMs != 0L) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt64Size(ShutdownBarrierTimeoutInMs); | |||
| } | |||
| if (AgentDestructionWithoutShutdown != false) { | |||
| size += 1 + 1; | |||
| } | |||
| size += recoverableJobs_.CalculateSize(_repeated_recoverableJobs_codec); | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(CoordinationServiceConfig other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| if (other.ServiceType.Length != 0) { | |||
| ServiceType = other.ServiceType; | |||
| } | |||
| if (other.ServiceLeader.Length != 0) { | |||
| ServiceLeader = other.ServiceLeader; | |||
| } | |||
| if (other.EnableHealthCheck != false) { | |||
| EnableHealthCheck = other.EnableHealthCheck; | |||
| } | |||
| if (other.ClusterRegisterTimeoutInMs != 0L) { | |||
| ClusterRegisterTimeoutInMs = other.ClusterRegisterTimeoutInMs; | |||
| } | |||
| if (other.HeartbeatTimeoutInMs != 0L) { | |||
| HeartbeatTimeoutInMs = other.HeartbeatTimeoutInMs; | |||
| } | |||
| coordinatedJobList_.Add(other.coordinatedJobList_); | |||
| if (other.ShutdownBarrierTimeoutInMs != 0L) { | |||
| ShutdownBarrierTimeoutInMs = other.ShutdownBarrierTimeoutInMs; | |||
| } | |||
| if (other.AgentDestructionWithoutShutdown != false) { | |||
| AgentDestructionWithoutShutdown = other.AgentDestructionWithoutShutdown; | |||
| } | |||
| recoverableJobs_.Add(other.recoverableJobs_); | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| ServiceType = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| ServiceLeader = input.ReadString(); | |||
| break; | |||
| } | |||
| case 24: { | |||
| EnableHealthCheck = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| ClusterRegisterTimeoutInMs = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 40: { | |||
| HeartbeatTimeoutInMs = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 56: { | |||
| ShutdownBarrierTimeoutInMs = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 64: { | |||
| AgentDestructionWithoutShutdown = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 74: { | |||
| recoverableJobs_.AddEntriesFrom(input, _repeated_recoverableJobs_codec); | |||
| break; | |||
| } | |||
| case 82: { | |||
| coordinatedJobList_.AddEntriesFrom(input, _repeated_coordinatedJobList_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| ServiceType = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| ServiceLeader = input.ReadString(); | |||
| break; | |||
| } | |||
| case 24: { | |||
| EnableHealthCheck = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| ClusterRegisterTimeoutInMs = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 40: { | |||
| HeartbeatTimeoutInMs = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 56: { | |||
| ShutdownBarrierTimeoutInMs = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 64: { | |||
| AgentDestructionWithoutShutdown = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 74: { | |||
| recoverableJobs_.AddEntriesFrom(ref input, _repeated_recoverableJobs_codec); | |||
| break; | |||
| } | |||
| case 82: { | |||
| coordinatedJobList_.AddEntriesFrom(ref input, _repeated_coordinatedJobList_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endregion | |||
| } | |||
| #endregion Designer generated code | |||
| @@ -2,7 +2,7 @@ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/python/framework/cpp_shape_inference.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| @@ -55,23 +55,31 @@ namespace Tensorflow { | |||
| } | |||
| #region Messages | |||
| public sealed partial class CppShapeInferenceResult : pb::IMessage<CppShapeInferenceResult> { | |||
| public sealed partial class CppShapeInferenceResult : pb::IMessage<CppShapeInferenceResult> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<CppShapeInferenceResult> _parser = new pb::MessageParser<CppShapeInferenceResult>(() => new CppShapeInferenceResult()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<CppShapeInferenceResult> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CppShapeInferenceResult() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -79,6 +87,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CppShapeInferenceResult(CppShapeInferenceResult other) : this() { | |||
| shape_ = other.shape_ != null ? other.shape_.Clone() : null; | |||
| handleData_ = other.handleData_ != null ? other.handleData_.Clone() : null; | |||
| @@ -86,6 +95,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CppShapeInferenceResult Clone() { | |||
| return new CppShapeInferenceResult(this); | |||
| } | |||
| @@ -94,6 +104,7 @@ namespace Tensorflow { | |||
| public const int ShapeFieldNumber = 1; | |||
| private global::Tensorflow.TensorShapeProto shape_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.TensorShapeProto Shape { | |||
| get { return shape_; } | |||
| set { | |||
| @@ -105,6 +116,7 @@ namespace Tensorflow { | |||
| public const int HandleDataFieldNumber = 4; | |||
| private global::Tensorflow.CppShapeInferenceResult.Types.HandleData handleData_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { | |||
| get { return handleData_; } | |||
| set { | |||
| @@ -113,11 +125,13 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as CppShapeInferenceResult); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(CppShapeInferenceResult other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -131,6 +145,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (shape_ != null) hash ^= Shape.GetHashCode(); | |||
| @@ -142,12 +157,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (shape_ != null) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(Shape); | |||
| @@ -159,9 +179,29 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (shape_ != null) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(Shape); | |||
| } | |||
| if (handleData_ != null) { | |||
| output.WriteRawTag(34); | |||
| output.WriteMessage(HandleData); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (shape_ != null) { | |||
| @@ -177,6 +217,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(CppShapeInferenceResult other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -197,7 +238,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -220,29 +265,68 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| if (shape_ == null) { | |||
| Shape = new global::Tensorflow.TensorShapeProto(); | |||
| } | |||
| input.ReadMessage(Shape); | |||
| break; | |||
| } | |||
| case 34: { | |||
| if (handleData_ == null) { | |||
| HandleData = new global::Tensorflow.CppShapeInferenceResult.Types.HandleData(); | |||
| } | |||
| input.ReadMessage(HandleData); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| #region Nested types | |||
| /// <summary>Container for nested types declared in the CppShapeInferenceResult message type.</summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static partial class Types { | |||
| public sealed partial class HandleShapeAndType : pb::IMessage<HandleShapeAndType> { | |||
| public sealed partial class HandleShapeAndType : pb::IMessage<HandleShapeAndType> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<HandleShapeAndType> _parser = new pb::MessageParser<HandleShapeAndType>(() => new HandleShapeAndType()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<HandleShapeAndType> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public HandleShapeAndType() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -250,6 +334,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public HandleShapeAndType(HandleShapeAndType other) : this() { | |||
| shape_ = other.shape_ != null ? other.shape_.Clone() : null; | |||
| dtype_ = other.dtype_; | |||
| @@ -258,6 +343,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public HandleShapeAndType Clone() { | |||
| return new HandleShapeAndType(this); | |||
| } | |||
| @@ -266,6 +352,7 @@ namespace Tensorflow { | |||
| public const int ShapeFieldNumber = 1; | |||
| private global::Tensorflow.TensorShapeProto shape_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.TensorShapeProto Shape { | |||
| get { return shape_; } | |||
| set { | |||
| @@ -277,6 +364,7 @@ namespace Tensorflow { | |||
| public const int DtypeFieldNumber = 2; | |||
| private global::Tensorflow.DataType dtype_ = global::Tensorflow.DataType.DtInvalid; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.DataType Dtype { | |||
| get { return dtype_; } | |||
| set { | |||
| @@ -288,6 +376,7 @@ namespace Tensorflow { | |||
| public const int TypeFieldNumber = 4; | |||
| private global::Tensorflow.FullTypeDef type_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.FullTypeDef Type { | |||
| get { return type_; } | |||
| set { | |||
| @@ -296,11 +385,13 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as HandleShapeAndType); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(HandleShapeAndType other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -315,6 +406,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (shape_ != null) hash ^= Shape.GetHashCode(); | |||
| @@ -327,12 +419,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (shape_ != null) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(Shape); | |||
| @@ -348,9 +445,33 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (shape_ != null) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(Shape); | |||
| } | |||
| if (Dtype != global::Tensorflow.DataType.DtInvalid) { | |||
| output.WriteRawTag(16); | |||
| output.WriteEnum((int) Dtype); | |||
| } | |||
| if (type_ != null) { | |||
| output.WriteRawTag(34); | |||
| output.WriteMessage(Type); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (shape_ != null) { | |||
| @@ -369,6 +490,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(HandleShapeAndType other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -392,7 +514,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -419,27 +545,69 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| if (shape_ == null) { | |||
| Shape = new global::Tensorflow.TensorShapeProto(); | |||
| } | |||
| input.ReadMessage(Shape); | |||
| break; | |||
| } | |||
| case 16: { | |||
| Dtype = (global::Tensorflow.DataType) input.ReadEnum(); | |||
| break; | |||
| } | |||
| case 34: { | |||
| if (type_ == null) { | |||
| Type = new global::Tensorflow.FullTypeDef(); | |||
| } | |||
| input.ReadMessage(Type); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| public sealed partial class HandleData : pb::IMessage<HandleData> { | |||
| public sealed partial class HandleData : pb::IMessage<HandleData> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<HandleData> _parser = new pb::MessageParser<HandleData>(() => new HandleData()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<HandleData> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.CppShapeInferenceResult.Descriptor.NestedTypes[1]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public HandleData() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -447,6 +615,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public HandleData(HandleData other) : this() { | |||
| isSet_ = other.isSet_; | |||
| shapeAndType_ = other.shapeAndType_.Clone(); | |||
| @@ -454,6 +623,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public HandleData Clone() { | |||
| return new HandleData(this); | |||
| } | |||
| @@ -462,6 +632,7 @@ namespace Tensorflow { | |||
| public const int IsSetFieldNumber = 1; | |||
| private bool isSet_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool IsSet { | |||
| get { return isSet_; } | |||
| set { | |||
| @@ -478,16 +649,19 @@ namespace Tensorflow { | |||
| /// Only valid if <is_set>. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType> ShapeAndType { | |||
| get { return shapeAndType_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as HandleData); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(HandleData other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -501,6 +675,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (IsSet != false) hash ^= IsSet.GetHashCode(); | |||
| @@ -512,12 +687,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (IsSet != false) { | |||
| output.WriteRawTag(8); | |||
| output.WriteBool(IsSet); | |||
| @@ -526,9 +706,26 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (IsSet != false) { | |||
| output.WriteRawTag(8); | |||
| output.WriteBool(IsSet); | |||
| } | |||
| shapeAndType_.WriteTo(ref output, _repeated_shapeAndType_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (IsSet != false) { | |||
| @@ -542,6 +739,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(HandleData other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -554,7 +752,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -571,8 +773,32 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 8: { | |||
| IsSet = input.ReadBool(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| shapeAndType_.AddEntriesFrom(ref input, _repeated_shapeAndType_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| } | |||
| @@ -580,23 +806,31 @@ namespace Tensorflow { | |||
| } | |||
| public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage<CppShapeInferenceInputsNeeded> { | |||
| public sealed partial class CppShapeInferenceInputsNeeded : pb::IMessage<CppShapeInferenceInputsNeeded> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<CppShapeInferenceInputsNeeded> _parser = new pb::MessageParser<CppShapeInferenceInputsNeeded>(() => new CppShapeInferenceInputsNeeded()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<CppShapeInferenceInputsNeeded> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.CppShapeInferenceReflection.Descriptor.MessageTypes[1]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CppShapeInferenceInputsNeeded() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -604,6 +838,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CppShapeInferenceInputsNeeded(CppShapeInferenceInputsNeeded other) : this() { | |||
| inputTensorsNeeded_ = other.inputTensorsNeeded_.Clone(); | |||
| inputTensorsAsShapesNeeded_ = other.inputTensorsAsShapesNeeded_.Clone(); | |||
| @@ -611,6 +846,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public CppShapeInferenceInputsNeeded Clone() { | |||
| return new CppShapeInferenceInputsNeeded(this); | |||
| } | |||
| @@ -621,6 +857,7 @@ namespace Tensorflow { | |||
| = pb::FieldCodec.ForInt32(10); | |||
| private readonly pbc::RepeatedField<int> inputTensorsNeeded_ = new pbc::RepeatedField<int>(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<int> InputTensorsNeeded { | |||
| get { return inputTensorsNeeded_; } | |||
| } | |||
| @@ -631,16 +868,19 @@ namespace Tensorflow { | |||
| = pb::FieldCodec.ForInt32(18); | |||
| private readonly pbc::RepeatedField<int> inputTensorsAsShapesNeeded_ = new pbc::RepeatedField<int>(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<int> InputTensorsAsShapesNeeded { | |||
| get { return inputTensorsAsShapesNeeded_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as CppShapeInferenceInputsNeeded); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(CppShapeInferenceInputsNeeded other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -654,6 +894,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= inputTensorsNeeded_.GetHashCode(); | |||
| @@ -665,20 +906,39 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| inputTensorsNeeded_.WriteTo(output, _repeated_inputTensorsNeeded_codec); | |||
| inputTensorsAsShapesNeeded_.WriteTo(output, _repeated_inputTensorsAsShapesNeeded_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| inputTensorsNeeded_.WriteTo(ref output, _repeated_inputTensorsNeeded_codec); | |||
| inputTensorsAsShapesNeeded_.WriteTo(ref output, _repeated_inputTensorsAsShapesNeeded_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += inputTensorsNeeded_.CalculateSize(_repeated_inputTensorsNeeded_codec); | |||
| @@ -690,6 +950,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(CppShapeInferenceInputsNeeded other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -700,7 +961,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -719,7 +984,33 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: | |||
| case 8: { | |||
| inputTensorsNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsNeeded_codec); | |||
| break; | |||
| } | |||
| case 18: | |||
| case 16: { | |||
| inputTensorsAsShapesNeeded_.AddEntriesFrom(ref input, _repeated_inputTensorsAsShapesNeeded_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/core/protobuf/debug.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| @@ -55,23 +55,31 @@ namespace Tensorflow { | |||
| /// <summary> | |||
| /// Option for watching a node in TensorFlow Debugger (tfdbg). | |||
| /// </summary> | |||
| public sealed partial class DebugTensorWatch : pb::IMessage<DebugTensorWatch> { | |||
| public sealed partial class DebugTensorWatch : pb::IMessage<DebugTensorWatch> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<DebugTensorWatch> _parser = new pb::MessageParser<DebugTensorWatch>(() => new DebugTensorWatch()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<DebugTensorWatch> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebugTensorWatch() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -79,6 +87,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebugTensorWatch(DebugTensorWatch other) : this() { | |||
| nodeName_ = other.nodeName_; | |||
| outputSlot_ = other.outputSlot_; | |||
| @@ -89,6 +98,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebugTensorWatch Clone() { | |||
| return new DebugTensorWatch(this); | |||
| } | |||
| @@ -102,6 +112,7 @@ namespace Tensorflow { | |||
| /// general. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string NodeName { | |||
| get { return nodeName_; } | |||
| set { | |||
| @@ -120,6 +131,7 @@ namespace Tensorflow { | |||
| /// errors currently. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int OutputSlot { | |||
| get { return outputSlot_; } | |||
| set { | |||
| @@ -138,6 +150,7 @@ namespace Tensorflow { | |||
| /// e.g., {"DebugIdentity", "DebugNanCount"} | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<string> DebugOps { | |||
| get { return debugOps_; } | |||
| } | |||
| @@ -170,6 +183,7 @@ namespace Tensorflow { | |||
| /// TODO(cais): More visible documentation of this in g3docs. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<string> DebugUrls { | |||
| get { return debugUrls_; } | |||
| } | |||
| @@ -182,6 +196,7 @@ namespace Tensorflow { | |||
| /// incompatibility). Instead, just log the failure. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool TolerateDebugOpCreationFailures { | |||
| get { return tolerateDebugOpCreationFailures_; } | |||
| set { | |||
| @@ -190,11 +205,13 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as DebugTensorWatch); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(DebugTensorWatch other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -211,6 +228,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (NodeName.Length != 0) hash ^= NodeName.GetHashCode(); | |||
| @@ -225,12 +243,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (NodeName.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(NodeName); | |||
| @@ -248,9 +271,35 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (NodeName.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(NodeName); | |||
| } | |||
| if (OutputSlot != 0) { | |||
| output.WriteRawTag(16); | |||
| output.WriteInt32(OutputSlot); | |||
| } | |||
| debugOps_.WriteTo(ref output, _repeated_debugOps_codec); | |||
| debugUrls_.WriteTo(ref output, _repeated_debugUrls_codec); | |||
| if (TolerateDebugOpCreationFailures != false) { | |||
| output.WriteRawTag(40); | |||
| output.WriteBool(TolerateDebugOpCreationFailures); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (NodeName.Length != 0) { | |||
| @@ -271,6 +320,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(DebugTensorWatch other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -290,7 +340,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -319,30 +373,74 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| NodeName = input.ReadString(); | |||
| break; | |||
| } | |||
| case 16: { | |||
| OutputSlot = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 26: { | |||
| debugOps_.AddEntriesFrom(ref input, _repeated_debugOps_codec); | |||
| break; | |||
| } | |||
| case 34: { | |||
| debugUrls_.AddEntriesFrom(ref input, _repeated_debugUrls_codec); | |||
| break; | |||
| } | |||
| case 40: { | |||
| TolerateDebugOpCreationFailures = input.ReadBool(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| /// <summary> | |||
| /// Options for initializing DebuggerState in TensorFlow Debugger (tfdbg). | |||
| /// </summary> | |||
| public sealed partial class DebugOptions : pb::IMessage<DebugOptions> { | |||
| public sealed partial class DebugOptions : pb::IMessage<DebugOptions> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<DebugOptions> _parser = new pb::MessageParser<DebugOptions>(() => new DebugOptions()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<DebugOptions> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[1]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebugOptions() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -350,6 +448,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebugOptions(DebugOptions other) : this() { | |||
| debugTensorWatchOpts_ = other.debugTensorWatchOpts_.Clone(); | |||
| globalStep_ = other.globalStep_; | |||
| @@ -358,6 +457,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebugOptions Clone() { | |||
| return new DebugOptions(this); | |||
| } | |||
| @@ -371,6 +471,7 @@ namespace Tensorflow { | |||
| /// Debugging options | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.DebugTensorWatch> DebugTensorWatchOpts { | |||
| get { return debugTensorWatchOpts_; } | |||
| } | |||
| @@ -384,6 +485,7 @@ namespace Tensorflow { | |||
| /// step count. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long GlobalStep { | |||
| get { return globalStep_; } | |||
| set { | |||
| @@ -401,6 +503,7 @@ namespace Tensorflow { | |||
| /// are cleaned up from the disk after each Session.run. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool ResetDiskByteUsage { | |||
| get { return resetDiskByteUsage_; } | |||
| set { | |||
| @@ -409,11 +512,13 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as DebugOptions); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(DebugOptions other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -428,6 +533,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= debugTensorWatchOpts_.GetHashCode(); | |||
| @@ -440,12 +546,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| debugTensorWatchOpts_.WriteTo(output, _repeated_debugTensorWatchOpts_codec); | |||
| if (GlobalStep != 0L) { | |||
| output.WriteRawTag(80); | |||
| @@ -458,9 +569,30 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| debugTensorWatchOpts_.WriteTo(ref output, _repeated_debugTensorWatchOpts_codec); | |||
| if (GlobalStep != 0L) { | |||
| output.WriteRawTag(80); | |||
| output.WriteInt64(GlobalStep); | |||
| } | |||
| if (ResetDiskByteUsage != false) { | |||
| output.WriteRawTag(88); | |||
| output.WriteBool(ResetDiskByteUsage); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += debugTensorWatchOpts_.CalculateSize(_repeated_debugTensorWatchOpts_codec); | |||
| @@ -477,6 +609,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(DebugOptions other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -492,7 +625,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -513,27 +650,63 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 34: { | |||
| debugTensorWatchOpts_.AddEntriesFrom(ref input, _repeated_debugTensorWatchOpts_codec); | |||
| break; | |||
| } | |||
| case 80: { | |||
| GlobalStep = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 88: { | |||
| ResetDiskByteUsage = input.ReadBool(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| public sealed partial class DebuggedSourceFile : pb::IMessage<DebuggedSourceFile> { | |||
| public sealed partial class DebuggedSourceFile : pb::IMessage<DebuggedSourceFile> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<DebuggedSourceFile> _parser = new pb::MessageParser<DebuggedSourceFile>(() => new DebuggedSourceFile()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<DebuggedSourceFile> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[2]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebuggedSourceFile() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -541,6 +714,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebuggedSourceFile(DebuggedSourceFile other) : this() { | |||
| host_ = other.host_; | |||
| filePath_ = other.filePath_; | |||
| @@ -551,6 +725,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebuggedSourceFile Clone() { | |||
| return new DebuggedSourceFile(this); | |||
| } | |||
| @@ -562,6 +737,7 @@ namespace Tensorflow { | |||
| /// The host name on which a source code file is located. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string Host { | |||
| get { return host_; } | |||
| set { | |||
| @@ -576,6 +752,7 @@ namespace Tensorflow { | |||
| /// Path to the source code file. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string FilePath { | |||
| get { return filePath_; } | |||
| set { | |||
| @@ -590,6 +767,7 @@ namespace Tensorflow { | |||
| /// The timestamp at which the source code file is last modified. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long LastModified { | |||
| get { return lastModified_; } | |||
| set { | |||
| @@ -604,6 +782,7 @@ namespace Tensorflow { | |||
| /// Byte size of the file. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long Bytes { | |||
| get { return bytes_; } | |||
| set { | |||
| @@ -620,16 +799,19 @@ namespace Tensorflow { | |||
| /// Line-by-line content of the source code file. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<string> Lines { | |||
| get { return lines_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as DebuggedSourceFile); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(DebuggedSourceFile other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -646,6 +828,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (Host.Length != 0) hash ^= Host.GetHashCode(); | |||
| @@ -660,12 +843,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (Host.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Host); | |||
| @@ -686,9 +874,38 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (Host.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Host); | |||
| } | |||
| if (FilePath.Length != 0) { | |||
| output.WriteRawTag(18); | |||
| output.WriteString(FilePath); | |||
| } | |||
| if (LastModified != 0L) { | |||
| output.WriteRawTag(24); | |||
| output.WriteInt64(LastModified); | |||
| } | |||
| if (Bytes != 0L) { | |||
| output.WriteRawTag(32); | |||
| output.WriteInt64(Bytes); | |||
| } | |||
| lines_.WriteTo(ref output, _repeated_lines_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (Host.Length != 0) { | |||
| @@ -711,6 +928,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(DebuggedSourceFile other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -732,7 +950,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -761,27 +983,71 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| Host = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| FilePath = input.ReadString(); | |||
| break; | |||
| } | |||
| case 24: { | |||
| LastModified = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| Bytes = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 42: { | |||
| lines_.AddEntriesFrom(ref input, _repeated_lines_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| public sealed partial class DebuggedSourceFiles : pb::IMessage<DebuggedSourceFiles> { | |||
| public sealed partial class DebuggedSourceFiles : pb::IMessage<DebuggedSourceFiles> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<DebuggedSourceFiles> _parser = new pb::MessageParser<DebuggedSourceFiles>(() => new DebuggedSourceFiles()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<DebuggedSourceFiles> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.DebugReflection.Descriptor.MessageTypes[3]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebuggedSourceFiles() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -789,12 +1055,14 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebuggedSourceFiles(DebuggedSourceFiles other) : this() { | |||
| sourceFiles_ = other.sourceFiles_.Clone(); | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DebuggedSourceFiles Clone() { | |||
| return new DebuggedSourceFiles(this); | |||
| } | |||
| @@ -808,16 +1076,19 @@ namespace Tensorflow { | |||
| /// A collection of source code files. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.DebuggedSourceFile> SourceFiles { | |||
| get { return sourceFiles_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as DebuggedSourceFiles); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(DebuggedSourceFiles other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -830,6 +1101,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= sourceFiles_.GetHashCode(); | |||
| @@ -840,19 +1112,37 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| sourceFiles_.WriteTo(output, _repeated_sourceFiles_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| sourceFiles_.WriteTo(ref output, _repeated_sourceFiles_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += sourceFiles_.CalculateSize(_repeated_sourceFiles_codec); | |||
| @@ -863,6 +1153,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(DebuggedSourceFiles other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -872,7 +1163,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -885,7 +1180,27 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| sourceFiles_.AddEntriesFrom(ref input, _repeated_sourceFiles_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/core/framework/device_attributes.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| @@ -30,44 +30,53 @@ namespace Tensorflow { | |||
| "OAoKTG9jYWxMaW5rcxIqCgRsaW5rGAEgAygLMhwudGVuc29yZmxvdy5JbnRl", | |||
| "cmNvbm5lY3RMaW5rIloKDkRldmljZUxvY2FsaXR5Eg4KBmJ1c19pZBgBIAEo", | |||
| "BRIRCgludW1hX25vZGUYAiABKAUSJQoFbGlua3MYAyABKAsyFi50ZW5zb3Jm", | |||
| "bG93LkxvY2FsTGlua3MirAEKEERldmljZUF0dHJpYnV0ZXMSDAoEbmFtZRgB", | |||
| "bG93LkxvY2FsTGlua3MiwwEKEERldmljZUF0dHJpYnV0ZXMSDAoEbmFtZRgB", | |||
| "IAEoCRITCgtkZXZpY2VfdHlwZRgCIAEoCRIUCgxtZW1vcnlfbGltaXQYBCAB", | |||
| "KAMSLAoIbG9jYWxpdHkYBSABKAsyGi50ZW5zb3JmbG93LkRldmljZUxvY2Fs", | |||
| "aXR5EhMKC2luY2FybmF0aW9uGAYgASgGEhwKFHBoeXNpY2FsX2RldmljZV9k", | |||
| "ZXNjGAcgASgJQpEBChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCFkRldmlj", | |||
| "ZUF0dHJpYnV0ZXNQcm90b3NQAVpYZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", | |||
| "bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2ZyYW1ld29yay9kZXZpY2Vf", | |||
| "YXR0cmlidXRlc19nb19wcm90b/gBAWIGcHJvdG8z")); | |||
| "ZXNjGAcgASgJEhUKDXhsYV9nbG9iYWxfaWQYCCABKANCkQEKGG9yZy50ZW5z", | |||
| "b3JmbG93LmZyYW1ld29ya0IWRGV2aWNlQXR0cmlidXRlc1Byb3Rvc1ABWlhn", | |||
| "aXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dv", | |||
| "L2NvcmUvZnJhbWV3b3JrL2RldmljZV9hdHRyaWJ1dGVzX2dvX3Byb3Rv+AEB", | |||
| "YgZwcm90bzM=")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { }, | |||
| new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.InterconnectLink), global::Tensorflow.InterconnectLink.Parser, new[]{ "DeviceId", "Type", "Strength" }, null, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.LocalLinks), global::Tensorflow.LocalLinks.Parser, new[]{ "Link" }, null, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceLocality), global::Tensorflow.DeviceLocality.Parser, new[]{ "BusId", "NumaNode", "Links" }, null, null, null, null), | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceAttributes), global::Tensorflow.DeviceAttributes.Parser, new[]{ "Name", "DeviceType", "MemoryLimit", "Locality", "Incarnation", "PhysicalDeviceDesc" }, null, null, null, null) | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.DeviceAttributes), global::Tensorflow.DeviceAttributes.Parser, new[]{ "Name", "DeviceType", "MemoryLimit", "Locality", "Incarnation", "PhysicalDeviceDesc", "XlaGlobalId" }, null, null, null, null) | |||
| })); | |||
| } | |||
| #endregion | |||
| } | |||
| #region Messages | |||
| public sealed partial class InterconnectLink : pb::IMessage<InterconnectLink> { | |||
| public sealed partial class InterconnectLink : pb::IMessage<InterconnectLink> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<InterconnectLink> _parser = new pb::MessageParser<InterconnectLink>(() => new InterconnectLink()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<InterconnectLink> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public InterconnectLink() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -75,6 +84,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public InterconnectLink(InterconnectLink other) : this() { | |||
| deviceId_ = other.deviceId_; | |||
| type_ = other.type_; | |||
| @@ -83,6 +93,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public InterconnectLink Clone() { | |||
| return new InterconnectLink(this); | |||
| } | |||
| @@ -91,6 +102,7 @@ namespace Tensorflow { | |||
| public const int DeviceIdFieldNumber = 1; | |||
| private int deviceId_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int DeviceId { | |||
| get { return deviceId_; } | |||
| set { | |||
| @@ -102,6 +114,7 @@ namespace Tensorflow { | |||
| public const int TypeFieldNumber = 2; | |||
| private string type_ = ""; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string Type { | |||
| get { return type_; } | |||
| set { | |||
| @@ -113,6 +126,7 @@ namespace Tensorflow { | |||
| public const int StrengthFieldNumber = 3; | |||
| private int strength_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int Strength { | |||
| get { return strength_; } | |||
| set { | |||
| @@ -121,11 +135,13 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as InterconnectLink); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(InterconnectLink other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -140,6 +156,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (DeviceId != 0) hash ^= DeviceId.GetHashCode(); | |||
| @@ -152,12 +169,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (DeviceId != 0) { | |||
| output.WriteRawTag(8); | |||
| output.WriteInt32(DeviceId); | |||
| @@ -173,9 +195,33 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (DeviceId != 0) { | |||
| output.WriteRawTag(8); | |||
| output.WriteInt32(DeviceId); | |||
| } | |||
| if (Type.Length != 0) { | |||
| output.WriteRawTag(18); | |||
| output.WriteString(Type); | |||
| } | |||
| if (Strength != 0) { | |||
| output.WriteRawTag(24); | |||
| output.WriteInt32(Strength); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (DeviceId != 0) { | |||
| @@ -194,6 +240,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(InterconnectLink other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -211,7 +258,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -232,27 +283,63 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 8: { | |||
| DeviceId = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| Type = input.ReadString(); | |||
| break; | |||
| } | |||
| case 24: { | |||
| Strength = input.ReadInt32(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| public sealed partial class LocalLinks : pb::IMessage<LocalLinks> { | |||
| public sealed partial class LocalLinks : pb::IMessage<LocalLinks> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<LocalLinks> _parser = new pb::MessageParser<LocalLinks>(() => new LocalLinks()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<LocalLinks> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[1]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public LocalLinks() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -260,12 +347,14 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public LocalLinks(LocalLinks other) : this() { | |||
| link_ = other.link_.Clone(); | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public LocalLinks Clone() { | |||
| return new LocalLinks(this); | |||
| } | |||
| @@ -276,16 +365,19 @@ namespace Tensorflow { | |||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.InterconnectLink.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.InterconnectLink> link_ = new pbc::RepeatedField<global::Tensorflow.InterconnectLink>(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.InterconnectLink> Link { | |||
| get { return link_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as LocalLinks); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(LocalLinks other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -298,6 +390,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| hash ^= link_.GetHashCode(); | |||
| @@ -308,19 +401,37 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| link_.WriteTo(output, _repeated_link_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| link_.WriteTo(ref output, _repeated_link_codec); | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| size += link_.CalculateSize(_repeated_link_codec); | |||
| @@ -331,6 +442,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(LocalLinks other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -340,7 +452,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -353,27 +469,55 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| link_.AddEntriesFrom(ref input, _repeated_link_codec); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| public sealed partial class DeviceLocality : pb::IMessage<DeviceLocality> { | |||
| public sealed partial class DeviceLocality : pb::IMessage<DeviceLocality> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<DeviceLocality> _parser = new pb::MessageParser<DeviceLocality>(() => new DeviceLocality()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<DeviceLocality> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[2]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DeviceLocality() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -381,6 +525,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DeviceLocality(DeviceLocality other) : this() { | |||
| busId_ = other.busId_; | |||
| numaNode_ = other.numaNode_; | |||
| @@ -389,6 +534,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DeviceLocality Clone() { | |||
| return new DeviceLocality(this); | |||
| } | |||
| @@ -401,6 +547,7 @@ namespace Tensorflow { | |||
| /// no specific locality. Specific localities are indexed from 1. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int BusId { | |||
| get { return busId_; } | |||
| set { | |||
| @@ -415,6 +562,7 @@ namespace Tensorflow { | |||
| /// Optional NUMA locality of device. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int NumaNode { | |||
| get { return numaNode_; } | |||
| set { | |||
| @@ -429,6 +577,7 @@ namespace Tensorflow { | |||
| /// Optional local interconnect links to other devices. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.LocalLinks Links { | |||
| get { return links_; } | |||
| set { | |||
| @@ -437,11 +586,13 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as DeviceLocality); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(DeviceLocality other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -456,6 +607,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (BusId != 0) hash ^= BusId.GetHashCode(); | |||
| @@ -468,12 +620,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (BusId != 0) { | |||
| output.WriteRawTag(8); | |||
| output.WriteInt32(BusId); | |||
| @@ -489,9 +646,33 @@ namespace Tensorflow { | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (BusId != 0) { | |||
| output.WriteRawTag(8); | |||
| output.WriteInt32(BusId); | |||
| } | |||
| if (NumaNode != 0) { | |||
| output.WriteRawTag(16); | |||
| output.WriteInt32(NumaNode); | |||
| } | |||
| if (links_ != null) { | |||
| output.WriteRawTag(26); | |||
| output.WriteMessage(Links); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (BusId != 0) { | |||
| @@ -510,6 +691,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(DeviceLocality other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -530,7 +712,11 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -554,27 +740,66 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 8: { | |||
| BusId = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 16: { | |||
| NumaNode = input.ReadInt32(); | |||
| break; | |||
| } | |||
| case 26: { | |||
| if (links_ == null) { | |||
| Links = new global::Tensorflow.LocalLinks(); | |||
| } | |||
| input.ReadMessage(Links); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| public sealed partial class DeviceAttributes : pb::IMessage<DeviceAttributes> { | |||
| public sealed partial class DeviceAttributes : pb::IMessage<DeviceAttributes> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<DeviceAttributes> _parser = new pb::MessageParser<DeviceAttributes>(() => new DeviceAttributes()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<DeviceAttributes> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.DeviceAttributesReflection.Descriptor.MessageTypes[3]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DeviceAttributes() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -582,6 +807,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DeviceAttributes(DeviceAttributes other) : this() { | |||
| name_ = other.name_; | |||
| deviceType_ = other.deviceType_; | |||
| @@ -589,10 +815,12 @@ namespace Tensorflow { | |||
| locality_ = other.locality_ != null ? other.locality_.Clone() : null; | |||
| incarnation_ = other.incarnation_; | |||
| physicalDeviceDesc_ = other.physicalDeviceDesc_; | |||
| xlaGlobalId_ = other.xlaGlobalId_; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public DeviceAttributes Clone() { | |||
| return new DeviceAttributes(this); | |||
| } | |||
| @@ -604,6 +832,7 @@ namespace Tensorflow { | |||
| /// Fully specified name of the device within a cluster. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string Name { | |||
| get { return name_; } | |||
| set { | |||
| @@ -618,6 +847,7 @@ namespace Tensorflow { | |||
| /// String representation of device_type. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string DeviceType { | |||
| get { return deviceType_; } | |||
| set { | |||
| @@ -632,6 +862,7 @@ namespace Tensorflow { | |||
| /// Memory capacity of device in bytes. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long MemoryLimit { | |||
| get { return memoryLimit_; } | |||
| set { | |||
| @@ -647,6 +878,7 @@ namespace Tensorflow { | |||
| /// for supporting efficient data transfers. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.DeviceLocality Locality { | |||
| get { return locality_; } | |||
| set { | |||
| @@ -662,6 +894,7 @@ namespace Tensorflow { | |||
| /// initialized. "incarnation" should never be 0. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public ulong Incarnation { | |||
| get { return incarnation_; } | |||
| set { | |||
| @@ -676,6 +909,7 @@ namespace Tensorflow { | |||
| /// String representation of the physical device that this device maps to. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string PhysicalDeviceDesc { | |||
| get { return physicalDeviceDesc_; } | |||
| set { | |||
| @@ -683,12 +917,31 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| /// <summary>Field number for the "xla_global_id" field.</summary> | |||
| public const int XlaGlobalIdFieldNumber = 8; | |||
| private long xlaGlobalId_; | |||
| /// <summary> | |||
| /// A physical device ID for use in XLA DeviceAssignments, unique across | |||
| /// clients in a multi-client setup. Set to -1 if unavailable, non-negative | |||
| /// otherwise. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long XlaGlobalId { | |||
| get { return xlaGlobalId_; } | |||
| set { | |||
| xlaGlobalId_ = value; | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as DeviceAttributes); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(DeviceAttributes other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -702,10 +955,12 @@ namespace Tensorflow { | |||
| if (!object.Equals(Locality, other.Locality)) return false; | |||
| if (Incarnation != other.Incarnation) return false; | |||
| if (PhysicalDeviceDesc != other.PhysicalDeviceDesc) return false; | |||
| if (XlaGlobalId != other.XlaGlobalId) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (Name.Length != 0) hash ^= Name.GetHashCode(); | |||
| @@ -714,6 +969,7 @@ namespace Tensorflow { | |||
| if (locality_ != null) hash ^= Locality.GetHashCode(); | |||
| if (Incarnation != 0UL) hash ^= Incarnation.GetHashCode(); | |||
| if (PhysicalDeviceDesc.Length != 0) hash ^= PhysicalDeviceDesc.GetHashCode(); | |||
| if (XlaGlobalId != 0L) hash ^= XlaGlobalId.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| @@ -721,12 +977,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (Name.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Name); | |||
| @@ -751,12 +1012,56 @@ namespace Tensorflow { | |||
| output.WriteRawTag(58); | |||
| output.WriteString(PhysicalDeviceDesc); | |||
| } | |||
| if (XlaGlobalId != 0L) { | |||
| output.WriteRawTag(64); | |||
| output.WriteInt64(XlaGlobalId); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (Name.Length != 0) { | |||
| output.WriteRawTag(10); | |||
| output.WriteString(Name); | |||
| } | |||
| if (DeviceType.Length != 0) { | |||
| output.WriteRawTag(18); | |||
| output.WriteString(DeviceType); | |||
| } | |||
| if (MemoryLimit != 0L) { | |||
| output.WriteRawTag(32); | |||
| output.WriteInt64(MemoryLimit); | |||
| } | |||
| if (locality_ != null) { | |||
| output.WriteRawTag(42); | |||
| output.WriteMessage(Locality); | |||
| } | |||
| if (Incarnation != 0UL) { | |||
| output.WriteRawTag(49); | |||
| output.WriteFixed64(Incarnation); | |||
| } | |||
| if (PhysicalDeviceDesc.Length != 0) { | |||
| output.WriteRawTag(58); | |||
| output.WriteString(PhysicalDeviceDesc); | |||
| } | |||
| if (XlaGlobalId != 0L) { | |||
| output.WriteRawTag(64); | |||
| output.WriteInt64(XlaGlobalId); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (Name.Length != 0) { | |||
| @@ -777,6 +1082,9 @@ namespace Tensorflow { | |||
| if (PhysicalDeviceDesc.Length != 0) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(PhysicalDeviceDesc); | |||
| } | |||
| if (XlaGlobalId != 0L) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt64Size(XlaGlobalId); | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| @@ -784,6 +1092,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(DeviceAttributes other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -809,11 +1118,18 @@ namespace Tensorflow { | |||
| if (other.PhysicalDeviceDesc.Length != 0) { | |||
| PhysicalDeviceDesc = other.PhysicalDeviceDesc; | |||
| } | |||
| if (other.XlaGlobalId != 0L) { | |||
| XlaGlobalId = other.XlaGlobalId; | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -847,9 +1163,60 @@ namespace Tensorflow { | |||
| PhysicalDeviceDesc = input.ReadString(); | |||
| break; | |||
| } | |||
| case 64: { | |||
| XlaGlobalId = input.ReadInt64(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| Name = input.ReadString(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| DeviceType = input.ReadString(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| MemoryLimit = input.ReadInt64(); | |||
| break; | |||
| } | |||
| case 42: { | |||
| if (locality_ == null) { | |||
| Locality = new global::Tensorflow.DeviceLocality(); | |||
| } | |||
| input.ReadMessage(Locality); | |||
| break; | |||
| } | |||
| case 49: { | |||
| Incarnation = input.ReadFixed64(); | |||
| break; | |||
| } | |||
| case 58: { | |||
| PhysicalDeviceDesc = input.ReadString(); | |||
| break; | |||
| } | |||
| case 64: { | |||
| XlaGlobalId = input.ReadInt64(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| @@ -0,0 +1,340 @@ | |||
| // <auto-generated> | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/compiler/xla/service/cpu/executable.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using pbr = global::Google.Protobuf.Reflection; | |||
| using scg = global::System.Collections.Generic; | |||
| namespace Xla.Cpu { | |||
| /// <summary>Holder for reflection information generated from tensorflow/compiler/xla/service/cpu/executable.proto</summary> | |||
| public static partial class ExecutableReflection { | |||
| #region Descriptor | |||
| /// <summary>File descriptor for tensorflow/compiler/xla/service/cpu/executable.proto</summary> | |||
| public static pbr::FileDescriptor Descriptor { | |||
| get { return descriptor; } | |||
| } | |||
| private static pbr::FileDescriptor descriptor; | |||
| static ExecutableReflection() { | |||
| byte[] descriptorData = global::System.Convert.FromBase64String( | |||
| string.Concat( | |||
| "CjR0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9zZXJ2aWNlL2NwdS9leGVjdXRh", | |||
| "YmxlLnByb3RvEgd4bGEuY3B1Gjd0ZW5zb3JmbG93L2NvbXBpbGVyL3hsYS9z", | |||
| "ZXJ2aWNlL2NwdS94bGFfZnJhbWV3b3JrLnByb3RvGil0ZW5zb3JmbG93L2Nv", | |||
| "bXBpbGVyL3hsYS9zZXJ2aWNlL2hsby5wcm90byLXAQocWGxhUnVudGltZUNw", | |||
| "dUV4ZWN1dGFibGVQcm90bxI+ChZ4bGFfcnVudGltZV9leGVjdXRhYmxlGAEg", | |||
| "ASgLMh4ueGxhLlhsYVJ1bnRpbWVFeGVjdXRhYmxlUHJvdG8SQAoVeGxhX2Zy", | |||
| "YW1ld29ya19tYXBwaW5nGAIgASgLMiEueGxhLmNwdS5YbGFGcmFtZXdvcmtN", | |||
| "YXBwaW5nUHJvdG8SNQoRYnVmZmVyX2Fzc2lnbm1lbnQYAyABKAsyGi54bGEu", | |||
| "QnVmZmVyQXNzaWdubWVudFByb3Rv")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { global::Xla.Cpu.XlaFrameworkReflection.Descriptor, global::Xla.HloReflection.Descriptor, }, | |||
| new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Xla.Cpu.XlaRuntimeCpuExecutableProto), global::Xla.Cpu.XlaRuntimeCpuExecutableProto.Parser, new[]{ "XlaRuntimeExecutable", "XlaFrameworkMapping", "BufferAssignment" }, null, null, null, null) | |||
| })); | |||
| } | |||
| #endregion | |||
| } | |||
| #region Messages | |||
| public sealed partial class XlaRuntimeCpuExecutableProto : pb::IMessage<XlaRuntimeCpuExecutableProto> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<XlaRuntimeCpuExecutableProto> _parser = new pb::MessageParser<XlaRuntimeCpuExecutableProto>(() => new XlaRuntimeCpuExecutableProto()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<XlaRuntimeCpuExecutableProto> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Xla.Cpu.ExecutableReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public XlaRuntimeCpuExecutableProto() { | |||
| OnConstruction(); | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public XlaRuntimeCpuExecutableProto(XlaRuntimeCpuExecutableProto other) : this() { | |||
| xlaRuntimeExecutable_ = other.xlaRuntimeExecutable_ != null ? other.xlaRuntimeExecutable_.Clone() : null; | |||
| xlaFrameworkMapping_ = other.xlaFrameworkMapping_ != null ? other.xlaFrameworkMapping_.Clone() : null; | |||
| bufferAssignment_ = other.bufferAssignment_ != null ? other.bufferAssignment_.Clone() : null; | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public XlaRuntimeCpuExecutableProto Clone() { | |||
| return new XlaRuntimeCpuExecutableProto(this); | |||
| } | |||
| /// <summary>Field number for the "xla_runtime_executable" field.</summary> | |||
| public const int XlaRuntimeExecutableFieldNumber = 1; | |||
| private global::Xla.XlaRuntimeExecutableProto xlaRuntimeExecutable_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Xla.XlaRuntimeExecutableProto XlaRuntimeExecutable { | |||
| get { return xlaRuntimeExecutable_; } | |||
| set { | |||
| xlaRuntimeExecutable_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "xla_framework_mapping" field.</summary> | |||
| public const int XlaFrameworkMappingFieldNumber = 2; | |||
| private global::Xla.Cpu.XlaFrameworkMappingProto xlaFrameworkMapping_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Xla.Cpu.XlaFrameworkMappingProto XlaFrameworkMapping { | |||
| get { return xlaFrameworkMapping_; } | |||
| set { | |||
| xlaFrameworkMapping_ = value; | |||
| } | |||
| } | |||
| /// <summary>Field number for the "buffer_assignment" field.</summary> | |||
| public const int BufferAssignmentFieldNumber = 3; | |||
| private global::Xla.BufferAssignmentProto bufferAssignment_; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Xla.BufferAssignmentProto BufferAssignment { | |||
| get { return bufferAssignment_; } | |||
| set { | |||
| bufferAssignment_ = value; | |||
| } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as XlaRuntimeCpuExecutableProto); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(XlaRuntimeCpuExecutableProto other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| } | |||
| if (ReferenceEquals(other, this)) { | |||
| return true; | |||
| } | |||
| if (!object.Equals(XlaRuntimeExecutable, other.XlaRuntimeExecutable)) return false; | |||
| if (!object.Equals(XlaFrameworkMapping, other.XlaFrameworkMapping)) return false; | |||
| if (!object.Equals(BufferAssignment, other.BufferAssignment)) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (xlaRuntimeExecutable_ != null) hash ^= XlaRuntimeExecutable.GetHashCode(); | |||
| if (xlaFrameworkMapping_ != null) hash ^= XlaFrameworkMapping.GetHashCode(); | |||
| if (bufferAssignment_ != null) hash ^= BufferAssignment.GetHashCode(); | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| } | |||
| return hash; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (xlaRuntimeExecutable_ != null) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(XlaRuntimeExecutable); | |||
| } | |||
| if (xlaFrameworkMapping_ != null) { | |||
| output.WriteRawTag(18); | |||
| output.WriteMessage(XlaFrameworkMapping); | |||
| } | |||
| if (bufferAssignment_ != null) { | |||
| output.WriteRawTag(26); | |||
| output.WriteMessage(BufferAssignment); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (xlaRuntimeExecutable_ != null) { | |||
| output.WriteRawTag(10); | |||
| output.WriteMessage(XlaRuntimeExecutable); | |||
| } | |||
| if (xlaFrameworkMapping_ != null) { | |||
| output.WriteRawTag(18); | |||
| output.WriteMessage(XlaFrameworkMapping); | |||
| } | |||
| if (bufferAssignment_ != null) { | |||
| output.WriteRawTag(26); | |||
| output.WriteMessage(BufferAssignment); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (xlaRuntimeExecutable_ != null) { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaRuntimeExecutable); | |||
| } | |||
| if (xlaFrameworkMapping_ != null) { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(XlaFrameworkMapping); | |||
| } | |||
| if (bufferAssignment_ != null) { | |||
| size += 1 + pb::CodedOutputStream.ComputeMessageSize(BufferAssignment); | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| return size; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(XlaRuntimeCpuExecutableProto other) { | |||
| if (other == null) { | |||
| return; | |||
| } | |||
| if (other.xlaRuntimeExecutable_ != null) { | |||
| if (xlaRuntimeExecutable_ == null) { | |||
| XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); | |||
| } | |||
| XlaRuntimeExecutable.MergeFrom(other.XlaRuntimeExecutable); | |||
| } | |||
| if (other.xlaFrameworkMapping_ != null) { | |||
| if (xlaFrameworkMapping_ == null) { | |||
| XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); | |||
| } | |||
| XlaFrameworkMapping.MergeFrom(other.XlaFrameworkMapping); | |||
| } | |||
| if (other.bufferAssignment_ != null) { | |||
| if (bufferAssignment_ == null) { | |||
| BufferAssignment = new global::Xla.BufferAssignmentProto(); | |||
| } | |||
| BufferAssignment.MergeFrom(other.BufferAssignment); | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); | |||
| break; | |||
| case 10: { | |||
| if (xlaRuntimeExecutable_ == null) { | |||
| XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); | |||
| } | |||
| input.ReadMessage(XlaRuntimeExecutable); | |||
| break; | |||
| } | |||
| case 18: { | |||
| if (xlaFrameworkMapping_ == null) { | |||
| XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); | |||
| } | |||
| input.ReadMessage(XlaFrameworkMapping); | |||
| break; | |||
| } | |||
| case 26: { | |||
| if (bufferAssignment_ == null) { | |||
| BufferAssignment = new global::Xla.BufferAssignmentProto(); | |||
| } | |||
| input.ReadMessage(BufferAssignment); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 10: { | |||
| if (xlaRuntimeExecutable_ == null) { | |||
| XlaRuntimeExecutable = new global::Xla.XlaRuntimeExecutableProto(); | |||
| } | |||
| input.ReadMessage(XlaRuntimeExecutable); | |||
| break; | |||
| } | |||
| case 18: { | |||
| if (xlaFrameworkMapping_ == null) { | |||
| XlaFrameworkMapping = new global::Xla.Cpu.XlaFrameworkMappingProto(); | |||
| } | |||
| input.ReadMessage(XlaFrameworkMapping); | |||
| break; | |||
| } | |||
| case 26: { | |||
| if (bufferAssignment_ == null) { | |||
| BufferAssignment = new global::Xla.BufferAssignmentProto(); | |||
| } | |||
| input.ReadMessage(BufferAssignment); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #endregion | |||
| } | |||
| #endregion Designer generated code | |||
| @@ -2,7 +2,7 @@ | |||
| // Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| // source: tensorflow/core/framework/full_type.proto | |||
| // </auto-generated> | |||
| #pragma warning disable 1591, 0612, 3021 | |||
| #pragma warning disable 1591, 0612, 3021, 8981 | |||
| #region Designer generated code | |||
| using pb = global::Google.Protobuf; | |||
| @@ -25,26 +25,30 @@ namespace Tensorflow { | |||
| byte[] descriptorData = global::System.Convert.FromBase64String( | |||
| string.Concat( | |||
| "Cil0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Z1bGxfdHlwZS5wcm90bxIK", | |||
| "dGVuc29yZmxvdyJyCgtGdWxsVHlwZURlZhInCgd0eXBlX2lkGAEgASgOMhYu", | |||
| "dGVuc29yZmxvdyJ/CgtGdWxsVHlwZURlZhInCgd0eXBlX2lkGAEgASgOMhYu", | |||
| "dGVuc29yZmxvdy5GdWxsVHlwZUlkEiUKBGFyZ3MYAiADKAsyFy50ZW5zb3Jm", | |||
| "bG93LkZ1bGxUeXBlRGVmEgsKAXMYAyABKAlIAEIGCgRhdHRyKqwDCgpGdWxs", | |||
| "VHlwZUlkEg0KCVRGVF9VTlNFVBAAEgsKB1RGVF9WQVIQARILCgdURlRfQU5Z", | |||
| "EAISDwoLVEZUX1BST0RVQ1QQAxIQCgxURlRfQ0FMTEFCTEUQZBIPCgpURlRf", | |||
| "VEVOU09SEOgHEg4KCVRGVF9BUlJBWRDpBxIRCgxURlRfT1BUSU9OQUwQ6gcS", | |||
| "EAoLVEZUX0RBVEFTRVQQ9k4SDQoIVEZUX0JPT0wQyAESDgoJVEZUX1VJTlQ4", | |||
| "EMkBEg8KClRGVF9VSU5UMTYQygESDwoKVEZUX1VJTlQzMhDLARIPCgpURlRf", | |||
| "VUlOVDY0EMwBEg0KCFRGVF9JTlQ4EM0BEg4KCVRGVF9JTlQxNhDOARIOCglU", | |||
| "RlRfSU5UMzIQzwESDgoJVEZUX0lOVDY0ENABEg0KCFRGVF9IQUxGENEBEg4K", | |||
| "CVRGVF9GTE9BVBDSARIPCgpURlRfRE9VQkxFENMBEhEKDFRGVF9CRkxPQVQx", | |||
| "NhDXARISCg1URlRfQ09NUExFWDY0ENQBEhMKDlRGVF9DT01QTEVYMTI4ENUB", | |||
| "Eg8KClRGVF9TVFJJTkcQ1gFCfQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3Jr", | |||
| "Qg5GdWxsVHlwZVByb3Rvc1ABWkxnaXRodWIuY29tL3RlbnNvcmZsb3cvdGVu", | |||
| "c29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvZnJhbWV3b3JrL3R5cGVzX2dv", | |||
| "X3Byb3Rv+AEBYgZwcm90bzM=")); | |||
| "bG93LkZ1bGxUeXBlRGVmEgsKAXMYAyABKAlIABILCgFpGAQgASgDSABCBgoE", | |||
| "YXR0cirDBAoKRnVsbFR5cGVJZBINCglURlRfVU5TRVQQABILCgdURlRfVkFS", | |||
| "EAESCwoHVEZUX0FOWRACEg8KC1RGVF9QUk9EVUNUEAMSDQoJVEZUX05BTUVE", | |||
| "EAQSEAoMVEZUX0ZPUl9FQUNIEBQSEAoMVEZUX0NBTExBQkxFEGQSDwoKVEZU", | |||
| "X1RFTlNPUhDoBxIOCglURlRfQVJSQVkQ6QcSEQoMVEZUX09QVElPTkFMEOoH", | |||
| "EhAKC1RGVF9MSVRFUkFMEOsHEhAKC1RGVF9FTkNPREVEEOwHEg0KCFRGVF9C", | |||
| "T09MEMgBEg4KCVRGVF9VSU5UOBDJARIPCgpURlRfVUlOVDE2EMoBEg8KClRG", | |||
| "VF9VSU5UMzIQywESDwoKVEZUX1VJTlQ2NBDMARINCghURlRfSU5UOBDNARIO", | |||
| "CglURlRfSU5UMTYQzgESDgoJVEZUX0lOVDMyEM8BEg4KCVRGVF9JTlQ2NBDQ", | |||
| "ARINCghURlRfSEFMRhDRARIOCglURlRfRkxPQVQQ0gESDwoKVEZUX0RPVUJM", | |||
| "RRDTARIRCgxURlRfQkZMT0FUMTYQ1wESEgoNVEZUX0NPTVBMRVg2NBDUARIT", | |||
| "Cg5URlRfQ09NUExFWDEyOBDVARIPCgpURlRfU1RSSU5HENYBEhAKC1RGVF9E", | |||
| "QVRBU0VUEPZOEg8KClRGVF9SQUdHRUQQ904SEQoMVEZUX0lURVJBVE9SEPhO", | |||
| "EhMKDlRGVF9NVVRFWF9MT0NLENpPEhcKElRGVF9MRUdBQ1lfVkFSSUFOVBDb", | |||
| "T0KBAQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3JrQg5GdWxsVHlwZVByb3Rv", | |||
| "c1ABWlBnaXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3Jm", | |||
| "bG93L2dvL2NvcmUvZnJhbWV3b3JrL2Z1bGxfdHlwZV9nb19wcm90b/gBAWIG", | |||
| "cHJvdG8z")); | |||
| descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, | |||
| new pbr::FileDescriptor[] { }, | |||
| new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.FullTypeId), }, null, new pbr::GeneratedClrTypeInfo[] { | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FullTypeDef), global::Tensorflow.FullTypeDef.Parser, new[]{ "TypeId", "Args", "S" }, new[]{ "Attr" }, null, null, null) | |||
| new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FullTypeDef), global::Tensorflow.FullTypeDef.Parser, new[]{ "TypeId", "Args", "S", "I" }, new[]{ "Attr" }, null, null, null) | |||
| })); | |||
| } | |||
| #endregion | |||
| @@ -52,6 +56,7 @@ namespace Tensorflow { | |||
| } | |||
| #region Enums | |||
| /// <summary> | |||
| /// LINT.IfChange | |||
| /// Experimental. Represents the complete type information of a TensorFlow value. | |||
| /// </summary> | |||
| public enum FullTypeId { | |||
| @@ -69,7 +74,7 @@ namespace Tensorflow { | |||
| /// TFT_TENSOR[TFT_VAR["T"]], TFT_TENSOR[TFT_VAR["T"]] are two tensors of | |||
| /// identical element types. | |||
| /// TFT_TENSOR[TFT_VAR["P"]], TFT_TENSOR[TFT_VAR["Q"]] are two tensors of | |||
| /// potentially different element types. | |||
| /// independent element types. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_VAR")] TftVar = 1, | |||
| /// <summary> | |||
| @@ -90,14 +95,55 @@ namespace Tensorflow { | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_PRODUCT")] TftProduct = 3, | |||
| /// <summary> | |||
| /// Represents a named field, with the name stored in the attribute. | |||
| /// | |||
| /// Parametrization: | |||
| /// TFT_NAMED[<type>]{<name>} | |||
| /// * <type> is the type of the field | |||
| /// * <name> is the field name, as string (thpugh can theoretically be an int | |||
| /// as well) | |||
| /// | |||
| /// Example: | |||
| /// TFT_RECORD[ | |||
| /// TFT_NAMED[TFT_TENSOR[TFT_INT32]]{'foo'}, | |||
| /// TFT_NAMED[TFT_TENSOR[TFT_FLOAT32]]{'bar'}, | |||
| /// ] | |||
| /// is a structure with two fields, an int tensor "foo" and a float tensor | |||
| /// "bar". | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_NAMED")] TftNamed = 4, | |||
| /// <summary> | |||
| /// Template definition. Expands the variables by repeating a template as | |||
| /// arguments of container. | |||
| /// | |||
| /// Parametrization: | |||
| /// TFT_FOR_EACH[<container_type>, <template>, <expansions>] | |||
| /// * <container_type> is the type of the container that the template will be | |||
| /// expanded into | |||
| /// * <template> is any type definition that potentially contains type | |||
| /// variables | |||
| /// * <expansions> is a TFT_VAR and may include more types in the future | |||
| /// | |||
| /// Example: | |||
| /// TFT_FOR_EACH[ | |||
| /// TFT_PRODUCT, | |||
| /// TFT_TENSOR[TFT_VAR["t"]], | |||
| /// TFT_VAR["t"] | |||
| /// ] | |||
| /// will substitute a T = TFT_INT32 to TFT_PRODUCT[TFT_TENSOR[TFT_INT32]] | |||
| /// and a T = (TFT_INT32, TFT_INT64) to | |||
| /// TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_INT64]]. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_FOR_EACH")] TftForEach = 20, | |||
| /// <summary> | |||
| /// Callable types describe functions and ops. | |||
| /// | |||
| /// Parametrization: | |||
| /// TFT_CALLABLE[<arg type>, <return type>] | |||
| /// * <arg_type> is the type of the arguments; TFT_PRODUCT represents | |||
| /// * <arg type> is the type of the arguments; TFT_PRODUCT represents | |||
| /// multiple | |||
| /// arguments. | |||
| /// * <return_type> is the return type; TFT_PRODUCT represents multiple | |||
| /// * <return type> is the return type; TFT_PRODUCT represents multiple | |||
| /// return values (that means that callables returning multiple things | |||
| /// don't necessarily return a single tuple). | |||
| /// | |||
| @@ -115,9 +161,9 @@ namespace Tensorflow { | |||
| /// | |||
| /// Parametrization: | |||
| /// TFT_TENSOR[<element type>, <shape type>] | |||
| /// * <element_type> is currently limited to one of the element types | |||
| /// * <element type> is currently limited to one of the element types | |||
| /// defined below. | |||
| /// * <shape_type> is not yet defined, and may only be TFT_UNKNOWN for now. | |||
| /// * <shape type> is not yet defined, and may only be TFT_UNKNOWN for now. | |||
| /// | |||
| /// A TFT_SHAPE type will be defined in the future. | |||
| /// | |||
| @@ -140,7 +186,7 @@ namespace Tensorflow { | |||
| /// | |||
| /// Parametrization: | |||
| /// TFT_ARRAY[<element type>] | |||
| /// * <element_type> may be any concrete type. | |||
| /// * <element type> may be any concrete type. | |||
| /// | |||
| /// Examples: | |||
| /// TFT_ARRAY[TFT_TENSOR[TFT_INT32]] is a TensorArray holding int32 Tensors | |||
| @@ -159,7 +205,7 @@ namespace Tensorflow { | |||
| /// | |||
| /// Parametrization: | |||
| /// TFT_OPTIONAL[<element type>] | |||
| /// * <element_type> may be any concrete type. | |||
| /// * <element type> may be any concrete type. | |||
| /// | |||
| /// Examples: | |||
| /// TFT_OPTIONAL[TFT_TENSOR[TFT_INT32]] is an Optional holding an int32 | |||
| @@ -167,28 +213,31 @@ namespace Tensorflow { | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_OPTIONAL")] TftOptional = 1002, | |||
| /// <summary> | |||
| /// Datasets created by tf.data ops and APIs. Datasets have generator/iterable | |||
| /// semantics, that is, one can construct an iterator from them. Like | |||
| /// Array, they are considered to return elements that can be described | |||
| /// by a single type. Unlike Array, they do not support random access or | |||
| /// mutation, and can potentially produce an infinite number of elements. | |||
| /// A datasets can produce logical structures (e.g. multiple elements). This | |||
| /// is expressed using TFT_PRODUCT. | |||
| /// Literal types describe compile-time constant values. | |||
| /// Literal types may also participate in dependent types. | |||
| /// | |||
| /// Parametrization: TFT_ARRAY[<element type>]. | |||
| /// <element_type> may be a concrete type or a type symbol. It represents the | |||
| /// data type of the elements produced by the dataset. | |||
| /// Parametrization: | |||
| /// TFT_LITERAL[<value type>]{<value>} | |||
| /// * <value type> may be any concrete type compatible that can hold <value> | |||
| /// * <value> is the type's attribute, and holds the actual literal value | |||
| /// | |||
| /// Examples: | |||
| /// TFT_DATSET[TFT_TENSOR[TFT_INT32]] is a Dataset producing single int32 | |||
| /// Tensors of unknown shape. | |||
| /// TFT_DATSET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT32]] is | |||
| /// a | |||
| /// Dataset producing pairs of Tensors, one integer and one float. | |||
| /// Note: The high ID number is to prepare for the eventuality that Datasets | |||
| /// will be supported by user types in the future. | |||
| /// TFT_LITERAL[TFT_INT32]{1} is the compile-time constant 1. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_DATASET")] TftDataset = 10102, | |||
| [pbr::OriginalName("TFT_LITERAL")] TftLiteral = 1003, | |||
| /// <summary> | |||
| /// Encoding types describe a value of a certain type, encoded as a different | |||
| /// type. | |||
| /// | |||
| /// Parametrization: | |||
| /// TFT_ENCODED[<encoded type>, <encoding type>] | |||
| /// * <encoded type> may be any type | |||
| /// * <encoding type> may be any type | |||
| /// | |||
| /// Examples: | |||
| /// TFT_ENCODING[TFT_INT32, TFT_STRING] is an integer encoded as string. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_ENCODED")] TftEncoded = 1004, | |||
| /// <summary> | |||
| /// The bool element type. | |||
| /// TODO(mdan): Quantized types, legacy representations (e.g. ref) | |||
| @@ -222,6 +271,62 @@ namespace Tensorflow { | |||
| /// The string element type. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_STRING")] TftString = 214, | |||
| /// <summary> | |||
| /// Datasets created by tf.data ops and APIs. Datasets have generator/iterable | |||
| /// semantics, that is, one can construct an iterator from them. Like | |||
| /// Array, they are considered to return elements that can be described | |||
| /// by a single type. Unlike Array, they do not support random access or | |||
| /// mutation, and can potentially produce an infinite number of elements. | |||
| /// A datasets can produce logical structures (e.g. multiple elements). This | |||
| /// is expressed using TFT_PRODUCT. | |||
| /// | |||
| /// Parametrization: TFT_DATASET[<element type>]. | |||
| /// * <element type> may be a concrete type or a type symbol. It represents | |||
| /// the data type of the elements produced by the dataset. | |||
| /// | |||
| /// Examples: | |||
| /// TFT_DATSET[TFT_TENSOR[TFT_INT32]] is a Dataset producing single int32 | |||
| /// Tensors of unknown shape. | |||
| /// TFT_DATSET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT32]] is | |||
| /// a Dataset producing pairs of Tensors, one integer and one float. | |||
| /// Note: The high ID number is to prepare for the eventuality that Datasets | |||
| /// will be supported by user types in the future. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_DATASET")] TftDataset = 10102, | |||
| /// <summary> | |||
| /// A ragged tensor created by tf.ragged ops and APIs. | |||
| /// | |||
| /// Parametrization: TFT_RAGGED[<element_type>]. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_RAGGED")] TftRagged = 10103, | |||
| /// <summary> | |||
| /// Iterators created by tf.data ops and APIs. Very similar to Datasets, except | |||
| /// they are mutable. | |||
| /// | |||
| /// Parametrization: TFT_ITERATOR[<element type>]. | |||
| /// * <element type> may be a concrete type or a type symbol. It represents | |||
| /// the data type of the elements produced by the dataset. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_ITERATOR")] TftIterator = 10104, | |||
| /// <summary> | |||
| /// A mutex lock tensor, produced by tf.raw_ops.MutexLock. | |||
| /// Unlike strict execution models, where ownership of a lock is denoted by | |||
| /// "running after the lock has been acquired", in non-strict mode, lock | |||
| /// ownership is in the true sense: "the op argument representing the lock is | |||
| /// available". | |||
| /// Mutex locks are the dynamic counterpart of control dependencies. | |||
| /// TODO(mdan): Properly document this thing. | |||
| /// | |||
| /// Parametrization: TFT_MUTEX_LOCK[]. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_MUTEX_LOCK")] TftMutexLock = 10202, | |||
| /// <summary> | |||
| /// The equivalent of a Tensor with DT_VARIANT dtype, kept here to simplify | |||
| /// translation. This type should not normally appear after type inference. | |||
| /// Note that LEGACY_VARIANT != ANY: TENSOR[INT32] is a subtype of ANY, but is | |||
| /// not a subtype of LEGACY_VARIANT. | |||
| /// </summary> | |||
| [pbr::OriginalName("TFT_LEGACY_VARIANT")] TftLegacyVariant = 10203, | |||
| } | |||
| #endregion | |||
| @@ -233,23 +338,31 @@ namespace Tensorflow { | |||
| /// particular the encoding imposes no restrictions on what the parameters of any | |||
| /// type should be, which in particular needs to be true for type symbols. | |||
| /// </summary> | |||
| public sealed partial class FullTypeDef : pb::IMessage<FullTypeDef> { | |||
| public sealed partial class FullTypeDef : pb::IMessage<FullTypeDef> | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| , pb::IBufferMessage | |||
| #endif | |||
| { | |||
| private static readonly pb::MessageParser<FullTypeDef> _parser = new pb::MessageParser<FullTypeDef>(() => new FullTypeDef()); | |||
| private pb::UnknownFieldSet _unknownFields; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pb::MessageParser<FullTypeDef> Parser { get { return _parser; } } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public static pbr::MessageDescriptor Descriptor { | |||
| get { return global::Tensorflow.FullTypeReflection.Descriptor.MessageTypes[0]; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| pbr::MessageDescriptor pb::IMessage.Descriptor { | |||
| get { return Descriptor; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public FullTypeDef() { | |||
| OnConstruction(); | |||
| } | |||
| @@ -257,6 +370,7 @@ namespace Tensorflow { | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public FullTypeDef(FullTypeDef other) : this() { | |||
| typeId_ = other.typeId_; | |||
| args_ = other.args_.Clone(); | |||
| @@ -264,12 +378,16 @@ namespace Tensorflow { | |||
| case AttrOneofCase.S: | |||
| S = other.S; | |||
| break; | |||
| case AttrOneofCase.I: | |||
| I = other.I; | |||
| break; | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public FullTypeDef Clone() { | |||
| return new FullTypeDef(this); | |||
| } | |||
| @@ -283,6 +401,7 @@ namespace Tensorflow { | |||
| /// symbol (Any, Union). See FullTypeId for details. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public global::Tensorflow.FullTypeId TypeId { | |||
| get { return typeId_; } | |||
| set { | |||
| @@ -296,6 +415,7 @@ namespace Tensorflow { | |||
| = pb::FieldCodec.ForMessage(18, global::Tensorflow.FullTypeDef.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.FullTypeDef> args_ = new pbc::RepeatedField<global::Tensorflow.FullTypeDef>(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public pbc::RepeatedField<global::Tensorflow.FullTypeDef> Args { | |||
| get { return args_; } | |||
| } | |||
| @@ -303,6 +423,7 @@ namespace Tensorflow { | |||
| /// <summary>Field number for the "s" field.</summary> | |||
| public const int SFieldNumber = 3; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public string S { | |||
| get { return attrCase_ == AttrOneofCase.S ? (string) attr_ : ""; } | |||
| set { | |||
| @@ -311,30 +432,50 @@ namespace Tensorflow { | |||
| } | |||
| } | |||
| /// <summary>Field number for the "i" field.</summary> | |||
| public const int IFieldNumber = 4; | |||
| /// <summary> | |||
| /// TODO(mdan): list/tensor, map? Need to reconcile with TFT_RECORD, etc. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public long I { | |||
| get { return attrCase_ == AttrOneofCase.I ? (long) attr_ : 0L; } | |||
| set { | |||
| attr_ = value; | |||
| attrCase_ = AttrOneofCase.I; | |||
| } | |||
| } | |||
| private object attr_; | |||
| /// <summary>Enum of possible cases for the "attr" oneof.</summary> | |||
| public enum AttrOneofCase { | |||
| None = 0, | |||
| S = 3, | |||
| I = 4, | |||
| } | |||
| private AttrOneofCase attrCase_ = AttrOneofCase.None; | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public AttrOneofCase AttrCase { | |||
| get { return attrCase_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void ClearAttr() { | |||
| attrCase_ = AttrOneofCase.None; | |||
| attr_ = null; | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override bool Equals(object other) { | |||
| return Equals(other as FullTypeDef); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public bool Equals(FullTypeDef other) { | |||
| if (ReferenceEquals(other, null)) { | |||
| return false; | |||
| @@ -345,16 +486,19 @@ namespace Tensorflow { | |||
| if (TypeId != other.TypeId) return false; | |||
| if(!args_.Equals(other.args_)) return false; | |||
| if (S != other.S) return false; | |||
| if (I != other.I) return false; | |||
| if (AttrCase != other.AttrCase) return false; | |||
| return Equals(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override int GetHashCode() { | |||
| int hash = 1; | |||
| if (TypeId != global::Tensorflow.FullTypeId.TftUnset) hash ^= TypeId.GetHashCode(); | |||
| hash ^= args_.GetHashCode(); | |||
| if (attrCase_ == AttrOneofCase.S) hash ^= S.GetHashCode(); | |||
| if (attrCase_ == AttrOneofCase.I) hash ^= I.GetHashCode(); | |||
| hash ^= (int) attrCase_; | |||
| if (_unknownFields != null) { | |||
| hash ^= _unknownFields.GetHashCode(); | |||
| @@ -363,12 +507,17 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public override string ToString() { | |||
| return pb::JsonFormatter.ToDiagnosticString(this); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void WriteTo(pb::CodedOutputStream output) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| output.WriteRawMessage(this); | |||
| #else | |||
| if (TypeId != global::Tensorflow.FullTypeId.TftUnset) { | |||
| output.WriteRawTag(8); | |||
| output.WriteEnum((int) TypeId); | |||
| @@ -378,12 +527,41 @@ namespace Tensorflow { | |||
| output.WriteRawTag(26); | |||
| output.WriteString(S); | |||
| } | |||
| if (attrCase_ == AttrOneofCase.I) { | |||
| output.WriteRawTag(32); | |||
| output.WriteInt64(I); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(output); | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalWriteTo(ref pb::WriteContext output) { | |||
| if (TypeId != global::Tensorflow.FullTypeId.TftUnset) { | |||
| output.WriteRawTag(8); | |||
| output.WriteEnum((int) TypeId); | |||
| } | |||
| args_.WriteTo(ref output, _repeated_args_codec); | |||
| if (attrCase_ == AttrOneofCase.S) { | |||
| output.WriteRawTag(26); | |||
| output.WriteString(S); | |||
| } | |||
| if (attrCase_ == AttrOneofCase.I) { | |||
| output.WriteRawTag(32); | |||
| output.WriteInt64(I); | |||
| } | |||
| if (_unknownFields != null) { | |||
| _unknownFields.WriteTo(ref output); | |||
| } | |||
| } | |||
| #endif | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public int CalculateSize() { | |||
| int size = 0; | |||
| if (TypeId != global::Tensorflow.FullTypeId.TftUnset) { | |||
| @@ -393,6 +571,9 @@ namespace Tensorflow { | |||
| if (attrCase_ == AttrOneofCase.S) { | |||
| size += 1 + pb::CodedOutputStream.ComputeStringSize(S); | |||
| } | |||
| if (attrCase_ == AttrOneofCase.I) { | |||
| size += 1 + pb::CodedOutputStream.ComputeInt64Size(I); | |||
| } | |||
| if (_unknownFields != null) { | |||
| size += _unknownFields.CalculateSize(); | |||
| } | |||
| @@ -400,6 +581,7 @@ namespace Tensorflow { | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(FullTypeDef other) { | |||
| if (other == null) { | |||
| return; | |||
| @@ -412,13 +594,20 @@ namespace Tensorflow { | |||
| case AttrOneofCase.S: | |||
| S = other.S; | |||
| break; | |||
| case AttrOneofCase.I: | |||
| I = other.I; | |||
| break; | |||
| } | |||
| _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| public void MergeFrom(pb::CodedInputStream input) { | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| input.ReadRawMessage(this); | |||
| #else | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| @@ -437,9 +626,45 @@ namespace Tensorflow { | |||
| S = input.ReadString(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| I = input.ReadInt64(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #if !GOOGLE_PROTOBUF_REFSTRUCT_COMPATIBILITY_MODE | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| [global::System.CodeDom.Compiler.GeneratedCode("protoc", null)] | |||
| void pb::IBufferMessage.InternalMergeFrom(ref pb::ParseContext input) { | |||
| uint tag; | |||
| while ((tag = input.ReadTag()) != 0) { | |||
| switch(tag) { | |||
| default: | |||
| _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, ref input); | |||
| break; | |||
| case 8: { | |||
| TypeId = (global::Tensorflow.FullTypeId) input.ReadEnum(); | |||
| break; | |||
| } | |||
| case 18: { | |||
| args_.AddEntriesFrom(ref input, _repeated_args_codec); | |||
| break; | |||
| } | |||
| case 26: { | |||
| S = input.ReadString(); | |||
| break; | |||
| } | |||
| case 32: { | |||
| I = input.ReadInt64(); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||