diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..fdf00590 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: ['https://bit.ly/2op1mu5']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/README.md b/README.md index 04df88ea..a80191a7 100644 --- a/README.md +++ b/README.md @@ -107,34 +107,9 @@ Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube) Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflownet.readthedocs.io/en/latest/FrontCover.html). -### More examples: +There are many examples reside at [TensorFlow.NET Examples](https://github.com/SciSharp/TensorFlow.NET-Examples). -Run specific example in shell: - -```cs -dotnet TensorFlowNET.Examples.dll -ex "MNIST CNN" -``` - -Example runner will download all the required files like training data and model pb files. - -* [Hello World](test/TensorFlowNET.Examples/HelloWorld.cs) -* [Basic Operations](test/TensorFlowNET.Examples/BasicOperations.cs) -* [Linear Regression](test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs) -* [Logistic Regression](test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs) -* [Nearest Neighbor](test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs) -* [Naive Bayes Classification](test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs) -* [Full Connected Neural Network](test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs) -* [Image Processing](test/TensorFlowNET.Examples/ImageProcessing) -* [K-means Clustering](test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs) -* [NN XOR](test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs) -* [Object Detection](test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs) -* [Text Classification](test/TensorFlowNET.Examples/TextProcessing/BinaryTextClassification.cs) -* [CNN Text Classification](test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs) -* [MNIST CNN](test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs) -* [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcessing/NER) -* [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs) - -More troubleshooting of running example refer [here](tensorflowlib/README.md). +Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md). ### Contribute: diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 8d230d26..96a8af5c 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -5,23 +5,15 @@ VisualStudioVersion = 16.0.29102.190 MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.UnitTest", "test\TensorFlowNET.UnitTest\TensorFlowNET.UnitTest.csproj", "{029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj", "{1FE60088-157C-4140-91AB-E96B915E4BAE}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.Core", "src\KerasNET.Core\Keras.Core.csproj", "{902E188F-A953-43B4-9991-72BAB1697BC3}" -EndProject -Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowBenchmark", "src\TensorFlowNet.Benchmarks\TensorFlowBenchmark.csproj", "{68861442-971A-4196-876E-C9330F0B3C54}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowHub", "src\TensorFlowHub\TensorFlowHub.csproj", "{8FD59A5A-97EB-457E-B9F1-D88B0C822C6E}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Models", "src\TensorFlowNET.Models\TensorFlowNET.Models.csproj", "{D03F94CF-B283-4730-B177-21A57641061F}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowText", "src\TensorFlowText\TensorFlowText.csproj", "{B598E5D5-BD2D-4191-8532-F2FBAC31AB81}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Text", "src\TensorFlowNET.Text\TensorFlowNET.Text.csproj", "{904472F8-40E1-4650-AA6F-C7F209B3691B}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowDatasets", "src\TensorFlowDatasets\TensorFlowDatasets.csproj", "{DF151A51-E9FD-41BD-B0F4-08A743755D44}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Hub", "src\TensorFlowNET.Hub\TensorFlowNET.Hub.csproj", "{4EAFAE19-C832-47C6-B01E-0F4268C9072C}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples.GPU", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.GPU.csproj", "{6F6B3382-8F87-4CD9-BF87-C81D5405685A}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Datasets", "src\TensorFlowNET.Datasets\TensorFlowNET.Datasets.csproj", "{494D6CAD-2C0D-4C0B-90E2-B097DB039383}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -33,42 +25,26 @@ Global {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|Any CPU.Build.0 = Debug|Any CPU {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|Any CPU.ActiveCfg = Release|Any CPU {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|Any CPU.Build.0 = Release|Any CPU - {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {1FE60088-157C-4140-91AB-E96B915E4BAE}.Debug|Any CPU.Build.0 = Debug|Any CPU - {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.ActiveCfg = Release|Any CPU - {1FE60088-157C-4140-91AB-E96B915E4BAE}.Release|Any CPU.Build.0 = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU - {902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {902E188F-A953-43B4-9991-72BAB1697BC3}.Debug|Any CPU.Build.0 = Debug|Any CPU - {902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.ActiveCfg = Release|Any CPU - {902E188F-A953-43B4-9991-72BAB1697BC3}.Release|Any CPU.Build.0 = Release|Any CPU - {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU - {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU - {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.Build.0 = Release|Any CPU - {68861442-971A-4196-876E-C9330F0B3C54}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {68861442-971A-4196-876E-C9330F0B3C54}.Debug|Any CPU.Build.0 = Debug|Any CPU - {68861442-971A-4196-876E-C9330F0B3C54}.Release|Any CPU.ActiveCfg = Release|Any CPU - {68861442-971A-4196-876E-C9330F0B3C54}.Release|Any CPU.Build.0 = Release|Any CPU - {8FD59A5A-97EB-457E-B9F1-D88B0C822C6E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {8FD59A5A-97EB-457E-B9F1-D88B0C822C6E}.Debug|Any CPU.Build.0 = Debug|Any CPU - {8FD59A5A-97EB-457E-B9F1-D88B0C822C6E}.Release|Any CPU.ActiveCfg = Release|Any CPU - {8FD59A5A-97EB-457E-B9F1-D88B0C822C6E}.Release|Any CPU.Build.0 = Release|Any CPU - {B598E5D5-BD2D-4191-8532-F2FBAC31AB81}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {B598E5D5-BD2D-4191-8532-F2FBAC31AB81}.Debug|Any CPU.Build.0 = Debug|Any CPU - {B598E5D5-BD2D-4191-8532-F2FBAC31AB81}.Release|Any CPU.ActiveCfg = Release|Any CPU - {B598E5D5-BD2D-4191-8532-F2FBAC31AB81}.Release|Any CPU.Build.0 = Release|Any CPU - {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Debug|Any CPU.Build.0 = Debug|Any CPU - {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Release|Any CPU.ActiveCfg = Release|Any CPU - {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Release|Any CPU.Build.0 = Release|Any CPU - {6F6B3382-8F87-4CD9-BF87-C81D5405685A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {6F6B3382-8F87-4CD9-BF87-C81D5405685A}.Debug|Any CPU.Build.0 = Debug|Any CPU - {6F6B3382-8F87-4CD9-BF87-C81D5405685A}.Release|Any CPU.ActiveCfg = Release|Any CPU - {6F6B3382-8F87-4CD9-BF87-C81D5405685A}.Release|Any CPU.Build.0 = Release|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D03F94CF-B283-4730-B177-21A57641061F}.Release|Any CPU.Build.0 = Release|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|Any CPU.Build.0 = Release|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|Any CPU.Build.0 = Release|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|Any CPU.Build.0 = Debug|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|Any CPU.ActiveCfg = Release|Any CPU + {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/TensorFlow.NET.sln.DotSettings b/TensorFlow.NET.sln.DotSettings new file mode 100644 index 00000000..aba8725c --- /dev/null +++ b/TensorFlow.NET.sln.DotSettings @@ -0,0 +1,2 @@ + + True \ No newline at end of file diff --git a/docs/source/Queue.md b/docs/source/Queue.md new file mode 100644 index 00000000..7f137fb3 --- /dev/null +++ b/docs/source/Queue.md @@ -0,0 +1,157 @@ +# Chapter. Queue + +ThensorFlow is capable to handle multiple threads, and queues are powerful mechanism for asynchronous computation. If we have large datasets this can significantly speed up the training process of our models. This functionality is especially handy when reading, pre-processing and extracting in mini-batches our training data. The secret to being able to do professional and high performance training of our model is understanding TensorFlow queuing operations. TensorFlow has implemented 4 types of Queue: **FIFOQueue**, **PaddingFIFOQueue**, **PriorityQueue** and **RandomShuffleQueue**. + +![FIFOQueue](_static/FIFOQueue-example.jpg) + +Like everything in TensorFlow, a queue is a node in a computation graph. It's a stateful node, like a variable: other nodes can modify its content, In particular, nodes can enqueue new items into the queue, or dequeue existing items from the queue. + +To get started with queue, let's consider a simple example. We will create a "first in, first out" queue (FIFOQueue) and fill it with numbers. Then we'll construct a graph that takes an item off the queue, adds one to that item, and puts it back on the end of the queue. + +```csharp +[TestMethod] +public void FIFOQueue() +{ + // create a first in first out queue with capacity up to 2 + // and data type set as int32 + var queue = tf.FIFOQueue(2, tf.int32); + // init queue, push 2 elements into queue. + var init = queue.enqueue_many(new[] { 10, 20 }); + // pop out the first element + var x = queue.dequeue(); + // add 1 + var y = x + 1; + // push back into queue + var inc = queue.enqueue(y); + + using (var sess = tf.Session()) + { + // init queue + init.run(); + + // pop out first element and push back calculated y + (int dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(10, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(20, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(11, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(21, dequeued); + + // thread will hang or block if you run sess.run(x) again + // until queue has more element. + } +} +``` + +`Enqueue`, `EnqueueMany` and `Dequeue` are special nodes. They take a pointer to the queue instead of a normal value, allowing them to change it. I first create a FIFOQueue *queue* of size up to 3, I enqueue two values into the *queue*. Then I immediately attempt to *dequeue* a value from it and assign it to *y* where I simply add 1 to the dequeued variable. Next, we start up a *session* and run. After we've run this operation a few times the queue will be empty - if we try and run the operation again, the main thread of the program will hang or block - this is because it will be waiting for another operation to be run to put more values in the queue. + +#### FIFOQueue + +Creates a queue that dequeues elements in a first-in first-out order. A `FIFOQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `FIFOQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. + +#### PaddingFIFOQueue + +A FIFOQueue that supports batching variable-sized tensors by padding. A `PaddingFIFOQueue` may contain components with dynamic shape, while also supporting `dequeue_many`. A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are described by the `shapes` argument. + +```chsarp +[TestMethod] +public void PaddingFIFOQueue() +{ + var numbers = tf.placeholder(tf.int32); + var queue = tf.PaddingFIFOQueue(10, tf.int32, new TensorShape(-1)); + var enqueue = queue.enqueue(numbers); + var dequeue_many = queue.dequeue_many(n: 3); + + using(var sess = tf.Session()) + { + sess.run(enqueue, (numbers, new[] { 1 })); + sess.run(enqueue, (numbers, new[] { 2, 3 })); + sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); + + var result = sess.run(dequeue_many[0]); + + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray())); + } +} +``` + + + +#### PriorityQueue + +A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. + +```csharp +[TestMethod] +public void PriorityQueue() +{ + var queue = tf.PriorityQueue(3, tf.@string); + var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); + var x = queue.dequeue(); + + using (var sess = tf.Session()) + { + init.run(); + + // output will 2, 3, 4 + var result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 2L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 3L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 4L); + } +} +``` + + + +#### RandomShuffleQueue + +A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. + +```csharp +[TestMethod] +public void RandomShuffleQueue() +{ + var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32); + var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + var x = queue.dequeue(); + + string results = ""; + using (var sess = tf.Session()) + { + init.run(); + + foreach(var i in range(9)) + results += (int)sess.run(x) + "."; + + // output in random order + // 1.2.3.4.5.6.7.8.9. + } +} +``` + + + +Queue methods must run on the same device as the queue. `FIFOQueue` and `RandomShuffleQueue` are important TensorFlow objects for computing tensor asynchronously in a graph. For example, a typical input architecture is to use a `RandomShuffleQueue` to prepare inputs for training a model: + +* Multiple threads prepare training examples and push them in the queue. +* A training thread executes a training op that dequeues mini-batches from the queue. + +This architecture simplifies the construction of input pipelines. + + + +From the above example, once the output gets to the point above you’ll actually have to terminate the program as it is blocked. Now, this isn’t very useful. What we really want to happen is for our little program to reload or enqueue more values whenever our queue is empty or is about to become empty. We could fix this by explicitly running our *enqueue_op* again in the code above to reload our queue with values. However, for large, more realistic programs, this will become unwieldy. Thankfully, TensorFlow has a solution. + +TensorFlow provides two classes to help multi-threading task: `tf.Coordinator` and `tf.QueueRunner`. There two classes are designed to be used together. The `Coordinator` class helps multiple threads stop together and report exceptions to a main thread. The `QueueRunner` class is used to create a number of threads cooperating to enqueue tensors in the same queue. diff --git a/docs/source/_static/FIFOQueue-example.jpg b/docs/source/_static/FIFOQueue-example.jpg new file mode 100644 index 00000000..ac274934 Binary files /dev/null and b/docs/source/_static/FIFOQueue-example.jpg differ diff --git a/docs/source/index.rst b/docs/source/index.rst index 20126917..61f0d752 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,7 @@ Welcome to TensorFlow.NET's documentation! Graph Session Operation + Queue Gradient Train EagerMode diff --git a/redist/SciSharp.TensorFlow-Cpu.Redist/SciSharp.TensorFlow-Cpu.Redist.csproj b/redist/SciSharp.TensorFlow-Cpu.Redist/SciSharp.TensorFlow-Cpu.Redist.csproj new file mode 100644 index 00000000..4d0fa1f0 --- /dev/null +++ b/redist/SciSharp.TensorFlow-Cpu.Redist/SciSharp.TensorFlow-Cpu.Redist.csproj @@ -0,0 +1,63 @@ + + + + netstandard2.0 + win-x64;linux-x64 + SciSharp.Tensorflow-Cpu.Redist + + SciSharp.Tensorflow-Cpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Meta-package for GPU Tensoflow library runtime distribution. + Libraries can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + + + + + + + + + + + + ../../packages;$(RestoreSources);https://api.nuget.org/v3/index.json + + + + + + + + + + + + + runtime.json + true + PreserveNewest + + + + diff --git a/redist/SciSharp.TensorFlow-Cpu.Redist/runtime.json b/redist/SciSharp.TensorFlow-Cpu.Redist/runtime.json new file mode 100644 index 00000000..a7a39cb5 --- /dev/null +++ b/redist/SciSharp.TensorFlow-Cpu.Redist/runtime.json @@ -0,0 +1,14 @@ +{ + "runtimes": { + "linux-x64": { + "SciSharp.TensorFlow-Gpu.Redist": { + "runtime.linux-x64.SciSharp.Tensorflow-Cpu.Redist": "1.0.0" + } + }, + "win-x64": { + "SciSharp.TensorFlow-Gpu.Redist": { + "runtime.win-x64.SciSharp.Tensorflow-Cpu.Redist": "1.0.0" + } + } + } +} diff --git a/redist/SciSharp.TensorFlow-Gpu.Redist/SciSharp.TensorFlow-Gpu.Redist.csproj b/redist/SciSharp.TensorFlow-Gpu.Redist/SciSharp.TensorFlow-Gpu.Redist.csproj new file mode 100644 index 00000000..61ea992e --- /dev/null +++ b/redist/SciSharp.TensorFlow-Gpu.Redist/SciSharp.TensorFlow-Gpu.Redist.csproj @@ -0,0 +1,81 @@ + + + + Library + netstandard2.0 + + win-x64;linux-x64 + SciSharp.Tensorflow-Gpu.Redist + + SciSharp.Tensorflow-Gpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Meta-package for GPU Tensoflow library runtime distribution. + Libraries can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + + + + + + + + + + + + ../../packages;$(RestoreSources);https://api.nuget.org/v3/index.json + + + + + + + + + + + + + runtime.json + true + PreserveNewest + + + + diff --git a/redist/SciSharp.TensorFlow-Gpu.Redist/runtime.json b/redist/SciSharp.TensorFlow-Gpu.Redist/runtime.json new file mode 100644 index 00000000..392dc3cc --- /dev/null +++ b/redist/SciSharp.TensorFlow-Gpu.Redist/runtime.json @@ -0,0 +1,14 @@ +{ + "runtimes": { + "linux-x64": { + "SciSharp.TensorFlow-Gpu.Redist": { + "runtime.linux-x64.SciSharp.Tensorflow-Gpu.Redist": "1.0.0" + } + }, + "win-x64": { + "SciSharp.TensorFlow-Gpu.Redist": { + "runtime.win-x64.SciSharp.Tensorflow-Gpu.Redist": "1.0.0" + } + } + } +} diff --git a/redist/TensorFlow.NET.Redist.sln b/redist/TensorFlow.NET.Redist.sln new file mode 100644 index 00000000..a21dc9dc --- /dev/null +++ b/redist/TensorFlow.NET.Redist.sln @@ -0,0 +1,60 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.29102.190 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{1E65784D-C976-4DFF-991A-DD5C57FFC8E2}" + ProjectSection(SolutionItems) = preProject + scripts\Copy-NativeTensorFlowLibs.ps1 = scripts\Copy-NativeTensorFlowLibs.ps1 + EndProjectSection +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist", "runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist\runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist.csproj", "{9834D2B4-01BF-4D18-8DCF-F498AC481FE7}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist", "runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist\runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist.csproj", "{9D853997-3143-4F87-B995-7D7024CF4E1A}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist", "runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist\runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist.csproj", "{878C1EE4-B945-41BF-98DE-C4747C28022A}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist", "runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist\runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist.csproj", "{744A3D51-CEF6-4685-B4C3-718FA61143A0}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SciSharp.TensorFlow-Cpu.Redist", "SciSharp.TensorFlow-Cpu.Redist\SciSharp.TensorFlow-Cpu.Redist.csproj", "{0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SciSharp.TensorFlow-Gpu.Redist", "SciSharp.TensorFlow-Gpu.Redist\SciSharp.TensorFlow-Gpu.Redist.csproj", "{1910BE36-82E3-4465-B3B1-788BFD252DB7}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {9834D2B4-01BF-4D18-8DCF-F498AC481FE7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9834D2B4-01BF-4D18-8DCF-F498AC481FE7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9834D2B4-01BF-4D18-8DCF-F498AC481FE7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9834D2B4-01BF-4D18-8DCF-F498AC481FE7}.Release|Any CPU.Build.0 = Release|Any CPU + {9D853997-3143-4F87-B995-7D7024CF4E1A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9D853997-3143-4F87-B995-7D7024CF4E1A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9D853997-3143-4F87-B995-7D7024CF4E1A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9D853997-3143-4F87-B995-7D7024CF4E1A}.Release|Any CPU.Build.0 = Release|Any CPU + {878C1EE4-B945-41BF-98DE-C4747C28022A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {878C1EE4-B945-41BF-98DE-C4747C28022A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {878C1EE4-B945-41BF-98DE-C4747C28022A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {878C1EE4-B945-41BF-98DE-C4747C28022A}.Release|Any CPU.Build.0 = Release|Any CPU + {744A3D51-CEF6-4685-B4C3-718FA61143A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {744A3D51-CEF6-4685-B4C3-718FA61143A0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {744A3D51-CEF6-4685-B4C3-718FA61143A0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {744A3D51-CEF6-4685-B4C3-718FA61143A0}.Release|Any CPU.Build.0 = Release|Any CPU + {0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0A281E9C-6E3D-4172-84BA-2B5F6E9F4D5B}.Release|Any CPU.Build.0 = Release|Any CPU + {1910BE36-82E3-4465-B3B1-788BFD252DB7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1910BE36-82E3-4465-B3B1-788BFD252DB7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1910BE36-82E3-4465-B3B1-788BFD252DB7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1910BE36-82E3-4465-B3B1-788BFD252DB7}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {CD7D5F34-42AE-4CCB-BDFA-1619B3A84708} + EndGlobalSection +EndGlobal diff --git a/redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist.csproj b/redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist.csproj new file mode 100644 index 00000000..ea6d4186 --- /dev/null +++ b/redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Cpu.Redist.csproj @@ -0,0 +1,39 @@ + + + + netstandard2.0 + linux-x64 + SciSharp.Tensorflow-Cpu.Redist + + runtime.linux-x64.SciSharp.Tensorflow-Cpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Distribution of the Linux CPU Tensoflow library. + The libraries can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + runtimes/$(RuntimeIdentifier)/native/%(Filename)%(Extension) + true + PreserveNewest + + + + diff --git a/redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist.csproj b/redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist.csproj new file mode 100644 index 00000000..d680f38a --- /dev/null +++ b/redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.linux-x64.SciSharp.TensorFlow-Gpu.Redist.csproj @@ -0,0 +1,40 @@ + + + + Library + netstandard2.0 + linux-x64 + SciSharp.Tensorflow-Gpu.Redist + + runtime.linux-x64.SciSharp.Tensorflow-Gpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Distribution of the Linux GPU Tensoflow library. + Dll can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + runtimes/$(RuntimeIdentifier)/native/%(Filename)%(Extension) + true + PreserveNewest + + + + \ No newline at end of file diff --git a/redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist.csproj b/redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist.csproj new file mode 100644 index 00000000..19e7854c --- /dev/null +++ b/redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist.csproj @@ -0,0 +1,39 @@ + + + + netstandard2.0 + win-x64 + SciSharp.Tensorflow-Cpu.Redist + + runtime.win-x64.SciSharp.Tensorflow-Cpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Distribution of the windows GPU Tensoflow library. + The libraries can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + runtimes/$(RuntimeIdentifier)/native/%(Filename)%(Extension) + true + PreserveNewest + + + + diff --git a/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/.gitignore b/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/.gitignore new file mode 100644 index 00000000..fca132d9 --- /dev/null +++ b/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/.gitignore @@ -0,0 +1 @@ +tensorflow.dll diff --git a/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist.csproj b/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist.csproj new file mode 100644 index 00000000..915e0e2a --- /dev/null +++ b/redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist/runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist.csproj @@ -0,0 +1,40 @@ + + + + Library + netstandard2.0 + win-x64 + SciSharp.Tensorflow-Gpu.Redist + + runtime.win-x64.SciSharp.Tensorflow-Gpu.Redist + 1.0.0 + SciSharp team + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + git + + Distribution of the windows GPU Tensoflow library. + Dll can be directly downloaded from https://storage.googleapis.com/tensorflow/libtensorflow/ + + Apache-2.0 + + https://github.com/SciSharp/TensorFlow.NET + native;tensorflow;machine-learning;ML + ../../packages + false + + false + false + false + + + + + + runtimes/$(RuntimeIdentifier)/native/%(Filename)%(Extension) + true + PreserveNewest + + + + \ No newline at end of file diff --git a/scripts/Copy-NativeTensorFlowLibs.ps1 b/scripts/Copy-NativeTensorFlowLibs.ps1 new file mode 100644 index 00000000..cf6521ae --- /dev/null +++ b/scripts/Copy-NativeTensorFlowLibs.ps1 @@ -0,0 +1,167 @@ +<# +.SYNOPSIS + Copy the native TensorFlow library to enable the packing a nuget to make + them available to TensorFlow.NET + +.DESCRIPTION + The TensorFlow libraries are copied for Windows and Linux and it becomes + possible to bundle a meta-package containing them. + +.PARAMETER SkipCpuLibraries + Setting this to true skips the downloading of the CPU version of the + TensorFlow libraries. + By default the CPU version of the libraries are downloaded and put in the + relevant projects. + +.PARAMETER SkipGpuLibraries + Setting this to tru skips the downloading of the GPU version of the + TensorFlow libraries. + By default the GPU version of the libraries are downloaded and put in the + releavant projects. + +#> +param( + [switch] $SkipCpuLibraries = $false, + [switch] $SkipGpuLibraries = $false +) + +function Expand-TarGzFiles { + <# + .SYNOPSIS + Expands the given list of files from the given archive into the given + target directory. + + .PARAMETER Archive + Path to the archive that should be considered. + + .PARAMETER Files + Files that should be extracted from the archive. + + .PARAMETER TargetDirectory + Directory into which the files should be expanded. + + #> + param + ( + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $Archive, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string []] $Files, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $TargetDirectory + ) + + & 7z e $Archive -o"$TargetDirectory" + $TarArchive = Join-Path $TargetDirectory "libtensorflow.tar" + + & 7z e $TarArchive $Files -o"$TargetDirectory" + Remove-Item $TarArchive +} + +function Expand-ZipFiles { + <# + .SYNOPSIS + Expands the given list of files from the given archive into the given target directory. + + .PARAMETER Archive + Path to the archive that should be considered. + + .PARAMETER Files + Files that should be extracted from the archive. + + .PARAMETER TargetDirectory + Directory into which the files should be expanded. + #> + param( + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $Archive, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string []] $Files, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $TargetDirectory + ) + + & 7z e $Archive $Files -o"$TargetDirectory" +} + +function Split-ArchiveFromUrl { + <# + .SYNOPSIS + Extracts the archive name out of the given Url. + + .PARAMETER ArchiveUrl + Url of the archive that will be downloaded. + + #> + param( + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] [string] $ArchiveUrl + ) + + $uriParts = $ArchiveUrl.split("/") + $ArchivePath = $uriParts[$uriParts.Count - 1] + + return $ArchivePath +} + +function Copy-Archive { + <# + .SYNOPSIS + This function copies the given binary file to the given target location. + + .PARAMETER ArchiveUrl + Url where the archive should be downloaded from. + + .PARAMETER TargetDirectory + Target directory where the archive should be downloaded. +#> + param ( + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] + [string] $ArchiveUrl, + [Parameter(Mandatory = $true, ValueFromPipeline = $true, ValueFromPipelineByPropertyName = $true)] + [string] $TargetDirectory + ) + + $ArchiveName = Split-ArchiveFromUrl $ArchiveUrl + + $TargetPath = [IO.Path]::Combine($PSScriptRoot, "..", "packages", $ArchiveName) + + if (Test-Path $TargetPath -PathType Leaf) { + Write-Error "$TargetPath already exists, please remove to download againg." + return $TargetPath + } + + if (-not (Test-Path $TargetDirectory -PathType Container)) { + Write-Host "Creating missing $TargetDirectory" + New-Item -Path $TargetDirectory -ItemType Directory + } + Write-Host "Downloading $ArchiveUrl, this might take a while..." + $wc = New-Object System.Net.WebClient + $wc.DownloadFile($ArchiveUrl, $TargetPath) + + return $TargetPath +} + +$LinuxGpuArchive = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.14.0.tar.gz" +$LinuxCpuArchive = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz" +$LinuxFiles = @(".\libtensorflow.tar", ".\lib\libtensorflow.so", ".\lib\libtensorflow.so.1", ".\lib\libtensorflow.so.1.14.0", ` + ".\lib\libtensorflow_framework.so", ".\lib\libtensorflow_framework.so.1", ".\lib\libtensorflow_framework.so.1.14.0") +$WindowsGpuArchive = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-1.14.0.zip" +$WindowsCpuArchive = "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.14.0.zip" +$WindowsFiles = @("lib\tensorflow.dll") +$PackagesDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "packages") + + +if (-not $SkipGpuLibraries) { + $Archive = Copy-Archive -ArchiveUrl $WindowsGpuArchive -TargetDirectory $PackagesDirectory + $TargetDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "redist", "runtime.win-x64.SciSharp.TensorFlow-Gpu.Redist") + Expand-ZipFiles $Archive $WindowsFiles $TargetDirectory + + $Archive = Copy-Archive -ArchiveUrl $LinuxGpuArchive -TargetDirectory $PackagesDirectory + $TargetDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "redist", "runtime.linux-x64.SciSharp.Tensorflow-Gpu.Redist") + Expand-TarGzFiles $Archive $LinuxFiles $TargetDirectory +} + +if (-not $SkipCpuLibraries) { + $Archive = Copy-Archive -ArchiveUrl $WindowsCpuArchive -TargetDirectory $PackagesDirectory + $TargetDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "redist", "runtime.win-x64.SciSharp.TensorFlow-Cpu.Redist") + Expand-ZipFiles $Archive $WindowsFiles $TargetDirectory + + $Archive = Copy-Archive -ArchiveUrl $LinuxCpuArchive -TargetDirectory $PackagesDirectory + $TargetDirectory = [IO.Path]::Combine($PSScriptRoot, "..", "redist", "runtime.linux-x64.SciSharp.Tensorflow-Cpu.Redist") + Expand-TarGzFiles $Archive $LinuxFiles $TargetDirectory +} + diff --git a/src/SciSharp.TensorFlow.Redist/README.md b/src/SciSharp.TensorFlow.Redist/README.md index 5bdf82a1..9101f422 100644 --- a/src/SciSharp.TensorFlow.Redist/README.md +++ b/src/SciSharp.TensorFlow.Redist/README.md @@ -10,7 +10,7 @@ PM> Install-Package SciSharp.TensorFlow.Redist * GPU version for Windows ```powershell -PM> Install-Package SciSharp.TensorFlow.Redist +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU ``` https://www.nuget.org/packages/SciSharp.TensorFlow.Redist @@ -21,8 +21,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5 On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. -1. Run `dotnet pack` under `src/SciSharp.TensorFlow.Redist` directory in Linux. +1. Run `dotnet pack SciSharp.TensorFlow.Redist-CPU.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. 2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.1.14.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json` - diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj index 6a225ede..5de76105 100644 --- a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj +++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj @@ -71,7 +71,7 @@ DownloadShaFile="$(BinDir)%(FileName)%(Extension).sha" ExtractDirectory="$(BinDir)%(FileName)" ExtractSemaphore="$(BinDir)%(FileName)\.extracted" - LocalShaFile="$(MSBuildProjectDirectory)\%(FileName)%(Extension).sha"/> + LocalShaFile="$(BinDir)\%(FileName)%(Extension).sha"/> diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Linux-GPU.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Linux-GPU.nupkgproj new file mode 100644 index 00000000..7f227dd5 --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Linux-GPU.nupkgproj @@ -0,0 +1,174 @@ + + + + $(MSBuildThisFileDirectory) + $(ProjDir)bin\ + $(ProjDir)obj\ + + x64 + netstandard2.0 + 1.14.0 + 1 + + $(BinDir)packages\ + $(MSBuildProjectName) + $(TensorFlowVersion) + + true + false + + Redist-Windows-GPU.nuspec + packageId=$(PackageId);version=$(PackageVersion) + $(ProjDir) + + CopyFilesFromArchive + + win + linux + osx + $(PackageRid)-$(TargetArchitecture) + + + + + false + + + + + + + + + + + + + + + + + + + + + + <_downloadFiles Include="@(TensorFlowArchive);@(AdditionalDownloadFile)" Url="%(Identity)" DestinationFile="%(DownloadFile)" /> + + + + + + + + + + + + + + + + + + + + + + @(FilesWithHashes->'%(FileHash)') + $([System.IO.File]::ReadAllText('%(LocalShaFile)').Replace("%0A", "").Replace("%0D", "")) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + <_fileFromArchive Include="%(TensorFlowArchive.FilesFromArchive)" ExtractDirectory="%(TensorFlowArchive.ExtractDirectory)" Runtime="%(TensorFlowArchive.Runtime)" /> + <_fileFromArchive DestinationFile="%(FileName)%(Extension)"/> + <_fileFromArchive PackagePath="runtimes\%(_fileFromArchive.Runtime)\native\%(_fileFromArchive.DestinationFile)" /> + + + <_fileFromArchive Condition="'%(DestinationFile)' == 'LICENSE'" PackagePath="THIRD_PARTY_NOTICES.txt" Runtime="" /> + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj index 08fd9386..e2b101fa 100644 --- a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj +++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj @@ -41,21 +41,6 @@ include\tensorflow\c\LICENSE" Runtime="win-x64"/> - - - - @@ -71,7 +56,7 @@ DownloadShaFile="$(BinDir)%(FileName)%(Extension).sha" ExtractDirectory="$(BinDir)%(FileName)" ExtractSemaphore="$(BinDir)%(FileName)\.extracted" - LocalShaFile="$(MSBuildProjectDirectory)\%(FileName)%(Extension).sha"/> + LocalShaFile="$(BinDir)\%(FileName)%(Extension).sha"/> diff --git a/src/SciSharp.TensorFlow.Redist/libtensorflow-cpu-darwin-x86_64-1.14.0.tar.gz.sha b/src/SciSharp.TensorFlow.Redist/libtensorflow-cpu-darwin-x86_64-1.14.0.tar.gz.sha deleted file mode 100644 index 951c4556..00000000 --- a/src/SciSharp.TensorFlow.Redist/libtensorflow-cpu-darwin-x86_64-1.14.0.tar.gz.sha +++ /dev/null @@ -1 +0,0 @@ -7002EF701BD23C5EF5FF94192E935F0DDF960A21BE2531CEE158586830C00E0BA889900F7F6E8AB568BEE0ACF1F5A6A246BB43D11C4109E9DC782B46377D8142 \ No newline at end of file diff --git a/src/SciSharp.TensorFlow.Redist/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz.sha b/src/SciSharp.TensorFlow.Redist/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz.sha deleted file mode 100644 index 784640a0..00000000 --- a/src/SciSharp.TensorFlow.Redist/libtensorflow-cpu-linux-x86_64-1.14.0.tar.gz.sha +++ /dev/null @@ -1 +0,0 @@ -E3F6D0309117E9E45780ECF8BC4D0268B3FC9F12E3E38FFE58496789330A4ACD2DC8FF721F3B8900357F6155F8A54000E45B99495F823486B558E8B42532392D \ No newline at end of file diff --git a/src/SciSharp.TensorFlow.Redist/libtensorflow-cpu-windows-x86_64-1.14.0.zip.sha b/src/SciSharp.TensorFlow.Redist/libtensorflow-cpu-windows-x86_64-1.14.0.zip.sha deleted file mode 100644 index b7d6402c..00000000 --- a/src/SciSharp.TensorFlow.Redist/libtensorflow-cpu-windows-x86_64-1.14.0.zip.sha +++ /dev/null @@ -1 +0,0 @@ -59A2B80B441439B851202358CE4A65BA0DDDB319A8A29E87B135DCD9954BC5B0628F2C0C8E72D6942EA3CDCE172805C2BD5421815B3D0210B62BC0936DC59A08 \ No newline at end of file diff --git a/src/SciSharp.TensorFlow.Redist/libtensorflow-gpu-windows-x86_64-1.14.0.zip.sha b/src/SciSharp.TensorFlow.Redist/libtensorflow-gpu-windows-x86_64-1.14.0.zip.sha deleted file mode 100644 index 739129b1..00000000 --- a/src/SciSharp.TensorFlow.Redist/libtensorflow-gpu-windows-x86_64-1.14.0.zip.sha +++ /dev/null @@ -1 +0,0 @@ -850A27858FA951DF77A78CD1BD78B54F6EE2532DD5A49F0579A7B02C795C62F0212F20177EAEA2BD77BD451A57FBBD1348362492F9E14BFE5CA5028C71711293 diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index adf0b86f..56672173 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -54,6 +54,15 @@ namespace Tensorflow public struct DeallocatorArgs { + internal static unsafe c_api.DeallocatorArgs* EmptyPtr; + internal static unsafe IntPtr Empty; + + static unsafe DeallocatorArgs() + { + Empty = new IntPtr(EmptyPtr = (DeallocatorArgs*) Marshal.AllocHGlobal(Marshal.SizeOf())); + *EmptyPtr = new DeallocatorArgs() {gc_handle = IntPtr.Zero, deallocator_called = false}; + } + public bool deallocator_called; public IntPtr gc_handle; } diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 06d48555..3a674e83 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using System.Collections.Generic; using System.Linq; @@ -22,6 +23,39 @@ namespace Tensorflow { public partial class tensorflow { + /// + /// A convenient alias for None, useful for indexing arrays. + /// + public Slice newaxis = Slice.NewAxis; + + /// + /// BatchToSpace for N-D tensors of type T. + /// + /// + /// + /// + /// + /// + /// + public Tensor batch_to_space_nd(T input, int[] block_shape, int[,] crops, string name = null) + => gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name); + + /// + /// Apply boolean mask to tensor. + /// + /// + /// + /// N-D tensor. + /// K-D boolean tensor, K <= N and K must be known statically. + /// + /// A 0-D int Tensor representing the axis in tensor to mask from. + /// (N-K+1)-dimensional tensor populated by entries in tensor corresponding to True values in mask. + public Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) + => array_ops.boolean_mask(tensor, mask, name: name, axis: axis); + + public Tensor check_numerics(Tensor tensor, string message, string name = null) + => gen_array_ops.check_numerics(tensor, message, name: name); + /// /// Concatenates tensors along one dimension. /// @@ -61,6 +95,26 @@ namespace Tensorflow public Tensor fill(Tensor dims, T value, string name = null) => gen_array_ops.fill(dims, value, name: name); + /// + /// Return a tensor with the same shape and contents as input. + /// + /// + /// + /// + public static Tensor identity(Tensor input, string name = null) + => array_ops.identity(input, name: name); + + /// + /// Gather slices from params axis axis according to indices. + /// + /// + /// + /// + /// + /// + public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) + => array_ops.gather(@params, indices, name: name, axis: axis); + /// /// Return the elements, either from `x` or `y`, depending on the `condition`. /// @@ -79,6 +133,39 @@ namespace Tensorflow public Tensor transpose(T1 a, int[] perm = null, string name = "transpose", bool conjugate = false) => array_ops.transpose(a, perm, name, conjugate); + /// + /// Reverses specific dimensions of a tensor. + /// + /// + /// + /// + /// + public static Tensor reverse(Tensor tensor, int[] axis, string name = null) + => gen_array_ops.reverse(tensor, axis, name: name); + + public static Tensor reverse(Tensor tensor, Tensor axis, string name = null) + => gen_array_ops.reverse(tensor, axis, name: name); + + /// + /// Returns the rank of a tensor. + /// + /// + /// + /// Returns a 0-D `int32` `Tensor` representing the rank of `input`. + public Tensor rank(Tensor input, string name = null) + => array_ops.rank(input, name: name); + + /// + /// Extracts a slice from a tensor. + /// + /// A `Tensor`. + /// An `int32` or `int64` `Tensor`. + /// An `int32` or `int64` `Tensor`. + /// A name for the operation (optional). + /// A `Tensor` the same type as `input`. + public Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + => array_ops.slice(input, begin, size, name: name); + public Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) => gen_array_ops.squeeze(input, axis, name); @@ -92,6 +179,20 @@ namespace Tensorflow public Tensor stack(object values, int axis = 0, string name = "stack") => array_ops.stack(values, axis, name: name); + /// + /// Creates a tensor with all elements set to 1. + /// + /// + /// + /// A name for the operation (optional). + /// + /// if true, attempt to statically determine the shape of 'tensor' and + /// encode it as a constant. + /// + /// A `Tensor` with all elements set to 1. + public Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.ones_like(tensor, dtype: dtype, name: name, optimize: optimize); + public Tensor one_hot(Tensor indices, int depth, Tensor on_value = null, Tensor off_value = null, @@ -99,6 +200,18 @@ namespace Tensorflow int axis = -1, string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name); + /// + /// Pads a tensor + /// + /// + /// + /// + /// + /// + /// + public Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0) + => array_ops.pad(tensor, paddings, mode: mode, name: name, constant_values: constant_values); + /// /// A placeholder op that passes through `input` when its output is not fed. /// @@ -112,5 +225,47 @@ namespace Tensorflow /// A `Tensor`. Has the same type as `input`. public Tensor placeholder_with_default(T input, int[] shape, string name = null) => gen_array_ops.placeholder_with_default(input, shape, name: name); + + /// + /// Returns the shape of a tensor. + /// + /// + /// + /// + /// + public Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) + => array_ops.shape_internal(input, name, optimize: true, out_type: out_type); + + /// + /// Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. + /// + /// + /// + /// + /// A stacked `Tensor` with the same type as `values`. + public Tensor stack(Tensor[] values, int axis = 0, string name = "stack") + => array_ops.stack(values, axis: axis, name: name); + + /// + /// Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors. + /// + /// + /// + /// + /// + /// + public Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") + => array_ops.unstack(value, num: num, axis: axis, name: name); + + /// + /// Creates a tensor with all elements set to zero. + /// + /// + /// + /// + /// + /// A `Tensor` with all elements set to zero. + public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.control.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs similarity index 69% rename from src/TensorFlowNET.Core/APIs/tf.control.cs rename to src/TensorFlowNET.Core/APIs/tf.control_flow.cs index dcccb6fe..6ed475a9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.control.cs +++ b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs @@ -20,6 +20,23 @@ namespace Tensorflow { public partial class tensorflow { + public Tensor cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, + bool strict = false, + string name = null) + => control_flow_ops.cond(pred, true_fn, false_fn, strict: strict, name: name); + + /// + /// Create an op that groups multiple operations. + /// + /// + /// + /// + /// An Operation that executes all its inputs. + public Operation group(T[] inputs, string name = null) where T : ITensorOrOperation + => control_flow_ops.group(inputs, name: name); + public Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, TensorShape shape_invariants = null, int parallel_iterations = 10, @@ -37,7 +54,7 @@ namespace Tensorflow maximum_iterations: maximum_iterations, return_same_structure: return_same_structure); - public _ControlDependenciesController control_dependencies(Operation[] control_inputs) + public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) => ops.control_dependencies(control_inputs); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.data_flow.cs b/src/TensorFlowNET.Core/APIs/tf.data_flow.cs new file mode 100644 index 00000000..593596ff --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.data_flow.cs @@ -0,0 +1,33 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// Interleave the values from the data tensors into a single tensor. + /// + /// + /// + /// + /// + public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null) + => gen_data_flow_ops.dynamic_stitch(indices, data, name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.debugging.cs b/src/TensorFlowNET.Core/APIs/tf.debugging.cs new file mode 100644 index 00000000..8e220594 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.debugging.cs @@ -0,0 +1,44 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// Assert the condition `x == y` holds element-wise. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor assert_equal(T1 t1, + T2 t2, + object[] data = null, + string message = null, + string name = null) + => check_ops.assert_equal(t1, + t2, + data: data, + message: message, + name: name); + + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.estimator.cs b/src/TensorFlowNET.Core/APIs/tf.estimator.cs new file mode 100644 index 00000000..3cabfdf4 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.estimator.cs @@ -0,0 +1,57 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using static Tensorflow.Binding; +using Tensorflow.Estimators; +using Tensorflow.Data; + +namespace Tensorflow +{ + public partial class tensorflow + { + public Estimator_Internal estimator { get; } = new Estimator_Internal(); + + public class Estimator_Internal + { + public Estimator Estimator(Action model_fn, RunConfig config) + => new Estimator(model_fn: model_fn, config: config); + + public RunConfig RunConfig(string model_dir) + => new RunConfig(model_dir: model_dir); + + public void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) + => Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); + + public TrainSpec TrainSpec(Func input_fn, int max_steps) + => new TrainSpec(input_fn: input_fn, max_steps: max_steps); + + /// + /// Create an `Exporter` to use with `tf.estimator.EvalSpec`. + /// + /// + /// + /// + /// + public FinalExporter FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false) + => new FinalExporter(name: name, serving_input_receiver_fn: serving_input_receiver_fn, + as_text: as_text); + + public EvalSpec EvalSpec(string name, Action input_fn, FinalExporter exporters) + => new EvalSpec(name: name, input_fn: input_fn, exporters: exporters); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs index cee941ed..1648cb70 100644 --- a/src/TensorFlowNET.Core/APIs/tf.graph.cs +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -29,7 +29,19 @@ namespace Tensorflow return ops.get_default_graph(); } - public Graph Graph() + /// + /// Equivalent to but does not create a new graph if it there is none. + /// + public Graph peak_default_graph() + { + return ops.default_graph_stack.peak_controller(); + } + + /// + /// Creates a new graph. + /// + ///Has no interaction with graph defaulting. Equivalent to new Graph(); + public Graph Graph() => new Graph(); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs index e2e3206b..57b8b093 100644 --- a/src/TensorFlowNET.Core/APIs/tf.image.cs +++ b/src/TensorFlowNET.Core/APIs/tf.image.cs @@ -40,6 +40,11 @@ namespace Tensorflow public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) => gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, name: name); + public Tensor resize_images(Tensor images, Tensor size, ResizeMethod method = ResizeMethod.BILINEAR, + bool align_corners = false, bool preserve_aspect_ratio = false, string name = null) + => image_ops_impl.resize_images(images, size, method: method, + align_corners: align_corners, preserve_aspect_ratio: preserve_aspect_ratio, name: name); + public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name = null) => gen_image_ops.convert_image_dtype(image, dtype, saturate: saturate, name: name); @@ -54,8 +59,22 @@ namespace Tensorflow /// /// /// - public static Tensor is_jpeg(Tensor contents, string name = null) + public Tensor is_jpeg(Tensor contents, string name = null) => image_ops_impl.is_jpeg(contents, name: name); + + /// + /// Resize `images` to `size` using nearest neighbor interpolation. + /// + /// + /// + /// + /// + /// + /// + public Tensor resize_nearest_neighbor(Tensor images, Tsize size, bool align_corners = false, + string name = null, bool half_pixel_centers = false) + => image_ops_impl.resize_nearest_neighbor(images, size, align_corners: align_corners, + name: name, half_pixel_centers: half_pixel_centers); } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index c9294653..15bcd766 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -20,6 +20,8 @@ namespace Tensorflow { public partial class tensorflow { + public IInitializer constant_initializer(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) + => new Constant(value, dtype: dtype, verify_shape: verify_shape); public IInitializer zeros_initializer => new Zeros(); public IInitializer ones_initializer => new Ones(); public IInitializer glorot_uniform_initializer => new GlorotUniform(); @@ -60,5 +62,25 @@ namespace Tensorflow stddev: stddev, seed: seed, dtype: dtype); + + /// + /// Initializer capable of adapting its scale to the shape of weights tensors. + /// + /// + /// + /// + /// + /// + /// + public IInitializer variance_scaling_initializer(float scale = 1.0f, + string mode = "fan_in", + string distribution = "truncated_normal", + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling( + scale: scale, + mode: mode, + distribution: distribution, + seed: seed, + dtype: dtype); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 53ddcf13..9f989bc5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -14,6 +14,9 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; +using System.Linq; +using NumSharp; using Tensorflow.Keras.Layers; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -143,6 +146,20 @@ namespace Tensorflow return layer.apply(inputs); } + /// + /// Densely-connected layer class. aka fully-connected

+ /// `outputs = activation(inputs * kernel + bias)` + ///
+ /// + /// Python integer, dimensionality of the output space. + /// + /// Boolean, whether the layer uses a bias. + /// + /// + /// + /// + /// + /// public Tensor dense(Tensor inputs, int units, IActivation activation = null, @@ -159,10 +176,60 @@ namespace Tensorflow var layer = new Dense(units, activation, use_bias: use_bias, bias_initializer: bias_initializer, - kernel_initializer: kernel_initializer); + kernel_initializer: kernel_initializer, + trainable: trainable); return layer.apply(inputs); } + + /// + /// Flattens an input tensor while preserving the batch axis (axis 0). + /// + /// Tensor input. + /// The name of the layer. + /// + /// A string, one of `channels_last` (default) or `channels_first`.

+ /// The ordering of the dimensions in the inputs.

+ /// `channels_last` corresponds to inputs with shape

+ /// `(batch, height, width, channels)` while `channels_first` corresponds to

+ /// inputs with shape `(batch, channels, height, width)`. + /// + /// + public Tensor flatten(Tensor inputs, + string name = null, + string data_format = "channels_last") + { + var input_shape = inputs.shape; + if (inputs.shape.Length == 0) + throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()"); + + var premutation = new List() {0}; + if (data_format == "channels_first" && inputs.NDims > 1) + { + premutation.AddRange(Binding.range(2, inputs.NDims)); + premutation.Add(1); + inputs = array_ops.transpose(inputs, premutation.ToArray()); + } + + var ret = array_ops.reshape(inputs, compute_output_shape(input_shape)); + //ret.set_shape(compute_output_shape(ret.shape)); + return ret; + + int[] compute_output_shape(int[] inputshape) + { + if (inputshape == null || inputshape.Length == 0) + inputshape = new int[] {1}; + + if (inputshape.Skip(1).All(d => d > 0)) + { + int[] output_shape = new int[2]; + output_shape[0] = inputshape[0]; + output_shape[1] = inputshape.Skip(1).Aggregate(1, (acc, rhs) => acc*rhs); //calculate size of all the rest dimensions + return output_shape; + } else + return new int[] {inputshape[0], -1}; //-1 == Binding.None + } + } } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index ddfa71ec..cba05d0e 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Operations; + namespace Tensorflow { public partial class tensorflow @@ -42,6 +44,15 @@ namespace Tensorflow public Tensor add(Tx a, Ty b, string name = null) => gen_math_ops.add(a, b, name: name); + /// + /// Adds all input tensors element-wise. + /// + /// + /// + /// A `Tensor` of same shape and type as the elements of `inputs`. + public Tensor add_n(Tensor[] inputs, string name = null) + => math_ops.add_n(inputs, name: name); + /// /// Computes atan of x element-wise. /// @@ -211,6 +222,36 @@ namespace Tensorflow /// public Tensor _clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null) => gen_math_ops._clip_by_value(t, clip_value_min, clip_value_max); + + /// + /// Clips tensor values to a specified min and max. + /// + /// + /// A Tensor. + /// + /// + /// A 0-D (scalar) Tensor, or a Tensor with the same shape + /// as t. The minimum value to clip by. + /// + /// + /// A 0-D (scalar) Tensor, or a Tensor with the same shape + /// as t. The maximum value to clip by. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'ClipByValue'. + /// + /// + /// A clipped Tensor with the same shape as input 't'. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// Given a tensor t, this operation returns a tensor of the same type and + /// shape as t with its values clipped to clip_value_min and clip_value_max. + /// Any values less than clip_value_min are set to clip_value_min. Any values + /// greater than clip_value_max are set to clip_value_max. + /// + public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue") + => gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name); public Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b); @@ -299,6 +340,16 @@ namespace Tensorflow public Tensor negative(Tensor x, string name = null) => gen_math_ops.neg(x, name); + /// + /// Returns the truth value of (x != y) element-wise. + /// + /// + /// + /// + /// A `Tensor` of type bool with the same size as that of x or y. + public Tensor not_equal(Tx x, Ty y, string name = null) + => math_ops.not_equal(x, y, name: name); + /// /// Divides x / y elementwise (using Python 2 division operator semantics). /// @@ -315,21 +366,84 @@ namespace Tensorflow public Tensor pow(T1 x, T2 y) => gen_math_ops.pow(x, y); + /// + /// Divides `x / y` elementwise, rounding toward the most negative integer. + /// + /// + /// + /// + /// `x / y` rounded down. + public Tensor floordiv(Tensor x, Tensor y, string name = null) + => math_ops.floordiv(x, y, name: name); + + /// + /// Divides x / y elementwise (using Python 3 division operator semantics). + /// + /// + /// + /// + /// `x / y` evaluated in floating point. + public static Tensor truediv(Tensor x, Tensor y, string name = null) + => math_ops.truediv(x, y, name: name); + + public Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") + => math_ops.range(start, limit: limit, delta: delta, dtype: dtype, name: name); + + /// + /// Computes the "logical or" of elements across dimensions of a tensor. + /// + /// The boolean tensor to reduce. + /// The dimensions to reduce. + /// If true, retains reduced dimensions with length 1. + /// + /// The reduced tensor. + public Tensor reduce_any(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_any(input_tensor, axis: axis, keepdims: keepdims, name: name); + + public Tensor reduce_any(Tensor input_tensor, int axis = 0, bool keepdims = false, string name = null) + => math_ops.reduce_any(input_tensor, axis: new[] { axis }, keepdims: keepdims, name: name); + + /// + /// Computes the "logical and" of elements across dimensions of a tensor. + /// + /// + /// + /// + /// + /// The reduced tensor. + public Tensor reduce_all(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_all(input_tensor, axis: axis, keepdims: keepdims, name: name); + + /// + /// Computes the product of elements across dimensions of a tensor. + /// + /// + /// + /// + /// + /// + public Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + => math_ops.reduce_prod(input_tensor, axis: axis, keepdims: keepdims, name: name); + /// /// Computes the sum of elements across dimensions of a tensor. /// /// /// /// - public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null) + public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, + bool keepdims = false, string name = null) { if(!axis.HasValue && reduction_indices.HasValue) return math_ops.reduce_sum(input, reduction_indices.Value); - return math_ops.reduce_sum(input); + else if (axis.HasValue && !reduction_indices.HasValue) + return math_ops.reduce_sum(input, axis.Value); + return math_ops.reduce_sum(input, keepdims: keepdims, name: name); } - public Tensor reduce_sum(Tensor input, int axis, int? reduction_indices = null) - => math_ops.reduce_sum(input, axis); + public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null, + bool keepdims = false, string name = null) + => math_ops.reduce_sum(input, axis, keepdims: keepdims, name: name); /// /// Computes the maximum of elements across dimensions of a tensor. diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 01789617..ea52ab57 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -168,6 +168,9 @@ namespace Tensorflow public rnn_cell_impl rnn_cell => new rnn_cell_impl(); + public Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null) + => nn_impl.sigmoid_cross_entropy_with_logits(labels: labels, logits: logits, name: name); + public Tensor softmax(Tensor logits, int axis = -1, string name = null) => gen_nn_ops.softmax(logits, name); diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs index ec533af4..571d57b2 100644 --- a/src/TensorFlowNET.Core/APIs/tf.ops.cs +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -14,15 +14,36 @@ limitations under the License. ******************************************************************************/ +using System; +using System.Collections.Generic; + namespace Tensorflow { public partial class tensorflow { + public void add_to_collection(string name, T value) + => get_default_graph().add_to_collection(name, value); + + public void add_to_collections(List names, T value) + => get_default_graph().add_to_collections(names, value); + public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) => state_ops.assign(@ref, value, validate_shape, use_locking, name); - public object get_collection(string key, string scope = "") - => get_default_graph().get_collection(key, scope: scope); + public Tensor assign(RefVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null) + => state_ops.assign(@ref, value, validate_shape, use_locking, name); + + public void device(string device_name) + => get_default_graph().device(device_name); + + public List get_collection(string key, string scope = "") + => get_default_graph().get_collection(key, scope: scope); + + /// + /// A context manager that lifts ops out of control-flow scopes and function-building graphs. + /// + public void init_scope() + => ops.init_scope(); /// /// Returns a context manager that creates hierarchical names for operations. @@ -39,7 +60,36 @@ namespace Tensorflow /// /// /// - public Tensor no_op(string name = null) + public Operation no_op(string name = null) => gen_control_flow_ops.no_op(name: name); + + /// + /// map on the list of tensors unpacked from `elems` on dimension 0. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A tensor or (possibly nested) sequence of tensors. + public Tensor map_fn(Func fn, + Tensor elems, + TF_DataType dtype = TF_DataType.DtInvalid, + int parallel_iterations = -1, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + string name = null) + => Operation.map_fn(fn, + elems, + dtype, + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory, + infer_shape: infer_shape, + name: name); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.queue.cs b/src/TensorFlowNET.Core/APIs/tf.queue.cs new file mode 100644 index 00000000..91947e5b --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.queue.cs @@ -0,0 +1,127 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Queues; + +namespace Tensorflow +{ + public partial class tensorflow + { + /// + /// A FIFOQueue that supports batching variable-sized tensors by padding. + /// + /// + /// + /// + /// + /// + /// + /// + public PaddingFIFOQueue PaddingFIFOQueue(int capacity, + TF_DataType[] dtypes, + TensorShape[] shapes, + string[] names = null, + string shared_name = null, + string name = "padding_fifo_queue") + => new PaddingFIFOQueue(capacity, + dtypes, + shapes, + names, + shared_name: shared_name, + name: name); + + public PaddingFIFOQueue PaddingFIFOQueue(int capacity, + TF_DataType dtype, + TensorShape shape, + string shared_name = null, + string name = "padding_fifo_queue") + => new PaddingFIFOQueue(capacity, + new[] { dtype }, + new[] { shape }, + shared_name: shared_name, + name: name); + + /// + /// A queue implementation that dequeues elements in first-in first-out order. + /// + /// + /// + /// + /// + /// + /// + /// + public FIFOQueue FIFOQueue(int capacity, + TF_DataType[] dtypes, + TensorShape[] shapes = null, + string[] names = null, + string shared_name = null, + string name = "fifo_queue") + => new FIFOQueue(capacity, + dtypes, + shapes, + names, + shared_name: shared_name, + name: name); + + public FIFOQueue FIFOQueue(int capacity, + TF_DataType dtype, + TensorShape shape = null, + string shared_name = null, + string name = "fifo_queue") + => new FIFOQueue(capacity, + new[] { dtype }, + new[] { shape ?? new TensorShape() }, + shared_name: shared_name, + name: name); + + /// + /// Creates a queue that dequeues elements in a first-in first-out order. + /// + /// + /// + /// + /// + /// + /// + public PriorityQueue PriorityQueue(int capacity, + TF_DataType dtype, + TensorShape shape = null, + string shared_name = null, + string name = "priority_queue") + => new PriorityQueue(capacity, + new[] { dtype }, + new[] { shape ?? new TensorShape() }, + shared_name: shared_name, + name: name); + + public RandomShuffleQueue RandomShuffleQueue(int capacity, + int min_after_dequeue, + TF_DataType dtype, + TensorShape shape = null, + int? seed = null, + string shared_name = null, + string name = "random_shuffle_queue") + => new RandomShuffleQueue(capacity, + min_after_dequeue: min_after_dequeue, + new[] { dtype }, + new[] { shape ?? new TensorShape() }, + seed: seed, + shared_name: shared_name, + name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index ae115872..c331eb7f 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -49,5 +49,18 @@ namespace Tensorflow int? seed = null, string name = null) => random_ops.truncated_normal(shape, mean, stddev, dtype, seed, name); + + /// + /// Randomly shuffles a tensor along its first dimension. + /// + /// + /// + /// + /// + /// A tensor of same shape and type as value, shuffled along its + /// first dimension. + /// + public Tensor random_shuffle(Tensor value, int? seed = null, string name = null) + => random_ops.random_shuffle(value, seed: seed, name: name); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs index 78a00432..b6924709 100644 --- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs +++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs @@ -18,8 +18,8 @@ namespace Tensorflow { public partial class tensorflow { - public Tensor reshape(Tensor tensor, - Tensor shape, + public Tensor reshape(T1 tensor, + T2 shape, string name = null) => gen_array_ops.reshape(tensor, shape, name); public Tensor reshape(Tensor tensor, diff --git a/src/TensorFlowNET.Core/APIs/tf.sparse.cs b/src/TensorFlowNET.Core/APIs/tf.sparse.cs new file mode 100644 index 00000000..c615a614 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.sparse.cs @@ -0,0 +1,61 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using Tensorflow.Framework; + +namespace Tensorflow +{ + public partial class tensorflow + { + public SparseTensor SparseTensor(long[,] indices, T[] values, long[] dense_shape) + => new SparseTensor(indices, values, dense_shape); + + public Tensor sparse_tensor_to_dense(SparseTensor sp_input, + T default_value = default, + bool validate_indices = true, + string name = null) + => gen_sparse_ops.sparse_to_dense(sp_input.indices, + sp_input.dense_shape, + sp_input.values, + default_value: default_value, + validate_indices: validate_indices, + name: name); + + /// + /// Converts a sparse representation into a dense tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// Dense `Tensor` of shape `output_shape`. Has the same type as `sparse_values`. + public Tensor sparse_to_dense(Tensor sparse_indices, + TensorShape output_shape, + T sparse_values, + T default_value = default, + bool validate_indices = true, + string name = null) + => gen_sparse_ops.sparse_to_dense(sparse_indices, + output_shape, + sparse_values, + default_value: default_value, + validate_indices: validate_indices, + name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.state.cs b/src/TensorFlowNET.Core/APIs/tf.state.cs new file mode 100644 index 00000000..c57d03c6 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.state.cs @@ -0,0 +1,25 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow +{ + public partial class tensorflow + { + public Tensor assign_add(RefVariable @ref, T value, + bool use_locking = false, string name = null) + => state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs index b553095e..8ba78f42 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs @@ -18,8 +18,8 @@ namespace Tensorflow { public partial class tensorflow { - public Tensor convert_to_tensor(object value, - string name = null) => ops.convert_to_tensor(value, name: name); + public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid) + => ops.convert_to_tensor(value, dtype, name, preferred_dtype); public Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides = null, int begin_mask = 0, @@ -54,5 +54,23 @@ namespace Tensorflow new_axis_mask: new_axis_mask, shrink_axis_mask: shrink_axis_mask, name: name); + + /// + /// Splits a tensor into sub tensors. + /// + /// The Tensor to split. + /// Either an integer indicating the number of splits along split_dim or a 1-D integer + /// Tensor or Python list containing the sizes of each output tensor along split_dim. + /// If a scalar then it must evenly divide value.shape[axis]; otherwise the sum of sizes along the split dimension must match that of the value. + /// An integer or scalar int32 Tensor. The dimension along which to split. Must be in the range [-rank(value), rank(value)). Defaults to 0. + /// A name for the operation (optional) + /// if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; + /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value. + public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) => gen_array_ops.split( + value: value, + axis: axis, + num_split: num_split, + name: name + ); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs index 21017a17..0995dc27 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tile.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs @@ -20,12 +20,8 @@ namespace Tensorflow { public partial class tensorflow { - public Tensor tile(Tensor input, - Tensor multiples, + public Tensor tile(Tensor input, + T multiples, string name = null) => gen_array_ops.tile(input, multiples, name); - public Tensor tile(NDArray input, - int[] multiples, - string name = null) => gen_array_ops.tile(input, multiples, name); - } } diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index a943308b..3a790327 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -25,16 +25,26 @@ namespace Tensorflow public class train_internal { + public RefVariable create_global_step(Graph graph) + => TrainingUtil.create_global_step(graph); + + public RefVariable get_global_step(Graph graph) + => TrainingUtil.get_global_step(graph); + public Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate); public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") => new AdamOptimizer(learning_rate, name: name); + public Optimizer AdamOptimizer(Tensor learning_rate, string name = "Adam") + => new AdamOptimizer(learning_rate, name: name); + public ExponentialMovingAverage ExponentialMovingAverage(float decay) => new ExponentialMovingAverage(decay); - public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); + public Saver Saver(VariableV1[] var_list = null, int max_to_keep = 5) + => new Saver(var_list: var_list, max_to_keep: max_to_keep); public string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); @@ -45,7 +55,7 @@ namespace Tensorflow clear_devices, import_scope).Item1; - public (MetaGraphDef, Dictionary) export_meta_graph(string filename = "", + public (MetaGraphDef, Dictionary) export_meta_graph(string filename = "", bool as_text = false, bool clear_devices = false, bool clear_extraneous_savers = false, @@ -54,6 +64,12 @@ namespace Tensorflow clear_devices: clear_devices, clear_extraneous_savers: clear_extraneous_savers, strip_default_attrs: strip_default_attrs); + + public string latest_checkpoint(string checkpoint_dir, string latest_filename = null) + => checkpoint_management.latest_checkpoint(checkpoint_dir, latest_filename: latest_filename); + + public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) + => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index b3c5bf43..179cedee 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -46,6 +46,7 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, object initializer = null, // IInitializer or Tensor bool? trainable = null, + List collections = null, bool? use_resource = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, @@ -60,7 +61,11 @@ namespace Tensorflow use_resource: use_resource, validate_shape: validate_shape, initializer: initializer, - trainable: trainable); + trainable: trainable, + collections: collections); } + + public VariableScope get_variable_scope() + => Tensorflow.variable_scope.get_variable_scope(); } } diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index d723283f..dcf191ed 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -21,6 +21,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; using System.Linq; +using NumSharp.Utilities; namespace Tensorflow { @@ -29,9 +30,37 @@ namespace Tensorflow /// public static partial class Binding { + private static string _tostring(object obj) + { + switch (obj) + { + case NDArray nd: + return nd.ToString(false); + case Array arr: + if (arr.Rank!=1 || arr.GetType().GetElementType()?.IsArray == true) + arr = Arrays.Flatten(arr); + var objs = toObjectArray(arr); + return $"[{string.Join(", ", objs.Select(_tostring))}]"; + default: + return obj?.ToString() ?? "null"; + } + + object[] toObjectArray(Array arr) + { + var len = arr.LongLength; + var ret = new object[len]; + for (long i = 0; i < len; i++) + { + ret[i] = arr.GetValue(i); + } + + return ret; + } + } + public static void print(object obj) { - Console.WriteLine(obj.ToString()); + Console.WriteLine(_tostring(obj)); } public static int len(object a) @@ -77,11 +106,6 @@ namespace Tensorflow py.__enter__(); action(py); } - catch (Exception ex) - { - Console.WriteLine(ex.ToString()); - throw; - } finally { py.__exit__(); @@ -97,11 +121,6 @@ namespace Tensorflow py.__enter__(); action(py); } - catch (Exception ex) - { - Console.WriteLine(ex.ToString()); - throw; - } finally { py.__exit__(); @@ -117,11 +136,6 @@ namespace Tensorflow py.__enter__(); return action(py); } - catch (Exception ex) - { - Console.WriteLine(ex.ToString()); - return default(TOut); - } finally { py.__exit__(); @@ -139,7 +153,7 @@ namespace Tensorflow { var a = t1.AsIterator(); var b = t2.AsIterator(); - while (a.HasNext()) + while (a.HasNext() && b.HasNext()) yield return (a.MoveNext(), b.MoveNext()); } @@ -155,19 +169,13 @@ namespace Tensorflow { var a = t1.AsIterator(); var b = t2.AsIterator(); - while(a.HasNext()) + while(a.HasNext() && b.HasNext()) yield return (a.MoveNext(), b.MoveNext()); } public static IEnumerable<(T1, T2)> zip(IEnumerable e1, IEnumerable e2) { - var iter2 = e2.GetEnumerator(); - foreach (var v1 in e1) - { - iter2.MoveNext(); - var v2 = iter2.Current; - yield return (v1, v2); - } + return e1.Zip(e2, (t1, t2) => (t1, t2)); } public static IEnumerable<(TKey, TValue)> enumerate(Dictionary values) @@ -277,39 +285,6 @@ namespace Tensorflow return (__memberobject__.Length > 0) ? true : false; } - public delegate object __object__(params object[] args); - - public static __object__ getattr(object obj, string key, params Type[] ___parameter_type__) - { - var __dyn_obj__ = obj.GetType().GetMember(key); - if (__dyn_obj__.Length == 0) - throw new Exception("The object \"" + nameof(obj) + "\" doesnot have a defination \"" + key + "\""); - var __type__ = __dyn_obj__[0]; - if (__type__.MemberType == System.Reflection.MemberTypes.Method) - { - try - { - var __method__ = (___parameter_type__.Length > 0) ? obj.GetType().GetMethod(key, ___parameter_type__) : obj.GetType().GetMethod(key); - return (object[] args) => __method__.Invoke(obj, args); - } - catch (System.Reflection.AmbiguousMatchException ex) - { - throw new Exception("AmbigousFunctionMatchFound : (Probable cause : Function Overloading) Please add parameter types of the function."); - } - } - else if (__type__.MemberType == System.Reflection.MemberTypes.Field) - { - var __field__ = obj.GetType().GetField(key).GetValue(obj); - return (object[] args) => { return __field__; }; - } - else if (__type__.MemberType == System.Reflection.MemberTypes.Property) - { - var __property__ = obj.GetType().GetProperty(key).GetValue(obj); - return (object[] args) => { return __property__; }; - } - return (object[] args) => { return "NaN"; }; - } - public static IEnumerable TupleToEnumerable(object tuple) { Type t = tuple.GetType(); @@ -338,5 +313,15 @@ namespace Tensorflow return true; return false; } + + public static Func partial(Func func, Tin1 args) + { + Func newfunc = (args1) => + { + return func(args1); + }; + + return newfunc; + } } } diff --git a/src/TensorFlowNET.Core/Binding.cs b/src/TensorFlowNET.Core/Binding.cs index f443f2eb..34dfbbdb 100644 --- a/src/TensorFlowNET.Core/Binding.cs +++ b/src/TensorFlowNET.Core/Binding.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Dynamic; using System.Text; namespace Tensorflow @@ -7,5 +8,17 @@ namespace Tensorflow public static partial class Binding { public static tensorflow tf { get; } = New(); + + /// + /// Alias to null, similar to python's None. + /// For TensorShape, please use Unknown + /// + public static readonly object None = null; + + /// + /// Used for TensorShape None + /// + /// + public static readonly int Unknown = -1; } } diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index c08d3175..ad5dbc44 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -70,7 +70,7 @@ namespace Tensorflow public Buffer() => _handle = TF_NewBuffer(); - internal Buffer(IntPtr handle) + public Buffer(IntPtr handle) { if (handle == IntPtr.Zero) throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); diff --git a/src/TensorFlowNET.Core/Contrib/Train/HParams.cs b/src/TensorFlowNET.Core/Contrib/Train/HParams.cs new file mode 100644 index 00000000..bd85ad4c --- /dev/null +++ b/src/TensorFlowNET.Core/Contrib/Train/HParams.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Contrib.Train +{ + /// + /// Class to hold a set of hyperparameters as name-value pairs. + /// + public class HParams + { + public bool load_pretrained { get; set; } + + public HParams(bool load_pretrained) + { + this.load_pretrained = load_pretrained; + } + } +} diff --git a/src/TensorFlowNET.Core/Estimator/EstimatorV2.cs b/src/TensorFlowNET.Core/Estimator/EstimatorV2.cs deleted file mode 100644 index 10c2e16e..00000000 --- a/src/TensorFlowNET.Core/Estimator/EstimatorV2.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System; -using Tensorflow.Data; - -namespace Tensorflow.Estimator -{ - /// - /// Estimator class to train and evaluate TensorFlow models. - /// - /// - public class EstimatorV2 : IEstimator - { - public EstimatorV2(string model_dir = null) - { - - } - - /// - /// Calls the input function. - /// - /// - public void call_input_fn(string mode = null) - { - - } - - public void train_model_default(Func input_fn) - { - - } - - public void get_features_and_labels_from_input_fn() - { - - } - } -} diff --git a/src/TensorFlowNET.Core/Estimator/IEstimator.cs b/src/TensorFlowNET.Core/Estimator/IEstimator.cs deleted file mode 100644 index 4443f3c3..00000000 --- a/src/TensorFlowNET.Core/Estimator/IEstimator.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Tensorflow.Estimator -{ - public interface IEstimator - { - - } -} diff --git a/src/TensorFlowNET.Core/Estimator/TrainingExecutor.cs b/src/TensorFlowNET.Core/Estimator/TrainingExecutor.cs deleted file mode 100644 index 9e281811..00000000 --- a/src/TensorFlowNET.Core/Estimator/TrainingExecutor.cs +++ /dev/null @@ -1,15 +0,0 @@ -namespace Tensorflow.Estimator -{ - /// - /// The executor to run `Estimator` training and evaluation. - /// - /// - public class TrainingExecutor - { - private IEstimator _estimator; - public TrainingExecutor(IEstimator estimator) - { - _estimator = estimator; - } - } -} diff --git a/src/TensorFlowNET.Core/Estimators/Estimator.cs b/src/TensorFlowNET.Core/Estimators/Estimator.cs new file mode 100644 index 00000000..5ba7a9c3 --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/Estimator.cs @@ -0,0 +1,138 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Text; +using Tensorflow.Data; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow.Estimators +{ + /// + /// Estimator class to train and evaluate TensorFlow models. + /// + public class Estimator : IObjectLife + { + RunConfig _config; + public RunConfig config => _config; + + ConfigProto _session_config; + public ConfigProto session_config => _session_config; + + string _model_dir; + + Action _model_fn; + + public Estimator(Action model_fn, RunConfig config) + { + _config = config; + _model_dir = _config.model_dir; + _session_config = _config.session_config; + _model_fn = model_fn; + } + + public Estimator train(Func input_fn, int max_steps = 1, Action[] hooks = null, + _NewCheckpointListenerForEvaluate[] saving_listeners = null) + { + if(max_steps > 0) + { + var start_step = _load_global_step_from_checkpoint_dir(_model_dir); + if (max_steps <= start_step) + { + Console.WriteLine("Skipping training since max_steps has already saved."); + return this; + } + } + + var loss = _train_model(input_fn); + print($"Loss for final step: {loss}."); + return this; + } + + private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) + { + // var cp = tf.train.latest_checkpoint(checkpoint_dir); + // should use NewCheckpointReader (not implemented) + var cp = tf.train.get_checkpoint_state(checkpoint_dir); + + return cp.AllModelCheckpointPaths.Count - 1; + } + + private Tensor _train_model(Func input_fn) + { + return _train_model_default(input_fn); + } + + private Tensor _train_model_default(Func input_fn) + { + using (var g = tf.Graph().as_default()) + { + var global_step_tensor = _create_and_assert_global_step(g); + + // Skip creating a read variable if _create_and_assert_global_step + // returns None (e.g. tf.contrib.estimator.SavedModelEstimator). + if (global_step_tensor != null) + TrainingUtil._get_or_create_global_step_read(g); + + var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train"); + } + + throw new NotImplementedException(""); + } + + private (Dictionary, Dictionary) _get_features_and_labels_from_input_fn(Func input_fn, string mode) + { + var result = _call_input_fn(input_fn, mode); + return EstimatorUtil.parse_input_fn_result(result); + } + + /// + /// Calls the input function. + /// + /// + /// + private DatasetV1Adapter _call_input_fn(Func input_fn, string mode) + { + return input_fn(); + } + + private Tensor _create_and_assert_global_step(Graph graph) + { + var step = _create_global_step(graph); + Debug.Assert(step == tf.train.get_global_step(graph)); + Debug.Assert(step.dtype.is_integer()); + return step; + } + + private RefVariable _create_global_step(Graph graph) + { + return tf.train.create_global_step(graph); + } + + public void __init__() + { + throw new NotImplementedException(); + } + + public void __enter__() + { + throw new NotImplementedException(); + } + + public void __del__() + { + throw new NotImplementedException(); + } + + public void __exit__() + { + throw new NotImplementedException(); + } + + public void Dispose() + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs b/src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs new file mode 100644 index 00000000..df1fb38b --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/EstimatorUtil.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Data; + +namespace Tensorflow.Estimators +{ + public class EstimatorUtil + { + public static (Dictionary, Dictionary) parse_input_fn_result(DatasetV1Adapter result) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Estimators/EvalSpec.cs b/src/TensorFlowNET.Core/Estimators/EvalSpec.cs new file mode 100644 index 00000000..c5a5820e --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/EvalSpec.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Estimators +{ + public class EvalSpec + { + string _name; + + public EvalSpec(string name, Action input_fn, FinalExporter exporters) + { + _name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Estimators/Exporter.cs b/src/TensorFlowNET.Core/Estimators/Exporter.cs new file mode 100644 index 00000000..2610458e --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/Exporter.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Estimators +{ + public abstract class Exporter + { + public abstract void export(Estimator estimator, string export_path, string checkpoint_path); + } +} diff --git a/src/TensorFlowNET.Core/Estimators/FinalExporter.cs b/src/TensorFlowNET.Core/Estimators/FinalExporter.cs new file mode 100644 index 00000000..af78c2bf --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/FinalExporter.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Estimators +{ + public class FinalExporter + { + public FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Estimator/HyperParams.cs b/src/TensorFlowNET.Core/Estimators/HyperParams.cs similarity index 98% rename from src/TensorFlowNET.Core/Estimator/HyperParams.cs rename to src/TensorFlowNET.Core/Estimators/HyperParams.cs index 706f00c6..0c177cae 100644 --- a/src/TensorFlowNET.Core/Estimator/HyperParams.cs +++ b/src/TensorFlowNET.Core/Estimators/HyperParams.cs @@ -1,6 +1,6 @@ using System.IO; -namespace Tensorflow.Estimator +namespace Tensorflow.Estimators { public class HyperParams { diff --git a/src/TensorFlowNET.Core/Estimator/README.md b/src/TensorFlowNET.Core/Estimators/README.md similarity index 71% rename from src/TensorFlowNET.Core/Estimator/README.md rename to src/TensorFlowNET.Core/Estimators/README.md index b39e5650..2dffdfbb 100644 --- a/src/TensorFlowNET.Core/Estimator/README.md +++ b/src/TensorFlowNET.Core/Estimators/README.md @@ -1,6 +1,7 @@ -### TensorFlow Estimator +# TensorFlow Estimator TensorFlow Estimator is a high-level TensorFlow API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model. -Guide: + +https://github.com/tensorflow/estimator \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Estimators/RunConfig.cs b/src/TensorFlowNET.Core/Estimators/RunConfig.cs new file mode 100644 index 00000000..37c7d4d1 --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/RunConfig.cs @@ -0,0 +1,101 @@ +using System; + +namespace Tensorflow.Estimators +{ + public class RunConfig + { + // A list of the property names in RunConfig that the user is allowed to change. + private static readonly string[] _DEFAULT_REPLACEABLE_LIST = new [] + { + "model_dir", + "tf_random_seed", + "save_summary_steps", + "save_checkpoints_steps", + "save_checkpoints_secs", + "session_config", + "keep_checkpoint_max", + "keep_checkpoint_every_n_hours", + "log_step_count_steps", + "train_distribute", + "device_fn", + "protocol", + "eval_distribute", + "experimental_distribute", + "experimental_max_worker_delay_secs", + "session_creation_timeout_secs" + }; + + + #region const values + + private const string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set."; + private const string _TF_CONFIG_ENV = "TF_CONFIG"; + private const string _TASK_ENV_KEY = "task"; + private const string _TASK_TYPE_KEY = "type"; + private const string _TASK_ID_KEY = "index"; + private const string _CLUSTER_KEY = "cluster"; + private const string _SERVICE_KEY = "service"; + private const string _SESSION_MASTER_KEY = "session_master"; + private const string _EVAL_SESSION_MASTER_KEY = "eval_session_master"; + private const string _MODEL_DIR_KEY = "model_dir"; + private const string _LOCAL_MASTER = ""; + private const string _GRPC_SCHEME = "grpc://"; + + #endregion + + public string model_dir { get; set; } + public ConfigProto session_config { get; set; } + public int? tf_random_seed { get; set; } + public int save_summary_steps { get; set; } = 100; + public int save_checkpoints_steps { get; set; } + public int save_checkpoints_secs { get; set; } = 600; + public int keep_checkpoint_max { get; set; } = 5; + public int keep_checkpoint_every_n_hours { get; set; } = 10000; + public int log_step_count_steps{ get; set; } = 100; + public object train_distribute { get; set; } + public object device_fn { get; set; } + public object protocol { get; set; } + public object eval_distribute { get; set; } + public object experimental_distribute { get; set; } + public object experimental_max_worker_delay_secs { get; set; } + public int session_creation_timeout_secs { get; set; } = 7200; + public object service { get; set; } + + public RunConfig() + { + Initialize(); + } + + public RunConfig(string model_dir) + { + this.model_dir = model_dir; + Initialize(); + } + + public RunConfig( + string model_dir = null, + int? tf_random_seed = null, + int save_summary_steps=100, + object save_checkpoints_steps = null, // _USE_DEFAULT + object save_checkpoints_secs = null, // _USE_DEFAULT + object session_config = null, + int keep_checkpoint_max = 5, + int keep_checkpoint_every_n_hours = 10000, + int log_step_count_steps = 100, + object train_distribute = null, + object device_fn = null, + object protocol = null, + object eval_distribute = null, + object experimental_distribute = null, + object experimental_max_worker_delay_secs = null, + int session_creation_timeout_secs = 7200) + { + this.model_dir = model_dir; + Initialize(); + } + + private void Initialize() + { + } + } +} diff --git a/src/TensorFlowNET.Core/Estimators/TrainSpec.cs b/src/TensorFlowNET.Core/Estimators/TrainSpec.cs new file mode 100644 index 00000000..c2993684 --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/TrainSpec.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Data; + +namespace Tensorflow.Estimators +{ + public class TrainSpec + { + int _max_steps; + public int max_steps => _max_steps; + + Func _input_fn; + public Func input_fn => _input_fn; + + public TrainSpec(Func input_fn, int max_steps) + { + _max_steps = max_steps; + _input_fn = input_fn; + } + } +} diff --git a/src/TensorFlowNET.Core/Estimators/Training.cs b/src/TensorFlowNET.Core/Estimators/Training.cs new file mode 100644 index 00000000..930d57c9 --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/Training.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Estimators +{ + public class Training + { + public static void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) + { + var executor = new _TrainingExecutor(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); + var config = estimator.config; + + executor.run(); + } + } +} diff --git a/src/TensorFlowNET.Core/Estimators/_Evaluator.cs b/src/TensorFlowNET.Core/Estimators/_Evaluator.cs new file mode 100644 index 00000000..bd0aa4be --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/_Evaluator.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Estimators +{ + public class _Evaluator + { + public _Evaluator(Estimator estimator, EvalSpec eval_spec, int max_training_steps) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs b/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs new file mode 100644 index 00000000..71464562 --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/_NewCheckpointListenerForEvaluate.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Estimators +{ + public class _NewCheckpointListenerForEvaluate + { + _Evaluator _evaluator; + + public _NewCheckpointListenerForEvaluate(_Evaluator evaluator, int eval_throttle_secs) + { + _evaluator = evaluator; + } + } +} diff --git a/src/TensorFlowNET.Core/Estimators/_SavedModelExporter.cs b/src/TensorFlowNET.Core/Estimators/_SavedModelExporter.cs new file mode 100644 index 00000000..9beb35ea --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/_SavedModelExporter.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Estimators +{ + public class _SavedModelExporter : Exporter + { + public override void export(Estimator estimator, string export_path, string checkpoint_path) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs b/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs new file mode 100644 index 00000000..e7ad6905 --- /dev/null +++ b/src/TensorFlowNET.Core/Estimators/_TrainingExecutor.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Estimators +{ + /// + /// The executor to run `Estimator` training and evaluation. + /// + internal class _TrainingExecutor + { + Estimator _estimator; + EvalSpec _eval_spec; + TrainSpec _train_spec; + + public _TrainingExecutor(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) + { + _estimator = estimator; + _train_spec = train_spec; + _eval_spec = eval_spec; + } + + public void run() + { + var config = _estimator.config; + Console.WriteLine("Running training and evaluation locally (non-distributed)."); + run_local(); + } + + /// + /// Runs training and evaluation locally (non-distributed). + /// + private void run_local() + { + var train_hooks = new Action[0]; + Console.WriteLine("Start train and evaluate loop. The evaluate will happen " + + "after every checkpoint. Checkpoint frequency is determined " + + $"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " + + $"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}."); + var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps); + var saving_listeners = new _NewCheckpointListenerForEvaluate[0]; + _estimator.train(input_fn: _train_spec.input_fn, + max_steps: _train_spec.max_steps, + hooks: train_hooks, + saving_listeners: saving_listeners); + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/LookupError.cs b/src/TensorFlowNET.Core/Exceptions/LookupError.cs new file mode 100644 index 00000000..ebbaa526 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/LookupError.cs @@ -0,0 +1,17 @@ +using System; + +namespace Tensorflow +{ + public class LookupError : TensorflowException + { + public LookupError() : base() + { + + } + + public LookupError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/StopIteration.cs b/src/TensorFlowNET.Core/Exceptions/StopIteration.cs new file mode 100644 index 00000000..d91408a2 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/StopIteration.cs @@ -0,0 +1,17 @@ +using System; + +namespace Tensorflow +{ + public class StopIteration : TensorflowException + { + public StopIteration() : base() + { + + } + + public StopIteration(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs similarity index 97% rename from src/TensorFlowNET.Core/Framework/meta_graph.py.cs rename to src/TensorFlowNET.Core/Framework/meta_graph.cs index d7d7ef7e..3f5a2777 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -167,7 +167,7 @@ namespace Tensorflow /// /// /// - public static (MetaGraphDef, Dictionary) export_scoped_meta_graph(string filename = "", + public static (MetaGraphDef, Dictionary) export_scoped_meta_graph(string filename = "", GraphDef graph_def = null, bool as_text = false, string unbound_inputs_col_name = "unbound_inputs", @@ -179,8 +179,8 @@ namespace Tensorflow { var graph = ops.get_default_graph(); - var var_list = new Dictionary(); - var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) as List; + var var_list = new Dictionary(); + var variables = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES); if (variables != null) { diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs index 908acb75..67102cab 100644 --- a/src/TensorFlowNET.Core/Framework/smart_module.cs +++ b/src/TensorFlowNET.Core/Framework/smart_module.cs @@ -20,15 +20,24 @@ namespace Tensorflow.Framework { public class smart_module { - public static Tensor[] smart_cond(Tensor pred, - Func true_fn = null, - Func false_fn = null, + public static Tensor[] smart_cond(Tensor pred, + Func true_fn = null, + Func false_fn = null, string name = null) { - return control_flow_ops.cond(pred, - true_fn: true_fn, - false_fn: false_fn, - name: name); + var pred_value = smart_constant_value(pred); + if (pred_value.HasValue) + { + if (pred_value.Value) + return true_fn() as Tensor[]; + else + return false_fn() as Tensor[]; + } + else + return control_flow_ops.cond(pred, + true_fn: true_fn, + false_fn: false_fn, + name: name); } public static bool? smart_constant_value(Tensor pred) diff --git a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs b/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs index 8d0ea53b..b03ce2de 100644 --- a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs +++ b/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs @@ -1,19 +1,63 @@ -namespace Tensorflow.Framework -{ - public interface _TensorLike - { } +using System; +using System.Linq; +using static Tensorflow.Binding; - public class SparseTensor : CompositeTensor, _TensorLike +namespace Tensorflow.Framework +{ + /// + /// Represents a sparse tensor. + /// + public class SparseTensor : CompositeTensor, _TensorLike { - private static Tensor _dense_shape { get; set; } + long[,] _indices; + public Tensor indices; + + T[] _values; + public Tensor values; + + long[] _dense_shape; + public Tensor dense_shape; + + TensorShape _shape; + public TensorShape shape => _shape; + + public TF_DataType dtype => dtypes.as_dtype(typeof(T)); + + public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_) + { + tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate + { + indices = ops.convert_to_tensor( + indices_, name: "indices", dtype: dtypes.int64); + values = ops.internal_convert_to_tensor(values_, name: "values"); + dense_shape = ops.convert_to_tensor( + dense_shape_, name: "dense_shape", dtype: dtypes.int64); + }); + _indices = indices_; + _values = values_; + _dense_shape = dense_shape_; + + var indices_shape = indices.TensorShape.with_rank(2); + var values_shape = values.TensorShape.with_rank(1); + var dense_shape_shape = dense_shape.TensorShape.with_rank(1); + + indices_shape[0].merge_with(values_shape.dims[0]); + indices_shape[1].merge_with(dense_shape_shape.dims[0]); + + _shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray()); + } + } + + public interface _TensorLike + { } - public static class sparse_tensor + public static class sparse_tensor_extension { public static bool is_sparse(this _TensorLike x) { - return x is SparseTensor; + return x.GetType().Name.Contains("SparseTensor"); } } } diff --git a/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs b/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs new file mode 100644 index 00000000..d573e317 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/RegisterNoGradient.cs @@ -0,0 +1,33 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow.Gradients +{ + /// + /// REGISTER_NO_GRADIENT_OP(""); + /// + public class RegisterNoGradient : Attribute + { + public string Name { get; set; } + + public RegisterNoGradient(string name) + { + Name = name; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index e98ec21e..f07d2825 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -190,6 +190,26 @@ namespace Tensorflow.Gradients return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; } + [RegisterGradient("Pad")] + public static Tensor[] _PadGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var a = op.inputs[1]; + var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); + var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); + + // Make it a 1-D tensor. + var begin = array_ops.reshape(pad_before, new[] { -1 }); + var sizes = array_ops.shape(x); + var x_grad = array_ops.slice(grad, begin, sizes); + + if (len(op.inputs) == 3) + return new Tensor[] { x_grad, null, null }; + else + return new Tensor[] { x_grad, null }; + } + [RegisterGradient("Squeeze")] public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs similarity index 72% rename from src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs rename to src/TensorFlowNET.Core/Gradients/control_flow_grad.cs index 22a73374..76b6a7b5 100644 --- a/src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs +++ b/src/TensorFlowNET.Core/Gradients/control_flow_grad.cs @@ -1,28 +1,30 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. ******************************************************************************/ using System; using System.Linq; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow.Gradients { /// /// Gradients for operators defined in control_flow_ops.py.cs /// + [RegisterGradient("control_flow_grad")] public class control_flow_grad { /// @@ -33,118 +35,117 @@ namespace Tensorflow.Gradients /// on the second visit. A next_iteration is also added on second visit. /// /// - public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) + [RegisterGradient("Switch")] + public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads) { + var grad = grads[0]; + var graph = ops.get_default_graph(); + var op_ctxt = op._get_control_flow_context(); + var grad_ctxt = graph._get_control_flow_context(); + switch (op_ctxt) + { + case WhileContext cwhile: + throw new NotImplementedException("_SwitchGrad WhileContext"); + case CondContext ccond: + { + var zero_grad = grads[1 - op_ctxt.branch]; + // At this point, we have created zero_grad guarded by the right switch. + // Unfortunately, we may still get None here for not trainable data types. + if(zero_grad == null) + { + throw new NotImplementedException("_SwitchGrad CondContext zero_grad"); + } + + return new Tensor[] + { + merge(grads, name: "cond_grad")[0], + null + }; + } + default: + throw new NotImplementedException("_SwitchGrad WhileContext"); + } throw new NotImplementedException("_SwitchGrad"); - //graph = ops.get_default_graph() - //# pylint: disable=protected-access - //op_ctxt = op._get_control_flow_context() - //grad_ctxt = graph._get_control_flow_context() - //# pylint: enable=protected-access - //if isinstance(op_ctxt, WhileContext): - // merge_grad = grad_ctxt.grad_state.switch_map.get(op) - // if merge_grad is not None: - // # This is the second time this Switch is visited. It comes from - // # the non-exit branch of the Switch, so update the second input - // # to the Merge. - // # TODO(yuanbyu): Perform shape inference with this new input. - // if grad[1] is not None: - // # pylint: disable=protected-access - // control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], - // enforce_shape_invariant=False) - // # pylint: enable=protected-access - // return None, None - // elif grad[0] is not None: - // # This is the first time this Switch is visited. It comes from - // # the Exit branch, which is grad[0]. grad[1] is empty at this point. - // # Use grad[0] for both inputs to merge for now, but update the second - // # input of merge when we see this Switch the second time. - // merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] - // grad_ctxt.grad_state.switch_map[op] = merge_grad - // return merge_grad, None - // else: - // # This is the first time this Switch is visited. It comes from the - // # Identity branch. Such a Switch has `None` gradient for the Exit branch, - // # meaning the output is not differentiable. - // return None, None - //elif isinstance(op_ctxt, CondContext): - // zero_grad = grad[1 - op_ctxt.branch] - // # At this point, we have created zero_grad guarded by the right switch. - // # Unfortunately, we may still get None here for not trainable data types. - // if zero_grad is None: - // # For resource variables we get None always on the other branch, so bypass - // # this. - // if op.inputs[0].dtype == dtypes.resource: - // return merge( - // [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None - // return None, None - // return merge(grad, name="cond_grad")[0], None - //else: - // false_grad = switch(grad[0], op.inputs[1])[0] - // true_grad = switch(grad[1], op.inputs[1])[1] - // return merge([false_grad, true_grad])[0], None - } - + } + + /// + /// Returns the value of an available element of `inputs`. + /// + /// + /// + /// + internal static Tensor[] merge(Tensor[] inputs, string name = null) + { + return tf_with(ops.name_scope(name, "Merge", inputs), scope => + { + name = scope; + if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length) + return gen_control_flow_ops.ref_merge(inputs, name: name); + else + return gen_control_flow_ops.merge(inputs, name: name); + }); + } + /// /// Gradients for a Merge op are calculated using a Switch op. /// - [RegisterGradient("Merge")] + [RegisterGradient("Merge")] public static Tensor[] _MergeGrad(Operation op, Tensor[] grads) { var grad = grads[0]; - var _ = grads[1]; var input_op = op.inputs[0].op; var graph = ops.get_default_graph(); var op_ctxt = control_flow_util.GetOutputContext(input_op); var grad_ctxt = graph._get_control_flow_context(); switch (op_ctxt) { - case WhileContext cwhile: - { - return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); + case WhileContext cwhile: + { + return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot); } - case CondContext ccond: - { - var pred = ccond.pred; - if (grad_ctxt != null && grad_ctxt.grad_state != null) - { - //# This Merge node is part of a cond within a loop. - //# The backprop needs to have the value of this predicate for every - //# iteration. So we must have its values accumulated in the forward, and - //# use the accumulated values as the predicate for this backprop switch. - var grad_state = grad_ctxt.grad_state; - var real_pred = grad_state.history_map[pred.name] as Tensor; - if (real_pred == null) - { - //# Remember the value of pred for every iteration. - grad_ctxt = grad_state.grad_context; - grad_ctxt.Exit(); - var history_pred = grad_state.AddForwardAccumulator(pred); - grad_ctxt.Enter(); - - //# Add the stack pop op. If pred.op is in a (outer) CondContext, - //# the stack pop will be guarded with a switch. - real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); - grad_state.history_map[pred.name] = real_pred; - } - pred = real_pred; - } - var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); - return results; + case CondContext ccond: + { + var pred = ccond.pred; + if (grad_ctxt != null && grad_ctxt.grad_state != null) + { + //# This Merge node is part of a cond within a loop. + //# The backprop needs to have the value of this predicate for every + //# iteration. So we must have its values accumulated in the forward, and + //# use the accumulated values as the predicate for this backprop switch. + var grad_state = grad_ctxt.grad_state; + var real_pred = grad_state.history_map[pred.name] as Tensor; + if (real_pred == null) + { + //# Remember the value of pred for every iteration. + grad_ctxt = grad_state.grad_context; + grad_ctxt.Exit(); + var history_pred = grad_state.AddForwardAccumulator(pred); + grad_ctxt.Enter(); + + //# Add the stack pop op. If pred.op is in a (outer) CondContext, + //# the stack pop will be guarded with a switch. + real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred); + grad_state.history_map[pred.name] = real_pred; + } + pred = real_pred; + } + var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); + return results; } - default: - { - var num_inputs = op.inputs.Length; - var cond = new Tensor[num_inputs]; - for (int i = 0; i < num_inputs; i++) - cond[i] = math_ops.equal(op.outputs[1], i); - var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); - return result; + default: + { + var num_inputs = op.inputs.Length; + var cond = new Tensor[num_inputs]; + for (int i = 0; i < num_inputs; i++) + cond[i] = math_ops.equal(op.outputs[1], i); + var result = cond.Select(t => control_flow_ops._SwitchRefOrTensor(grad, t)[1]).ToArray(); + return result; } } } + [RegisterGradient("RefMerge")] public Tensor[] _RefMergeGrad(Operation op, Tensor[] grads) { return _MergeGrad(op, grads); @@ -153,6 +154,7 @@ namespace Tensorflow.Gradients /// /// Gradients for an exit op are calculated using an Enter op. /// + [RegisterGradient("Exit")] public Tensor[] _ExitGrad(Operation op, Tensor[] grads) { throw new NotImplementedException("_ExitGrad"); @@ -197,14 +199,16 @@ namespace Tensorflow.Gradients /// /// Note that the backprop next_iteration is added in switch grad. /// - public (object, Tensor[]) _NextIterationGrad(object _, Tensor[] grad) + [RegisterGradient("NextIteration")] + public Tensor[] _NextIterationGrad(object _, Tensor[] grad) { - return (_, grad); + return grad; } - public (object, Tensor[]) _RefNextIterationGrad(object _, Tensor[] grad) + [RegisterGradient("RefNextIteration")] + public Tensor[] _RefNextIterationGrad(object _, Tensor[] grad) { - return (_, grad); + return grad; } /// @@ -213,7 +217,8 @@ namespace Tensorflow.Gradients /// For loop variables, grad is the gradient so just add an exit. /// For loop invariants, we need to add an accumulator loop. /// - public (object, Tensor[]) _EnterGrad(Tensor op, Tensor[] grad) + [RegisterGradient("Enter")] + public Tensor[] _EnterGrad(Tensor op, Tensor[] grad) { throw new NotImplementedException("_EnterGrad"); // graph = ops.get_default_graph() @@ -242,7 +247,9 @@ namespace Tensorflow.Gradients // return result } - public (object, Tensor[]) _RefEnterGrad(Tensor op, Tensor[] grad) + + [RegisterGradient("RefEnter")] + public Tensor[] _RefEnterGrad(Tensor op, Tensor[] grad) { return _EnterGrad(op, grad); } @@ -250,10 +257,11 @@ namespace Tensorflow.Gradients /// /// Stop backprop for the predicate of a while loop. /// - public object _LoopCondGrad(object _) + [RegisterGradient("LoopCond")] + public Tensor[] _LoopCondGrad(Tensor op, Tensor[] grad) { return null; - } - - } + } + + } } diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 3b6d0eea..15ad511b 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -61,7 +61,7 @@ namespace Tensorflow string grad_scope = scope; // Get a uid for this call to gradients that can be used to help // cluster ops for compilation. - var gradient_uid = ops.get_default_graph().unique_name("uid"); + var gradient_uid = curr_graph.unique_name("uid"); ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); @@ -80,7 +80,7 @@ namespace Tensorflow var to_ops = ys.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); - (var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); foreach (var (y, grad_y) in zip(ys, grad_ys)) _SetGrad(grads, y, grad_y); @@ -117,23 +117,52 @@ namespace Tensorflow Tensor[] in_grads = null; var is_partitioned_call = _IsPartitionedCall(op); var is_func_call = false; - var has_out_grads = true; + var has_out_grads = out_grads.Exists(x => x != null); if (has_out_grads && !stop_ops.Contains(op)) { // A grad_fn must be defined, either as a function or as None // for ops that do not have gradients. - var grad_fn = ops.get_gradient_function(op); - if (is_func_call) + Func grad_fn = null; + try { + grad_fn = ops.get_gradient_function(op); + } + catch (LookupError) + { + if (is_func_call) + { + if (is_partitioned_call) + { + + } + else + { + } + } + else + { + throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); + } } - else + + // if (loop_state) + //loop_state.EnterGradWhileContext(op, before: false); + + if ((is_func_call || grad_fn != null) && has_out_grads) { + // NOTE: If _AggregatedGrads didn't compute a value for the i'th + // output, it means that the cost does not depend on output[i], + // therefore dC/doutput[i] is 0. foreach (var (i, out_grad) in enumerate(out_grads)) { - if (out_grad == null) + if (out_grad == null && + (grad_fn == null || _IsTrainable(op.outputs[i]))) { + // Only trainable outputs or outputs for a function call that + // will use SymbolicGradient get a zero gradient. Gradient + // functions should ignore the gradient for other outputs. if (loop_state != null) ; else @@ -143,13 +172,19 @@ namespace Tensorflow tf_with(ops.name_scope(op.name + "_grad"), scope1 => { - string name1 = scope1; if (grad_fn != null) { - in_grads = _MaybeCompile(grad_scope, op, out_grads.Select(x => x[0]).ToArray(), null, grad_fn); - _VerifyGeneratedGradients(in_grads, op); + in_grads = _MaybeCompile(grad_scope, + op, + out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, + grad_fn); } - + else + { + throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); + } + _VerifyGeneratedGradients(in_grads, op); if (gate_gradients && in_grads.Count(x => x != null) > 1) { ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); @@ -157,6 +192,12 @@ namespace Tensorflow } }); } + else + { + // If no grad_fn is defined or none of out_grads is available, + // just propagate a list of None backwards. + in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; + } } else { @@ -168,11 +209,11 @@ namespace Tensorflow { if (in_grad != null) { - if (in_grad is Tensor && + if (!(in_grad is null) && in_grad.Tag == null && // maybe a IndexedSlice t_in.dtype != TF_DataType.TF_RESOURCE) { - in_grad.shape = t_in.shape; + in_grad.set_shape(t_in.TensorShape); } _SetGrad(grads, t_in, in_grad); diff --git a/src/TensorFlowNET.Core/Gradients/image_grad.cs b/src/TensorFlowNET.Core/Gradients/image_grad.cs new file mode 100644 index 00000000..23b19de9 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/image_grad.cs @@ -0,0 +1,54 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow.Gradients +{ + [RegisterGradient("image_grad")] + public class image_grad + { + [RegisterGradient("ResizeNearestNeighbor")] + public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var image = op.inputs[0]; + var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); + Tensor image_shape = null; + if (shape.is_fully_defined()) + throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); + else + image_shape = array_ops.shape(image)["1:3"]; + + grad = gen_image_ops.resize_nearest_neighbor_grad( + grad, + image_shape, + align_corners: op.get_attr("align_corners"), + half_pixel_centers: op.get_attr("half_pixel_centers")); + + return new Tensor[] + { + grad, + null + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 672270b9..49dcbc45 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -42,7 +42,8 @@ namespace Tensorflow.Gradients var x = op.inputs[0]; var y = op.inputs[1]; var grad = grads[0]; - if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) + if (grad is Tensor && + _ShapesFullySpecifiedAndEqual(x, y, grad)) return new Tensor[] { grad, grad }; var sx = array_ops.shape(x); @@ -96,6 +97,12 @@ namespace Tensorflow.Gradients }); } + [RegisterNoGradient("GreaterEqual")] + public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null; + + [RegisterNoGradient("ZerosLike")] + public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null; + [RegisterGradient("Identity")] public static Tensor[] _IdGrad(Operation op, Tensor[] grads) { @@ -124,6 +131,17 @@ namespace Tensorflow.Gradients }); } + [RegisterGradient("Log1p")] + public static Tensor[] _Log1pGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => { + x = math_ops.conj(x); + return new Tensor[] { grad * math_ops.reciprocal(1 + x) }; + }); + } + [RegisterGradient("Mul")] public static Tensor[] _MulGrad(Operation op, Tensor[] grads) { @@ -332,6 +350,21 @@ namespace Tensorflow.Gradients return new Tensor[] { -grads[0] }; } + [RegisterGradient("Select")] + public static Tensor[] _SelectGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var c = op.inputs[0]; + var x = op.inputs[1]; + var zeros = array_ops.zeros_like(x); + return new Tensor[] + { + null, + array_ops.where(c, grad, zeros), + array_ops.where(c, zeros, grad) + }; + } + private static Tensor _safe_shape_div(Tensor x, Tensor y) { return math_ops.floordiv(x, gen_math_ops.maximum(y, 1)); @@ -343,7 +376,8 @@ namespace Tensorflow.Gradients var grad = grads[0]; var x = op.inputs[0]; var y = op.inputs[1]; - if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad)) + if (grad is Tensor && + _ShapesFullySpecifiedAndEqual(x, y, grad)) return new Tensor[] { grad, -grad }; var sx = array_ops.shape(x); @@ -361,9 +395,10 @@ namespace Tensorflow.Gradients var x_shape = x._shape_tuple(); var y_shape = y._shape_tuple(); var grad_shape = grad._shape_tuple(); - return Enumerable.SequenceEqual(x_shape, y_shape) && + return x_shape != null && + y_shape != null && + Enumerable.SequenceEqual(x_shape, y_shape) && Enumerable.SequenceEqual(y_shape, grad_shape) && - x.NDims != -1 && !x_shape.Contains(-1); } @@ -382,7 +417,9 @@ namespace Tensorflow.Gradients var rank = input_0_shape.Length; if (Enumerable.SequenceEqual(Enumerable.Range(0, rank), axes.Data())) { - grad = array_ops.reshape(grad, new int[] { 1 }); + var new_shape = range(rank).Select(x => 1).ToArray(); + grad = array_ops.reshape(grad, new_shape); + // If shape is not fully defined (but rank is), we use Shape. if (!input_0_shape.Contains(-1)) input_shape = constant_op.constant(input_0_shape); else diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index 53ee9699..967b3c21 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -47,6 +47,15 @@ namespace Tensorflow.Gradients return new Tensor[] { gen_nn_ops.relu_grad(grads[0], op.outputs[0]) }; } + [RegisterGradient("LeakyRelu")] + public static Tensor[] _LeakyReluGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var x = op.inputs[0]; + var alpha = (float)op.get_attr("alpha"); + return new Tensor[] { gen_nn_ops.leaky_relu_grad(grad, x, alpha: alpha)}; + } + /// /// The derivative of the softmax nonlinearity. /// @@ -157,6 +166,94 @@ namespace Tensorflow.Gradients }; } + [RegisterGradient("FusedBatchNorm")] + public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) + => _BaseFusedBatchNormGrad(op, 0, grads); + + /// + /// Return the gradients for the 3 inputs of BatchNorm. + /// + /// + /// + /// + /// + public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads) + { + var x = op.inputs[0]; + var grad_y = grads[0]; + var scale = op.inputs[1]; + var epsilon = op.get_attr("epsilon"); + var data_format = op.get_attr("data_format"); + var is_training = op.get_attr("is_training"); + Func grad_fun = null; + + switch (version) + { + case 2: + throw new NotImplementedException(""); + case 1: + throw new NotImplementedException(""); + default: + grad_fun = gen_nn_ops.fused_batch_norm_grad; + break; + } + + if (is_training) + { + return grad_fun(new FusedBatchNormParams + { + YBackprop = grad_y, + X = x, + Scale = scale, + ReserveSpace1 = op.outputs[3], + ReserveSpace2 = op.outputs[4], + ReserveSpace3 = version == 2 ? op.outputs[5] : null, + Epsilon = epsilon, + DataFormat = data_format, + IsTraining = is_training + }); + } + else + { + var pop_mean = op.inputs[3]; + var pop_var = op.inputs[4]; + if (data_format == "NCHW") + throw new NotImplementedException(""); + + var results = grad_fun(new FusedBatchNormParams + { + YBackprop = grad_y, + X = x, + Scale = scale, + ReserveSpace1 = op.outputs[3], + ReserveSpace2 = op.outputs[4], + ReserveSpace3 = version == 2 ? op.outputs[5] : null, + Epsilon = epsilon, + DataFormat = data_format, + IsTraining = is_training + }); + + var (dx, dscale, doffset) = (results[0], results[1], results[2]); + if (data_format == "NCHW") + throw new NotImplementedException(""); + + return new Tensor[] + { + dx, + dscale, + doffset, + null, + null + }; + } + } + + [RegisterGradient("BatchNormWithGlobalNormalization")] + public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads) + { + throw new NotImplementedException("BatchNormWithGlobalNormalization"); + } + private static bool IsZero(Tensor g) { if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index 0b624ba1..4891fcbb 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -39,6 +39,14 @@ namespace Tensorflow gradientFunctions[name] = func; } + public static void RegisterNoGradientFunction(string name) + { + if (gradientFunctions == null) + gradientFunctions = new Dictionary>(); + + gradientFunctions[name] = null; + } + public static Func get_gradient_function(Operation op) { if (op.inputs == null) return null; @@ -68,11 +76,18 @@ namespace Tensorflow args: new object[] { oper, out_grads }) as Tensor[] ); } + + // REGISTER_NO_GRADIENT_OP + methods = g.GetMethods().Where(x => x.GetCustomAttribute() != null) + .ToArray(); + + foreach (var m in methods) + RegisterNoGradientFunction(m.GetCustomAttribute().Name); } } if (!gradientFunctions.ContainsKey(op.type)) - throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}"); + throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); return gradientFunctions[op.type]; } diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs index 66419b3e..3dc77859 100644 --- a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs +++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs @@ -21,11 +21,10 @@ using static Tensorflow.Binding; namespace Tensorflow { - /// /// Serves as a stack for determining current default graph. /// - public class DefaultGraphStack + public class DefaultGraphStack { private readonly List _stack = new List(); @@ -40,7 +39,7 @@ namespace Tensorflow public Graph get_controller() { - if (_stack.Count(x => x.IsDefault) == 0) + if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); for (var i = _stack.Count - 1; i >= 0; i--) { @@ -52,6 +51,20 @@ namespace Tensorflow throw new TensorflowException("Unable to find a default graph"); } + public Graph peak_controller() + { + if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) + return null; + for (var i = _stack.Count - 1; i >= 0; i--) + { + var x = _stack[i]; + if (x.IsDefault) + return x.Graph; + } + + return null; + } + public bool remove(Graph g) { if (_stack.Count == 0) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs index 4a7e0ed8..2a0d939e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; using System.IO; using Tensorflow.Util; @@ -37,7 +38,9 @@ namespace Tensorflow using (var buffer = ToGraphDef(status)) { status.Check(true); - def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + // limit size to 250M, recursion to max 100 + var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100); + def = GraphDef.Parser.ParseFrom(inputStream); } // Strip the experimental library field iff it's empty. diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index 0e28dd9a..75f46a59 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -54,19 +54,35 @@ namespace Tensorflow var handle = return_oper_handle.node + tf_op_size * i; return_opers[i] = new Operation(*(IntPtr*)handle); } - } - + } + return return_opers; } + /// + /// Get operation with given + /// + /// When is not found current graph. + /// When tf.get_default_graph() is not current graph. + /// + /// graph.GetOperationByName("CustomInputName"); + /// public Operation OperationByName(string operName) { + if (operName == null) + throw new ArgumentNullException(nameof(operName)); + var handle = c_api.TF_GraphOperationByName(_handle, operName); - if(graph_key != tf.get_default_graph().graph_key) - { - Console.WriteLine($"Current graph is not default graph."); - // throw new ValueError($"Current graph is not default graph."); + if (handle == IntPtr.Zero) + throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); + + var defaultKey = tf.get_default_graph().graph_key; + if (graph_key != defaultKey) + { + //Console.WriteLine($"Current graph is not default graph."); + throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); } + return new Operation(handle, g: this); } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 0dfb68db..7119a4ad 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -175,6 +175,10 @@ namespace Tensorflow if (_nodes_by_name.ContainsKey(op_name)) return _nodes_by_name[op_name].outputs[out_n]; + else + throw new KeyError($"The name {name} refers to a Tensor which does not " + + $"exist. The operation, {op_name}, does not exist in the " + + "graph."); } else if (!name.Contains(":") & allow_operation) { @@ -223,6 +227,10 @@ namespace Tensorflow public void add_to_collection(string name, T value) { + if(name == "update_ops") + { + + } _check_not_finalized(); if (_collections.ContainsKey(name)) (_collections[name] as List).Add(value); @@ -242,7 +250,7 @@ namespace Tensorflow throw new RuntimeError("Graph is finalized and cannot be modified."); } - public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, + public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary attrs = null, OpDef op_def = null) { @@ -284,6 +292,11 @@ namespace Tensorflow return op; } + public void device(string device_name) + { + throw new NotImplementedException(""); + } + private void _create_op_helper(Operation op, bool compute_device = true) { _record_op_seen_by_control_dependencies(op); @@ -420,14 +433,36 @@ namespace Tensorflow public List get_collection(string name, string scope = null) { - return _collections.ContainsKey(name) ? _collections[name] as List : new List(); + List t = default; + var collection = _collections.ContainsKey(name) ? _collections[name] : new List(); + switch (collection) + { + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + default: + throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); + } + return t; } - public object get_collection_ref(string name) + public List get_collection_ref(string name) { if (!_collections.ContainsKey(name)) - _collections[name] = new List(); - return _collections[name]; + _collections[name] = new List(); + return _collections[name] as List; } public void prevent_feeding(Tensor tensor) diff --git a/src/TensorFlowNET.Core/Keras/backend.cs b/src/TensorFlowNET.Core/Keras/backend.cs index 5ac04f1a..46769bd8 100644 --- a/src/TensorFlowNET.Core/Keras/backend.cs +++ b/src/TensorFlowNET.Core/Keras/backend.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Keras //Func py_any = any; //Func> py_slice = slice; - public static Session _SESSION = tf.defaultSession; + public static Session _SESSION = ops.get_default_session(); public static Graph _GRAPH = null; public static Dictionary _GRAPH_LEARNING_PHASES; //Dictionary> PER_GRAPH_LAYER_NAME_UIDS; diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index a3ae3356..444c2dd4 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -90,7 +90,7 @@ namespace Tensorflow.Layers { foreach(var name in collection_list) { - var collection = ops.get_collection_ref(name) as List; + var collection = ops.get_collection_ref(name); foreach (var element in elements) if (!collection.Contains(element)) diff --git a/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs b/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs index 80e1c305..788adda4 100644 --- a/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs +++ b/src/TensorFlowNET.Core/Operations/Activation/gen_nn_ops.activations.cs @@ -14,20 +14,192 @@ limitations under the License. ******************************************************************************/ +using System; +using static Tensorflow.Binding; + namespace Tensorflow.Operations.Activation { + public class sigmoid : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.sigmoid(x); + } + } + + public class tanh : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.tanh(x); + } + } + + public class leakyrelu : IActivation + { + private readonly float _alpha; + + public leakyrelu(float alpha = 0.3f) { + _alpha = alpha; + } + + public Tensor Activate(Tensor x, string name = null) + { + return nn_ops.leaky_relu(x, _alpha); + } + } + + public class elu : IActivation + { + private readonly float _alpha; + + public elu(float alpha = 0.1f) + { + _alpha = alpha; + } + + public Tensor Activate(Tensor x, string name = null) + { + var res = gen_ops.elu(x); + if (Math.Abs(_alpha - 0.1f) < 0.00001f) + { + return res; + } + + return array_ops.@where(x > 0, res, _alpha * res); + } + } + + public class softmax : IActivation + { + private readonly int _axis; + + /// Initializes a new instance of the class. + public softmax(int axis = -1) + { + _axis = axis; + } + + public Tensor Activate(Tensor x, string name = null) + { + return nn_ops.softmax(x, _axis); + } + } + + public class softplus : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return gen_ops.softplus(x); + } + } + + public class softsign : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return gen_ops.softsign(x); + } + } + + public class linear : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return x; + } + } + + + public class exponential : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + return tf.exp(x, name: name); + } + } + + public class relu : IActivation { - public Tensor Activate(Tensor features, string name = null) + private readonly float _threshold; + private readonly float _alpha; + private readonly float? _maxValue; + + public relu(float threshold = 0f, float alpha = 0.2f, float? max_value = null) + { + _threshold = threshold; + _alpha = alpha; + _maxValue = max_value; + } + + public Tensor Activate(Tensor x, string name = null) { - OpDefLibrary _op_def_lib = new OpDefLibrary(); + //based on keras/backend.py + if (Math.Abs(_alpha) > 0.000001f) + { + if (!_maxValue.HasValue && Math.Abs(_threshold) < 0.0001) + { + return nn_ops.leaky_relu(x, _alpha); + } + } + + Tensor negative_part; + if (Math.Abs(_threshold) > 0.000001f) + { + negative_part = gen_ops.relu(-x + _threshold); + } else + { + negative_part = gen_ops.relu(-x + _threshold); + } + + if (Math.Abs(_threshold) > 0.000001f) + { + x = x * math_ops.cast(tf.greater(x, _threshold), TF_DataType.TF_FLOAT); + } else if (Math.Abs(_maxValue.Value - 6f) < 0.0001f) + { + x = gen_ops.relu6(x); + } else + { + x = gen_ops.relu(x); + } + + bool clip_max = _maxValue.HasValue; + if (clip_max) + { + Tensor maxval = constant_op.constant(_maxValue, x.dtype.as_base_dtype()); + var zero = constant_op.constant(0.0f, x.dtype.as_base_dtype()); + x = gen_ops.clip_by_value(x, zero, maxval); + } - var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new + if (Math.Abs(_alpha) > 0.00001) { - features - }); + var a = constant_op.constant(_alpha, x.dtype.as_base_dtype()); + x -= a * negative_part; + } - return _op.outputs[0]; + return x; + } + } + + public class selu : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + const float alpha = 1.6732632423543772848170429916717f; + const float scale = 1.0507009873554804934193349852946f; + return scale * new elu(alpha).Activate(x, name); + } + } + + public class hard_sigmoid : IActivation + { + public Tensor Activate(Tensor x, string name = null) + { + x = (0.2 * x) + 0.5; + var zero = tf.convert_to_tensor(0.0f, x.dtype.as_base_dtype()); + var one = tf.convert_to_tensor(1.0f, x.dtype.as_base_dtype()); + return tf.clip_by_value(x, zero, one); } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index aa314efb..ce2295c8 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -27,20 +27,6 @@ namespace Tensorflow.Operations /// public class CondContext : ControlFlowContext, IProtoBuf { - - - /// - /// The boolean tensor for the cond predicate - /// - private Tensor _pred; - - public Tensor pred => _pred; - - /// - /// 0 or 1 representing this branch - /// - private int _branch; - private Dictionary _external_values = new Dictionary(); /// diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 2a76c52c..c076cbc7 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -45,10 +45,19 @@ namespace Tensorflow.Operations /// The predicate tensor in this branch /// protected Tensor _pivot; - public Tensor pivot - { - get => _pivot; - } + public Tensor pivot => _pivot; + + /// + /// The boolean tensor for the cond predicate + /// + protected Tensor _pred; + public Tensor pred => _pred; + + /// + /// 0 or 1 representing this branch + /// + protected int _branch; + public int branch => _branch; protected Stack _context_stack; protected ControlFlowContext _outer_context; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs new file mode 100644 index 00000000..708d9db6 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs @@ -0,0 +1,55 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +namespace Tensorflow.Operations.Initializers +{ + public class Constant : IInitializer + { + TF_DataType dtype; + T value; + bool _verify_shape; + + public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) + { + this.value = value; + this.dtype = dtype; + _verify_shape = verify_shape; + } + + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) + { + if (dtype == TF_DataType.DtInvalid) + dtype = this.dtype; + + if (!verify_shape.HasValue) + verify_shape = _verify_shape; + + return constant_op._constant_impl(value, dtype, shape, + name: "Const", + verify_shape: verify_shape.Value, + allow_broadcast: false); + } + + public object get_config() + { + return new + { + value, + dtype = dtype.name() + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs index 4ca6c140..0ac0865f 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -18,7 +18,7 @@ namespace Tensorflow { public interface IInitializer { - Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid); + Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null); object get_config(); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs index 6fb4feb6..83e5b57d 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) { if (dtype == TF_DataType.DtInvalid) dtype = this.dtype; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index f553d45b..a3e2063f 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -38,7 +38,7 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) { if (dtype == TF_DataType.DtInvalid) dtype = this.dtype; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs index 98595edc..59333c84 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Operations.Initializers } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) { return random_ops.random_uniform(shape, minval: minval, diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs index 0611c0e9..7d635f0c 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -34,7 +34,7 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype) + public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) { return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed); } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index e88033b5..e2b2a0d6 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -45,7 +45,7 @@ namespace Tensorflow.Operations.Initializers _dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype) + public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) { var (fan_in, fan_out) = _compute_fans(shape); if (_mode == "fan_in") diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs index b9d4f746..bea9cf71 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) { if (dtype == TF_DataType.DtInvalid) dtype = this.dtype; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs new file mode 100644 index 00000000..689fa5fe --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/FusedBatchNormParams.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public class FusedBatchNormParams + { + public string Name { get; set; } + public Tensor YBackprop { get; set; } + public Tensor X { get; set; } + public Tensor Scale { get; set; } + public Tensor ReserveSpace1 { get; set; } + public Tensor ReserveSpace2 { get; set; } + public Tensor ReserveSpace3 { get; set; } + public float Epsilon { get; set; } + public string DataFormat { get; set; } + public bool IsTraining { get; set; } + + public FusedBatchNormParams() + { + Epsilon = 0.0001f; + DataFormat = "NHWC"; + IsTraining = true; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 49d504ab..4e376d19 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -156,7 +156,36 @@ namespace Tensorflow.Operations return op.output; } - public static Tensor[] _fused_batch_norm(Tensor x, + /// + /// Gradient for batch normalization. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params) + { + var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new + { + y_backprop = @params.YBackprop, + x = @params.X, + scale = @params.Scale, + reserve_space_1 = @params.ReserveSpace1, + reserve_space_2 = @params.ReserveSpace2, + epsilon = @params.Epsilon, + data_format = @params.DataFormat, + is_training = @params.IsTraining + }); + return op.outputs; + } + + public static Tensor[] fused_batch_norm(Tensor x, Tensor scale, Tensor offset, Tensor mean, @@ -284,6 +313,18 @@ namespace Tensorflow.Operations }); return _op.outputs[0]; + } + + public static Tensor leaky_relu_grad(Tensor gradients, Tensor features, float alpha = 0.2f, string name = null) + { + var _op = _op_def_lib._apply_op_helper("LeakyReluGrad", name: name, args: new + { + gradients, + features, + alpha + }); + + return _op.output; } public static Tensor softmax(Tensor logits, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 1b68d1cd..8e7425e5 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -145,7 +145,7 @@ namespace Tensorflow.Operations { var ta = new TensorArray(dtype: dtype_, size: time_steps, - element_shape: element_shape, + element_shape: new[] { element_shape }, tensor_array_name: base_name + name); return ta; }; diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index e6799f61..89ddebdb 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Linq; using static Tensorflow.OpDef.Types; using static Tensorflow.Binding; +using Google.Protobuf; namespace Tensorflow { @@ -194,7 +195,9 @@ namespace Tensorflow if (attrs.ContainsKey(key)) { attr_protos[key] = SetAttrValue(op_def, attr_def, attrs[key]); - } else { + } + else + { if (attr_def.DefaultValue == null) { throw new TypeError("Missing required positional argument " + key); @@ -311,6 +314,16 @@ namespace Tensorflow input_types.AddRange(base_types); } + public ByteString _MakeStr(string value, AttrDef attr_def) + { + return ByteString.CopyFromUtf8(value ?? string.Empty); + } + + public TensorShapeProto _MakeShape(TensorShape shape, AttrDef attr_def) + { + return shape.as_proto(); + } + public DataType _MakeType(TF_DataType v, AttrDef attr_def) { return v.as_base_dtype().as_datatype_enum(); @@ -330,7 +343,7 @@ namespace Tensorflow switch (attr_def.Type) { case "string": - attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); + attr_value.S = _MakeStr((string)value, attr_def); break; case "type": attr_value.Type = _MakeType((TF_DataType)value, attr_def); @@ -363,6 +376,9 @@ namespace Tensorflow else if (value is int[] val3) attr_value.Shape = tensor_util.as_shape(val3); + break; + case "list(shape)": + attr_value.List.Shape.AddRange((value as TensorShape[]).Select(x => _MakeShape(x, attr_def))); break; default: throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index 2f61f954..8e317df9 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -54,6 +54,10 @@ namespace Tensorflow public void _set_control_flow_context(ControlFlowContext ctx) { + if(name == "define_loss/conv_sobj_branch/batch_normalization/cond/FusedBatchNorm_1") + { + + } _control_flow_context = ctx; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 6d6403c9..c80e99f6 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -53,7 +53,7 @@ namespace Tensorflow for (int i = 0; i < NumInputs; i++) { var tf_output = Input(i); - var op = new Operation(tf_output.oper); + var op = GetOperation(tf_output.oper); retval[i] = op.outputs[tf_output.index]; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Instance.cs b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs new file mode 100644 index 00000000..6f6c8226 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Operation.Instance.cs @@ -0,0 +1,41 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; + +namespace Tensorflow +{ + public partial class Operation + { + // cache the mapping between managed and unmanaged op + // some data is stored in managed instance, so when + // create Operation by IntPtr, it will lost some data. + private static Dictionary OpInstances = new Dictionary(); + + /// + /// Get operation by handle + /// + /// + /// + public Operation GetOperation(IntPtr handle) + { + return OpInstances.ContainsKey(handle) ? + OpInstances[handle] : + new Operation(handle); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 62c8f378..6844c892 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -17,13 +17,14 @@ using System; using System.Linq; using System.Runtime.InteropServices; +using static Tensorflow.Binding; namespace Tensorflow { public partial class Operation { public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); - public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); + public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index)); public int OutputListLength(string name) { @@ -48,6 +49,20 @@ namespace Tensorflow public TF_Output this[int index] => _tf_output(index); + /// + /// List this operation's output types. + /// + public TF_DataType[] _output_types + { + get + { + var output_types = range(NumOutputs) + .Select(i => OutputType(i)) + .ToArray(); + return output_types; + } + } + public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { var handle = Marshal.AllocHGlobal(Marshal.SizeOf()); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 5fff9ade..6118602c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -22,58 +22,53 @@ using System.Linq; using Tensorflow.Util; namespace Tensorflow -{ - - /// - /// Represents a graph node that performs computation on tensors. - /// - /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or - /// more `Tensor` objects as input, and produces zero or more `Tensor` - /// objects as output. Objects of type `Operation` are created by - /// calling an op constructor(such as `tf.matmul`) - /// or `tf.Graph.create_op`. - /// - /// For example `c = tf.matmul(a, b)` creates an `Operation` of type - /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` - /// as output. - /// - /// After the graph has been launched in a session, an `Operation` can - /// be executed by passing it to - /// `tf.Session.run`. - /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. +{ + /// + /// Represents a graph node that performs computation on tensors. + /// + /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or + /// more `Tensor` objects as input, and produces zero or more `Tensor` + /// objects as output. Objects of type `Operation` are created by + /// calling an op constructor(such as `tf.matmul`) + /// or `tf.Graph.create_op`. + /// + /// For example `c = tf.matmul(a, b)` creates an `Operation` of type + /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` + /// as output. + /// + /// After the graph has been launched in a session, an `Operation` can + /// be executed by passing it to + /// `tf.Session.run`. + /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. /// public partial class Operation : ITensorOrOperation { private readonly IntPtr _handle; // _c_op in python - private readonly IntPtr _operDesc; + private readonly Graph _graph; + private NodeDef _node_def; - private Graph _graph; public string type => OpType; - public Graph graph => _graph; public int _id => _id_value; public int _id_value; public Operation op => this; - public TF_DataType dtype => TF_DataType.DtInvalid; - public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); - private NodeDef _node_def; public NodeDef node_def { get { - if(_node_def == null) + if (_node_def == null) _node_def = GetNodeDef(); return _node_def; } } - public Operation(IntPtr handle, Graph g=null) + public Operation(IntPtr handle, Graph g = null) { if (handle == IntPtr.Zero) return; @@ -89,24 +84,26 @@ namespace Tensorflow _control_flow_context = _graph._get_control_flow_context(); // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. + OpInstances[_handle] = this; } - public Operation(Graph g, string opType, string oper_name) + /*public Operation(Graph g, string opType, string oper_name) { _graph = g; - _operDesc = c_api.TF_NewOperation(g, opType, oper_name); + var _operDesc = c_api.TF_NewOperation(g, opType, oper_name); c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); - using (var status = new Status()) - { - _handle = c_api.TF_FinishOperation(_operDesc, status); - status.Check(true); - } - - // Dict mapping op name to file and line information for op colocation - // context managers. + lock (Locks.ProcessWide) + using (var status = new Status()) + { + _handle = c_api.TF_FinishOperation(_operDesc, status); + status.Check(true); + } + + // Dict mapping op name to file and line information for op colocation + // context managers. _control_flow_context = graph._get_control_flow_context(); - } + }*/ /// /// Creates an `Operation`. @@ -133,9 +130,9 @@ namespace Tensorflow // Build the list of control inputs. var control_input_ops = new List(); - if(control_inputs != null) + if (control_inputs != null) { - foreach(var c in control_inputs) + foreach (var c in control_inputs) { switch (c) { @@ -163,8 +160,8 @@ namespace Tensorflow if (op_def == null) op_def = g.GetOpDef(node_def.Op); - var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); - (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); + var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); + _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); // Initialize self._outputs. output_types = new TF_DataType[NumOutputs]; @@ -173,12 +170,14 @@ namespace Tensorflow _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) - _outputs[i] = new Tensor(this, i, OutputType(i)); + _outputs[i] = new Tensor(this, i, output_types[i]); graph._add_op(this); if (_handle != IntPtr.Zero) _control_flow_post_processing(); + + OpInstances[_handle] = this; } public void run(FeedItem[] feed_dict = null, Session session = null) @@ -196,15 +195,13 @@ namespace Tensorflow { if (!string.IsNullOrEmpty(input_arg.NumberAttr)) { - input_len = (int)attrs[input_arg.NumberAttr].I; + input_len = (int) attrs[input_arg.NumberAttr].I; is_sequence = true; - } - else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) + } else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) { input_len = attrs[input_arg.TypeListAttr].List.Type.Count; is_sequence = true; - } - else + } else { input_len = 1; is_sequence = false; @@ -221,26 +218,28 @@ namespace Tensorflow return grouped_inputs.ToArray(); } + public T get_attr(string name) + => (T)get_attr(name); + public object get_attr(string name) { AttrValue x = null; - using (var status = new Status()) - using (var buf = new Buffer()) - { - unsafe + lock (Locks.ProcessWide) + using (var status = new Status()) + using (var buf = new Buffer()) { c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); status.Check(true); + x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); } - } string oneof_value = x.ValueCase.ToString(); if (string.IsNullOrEmpty(oneof_value)) return null; - if(oneof_value == "list") + if (oneof_value == "list") throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); if (oneof_value == "type") @@ -259,60 +258,63 @@ namespace Tensorflow private NodeDef GetNodeDef() { - using (var s = new Status()) - using (var buffer = new Buffer()) - { - c_api.TF_OperationToNodeDef(_handle, buffer, s); - s.Check(); - return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); - } - } - - /// - /// Update the input to this operation at the given index. - /// - /// NOTE: This is for TF internal use only.Please don't use it. - /// - /// the index of the input to update. - /// the Tensor to be used as the input at the given index. - public void _update_input(int index, Tensor tensor) - { - _assert_same_graph(tensor); - - var input = _tf_input(index); - var output = tensor._as_tf_output(); - - // Reset cached inputs. - _inputs = null; - // after the c_api call next time _inputs is accessed - // the updated inputs are reloaded from the c_api - using (var status = new Status()) - { - c_api.UpdateEdge(_graph, output, input, status); - //var updated_inputs = inputs; - status.Check(); - } - } - - private void _assert_same_graph(Tensor tensor) - { - //TODO: implement - } - - /// - /// Create and return a new TF_Output for output_idx'th output of this op. - /// - public TF_Output _tf_output(int output_idx) - { - return new TF_Output(op, output_idx); - } - - /// - /// Create and return a new TF_Input for input_idx'th input of this op. - /// - public TF_Input _tf_input(int input_idx) - { - return new TF_Input(op, input_idx); - } - } -} + lock (Locks.ProcessWide) + using (var s = new Status()) + using (var buffer = new Buffer()) + { + c_api.TF_OperationToNodeDef(_handle, buffer, s); + s.Check(); + + return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } + } + + /// + /// Update the input to this operation at the given index. + /// + /// NOTE: This is for TF internal use only.Please don't use it. + /// + /// the index of the input to update. + /// the Tensor to be used as the input at the given index. + public void _update_input(int index, Tensor tensor) + { + _assert_same_graph(tensor); + + var input = _tf_input(index); + var output = tensor._as_tf_output(); + + // Reset cached inputs. + _inputs = null; + // after the c_api call next time _inputs is accessed + // the updated inputs are reloaded from the c_api + lock (Locks.ProcessWide) + using (var status = new Status()) + { + c_api.UpdateEdge(_graph, output, input, status); + //var updated_inputs = inputs; + status.Check(); + } + } + + private void _assert_same_graph(Tensor tensor) + { + //TODO: implement + } + + /// + /// Create and return a new TF_Output for output_idx'th output of this op. + /// + public TF_Output _tf_output(int output_idx) + { + return new TF_Output(op, output_idx); + } + + /// + /// Create and return a new TF_Input for input_idx'th input of this op. + /// + public TF_Input _tf_input(int input_idx) + { + return new TF_Input(op, input_idx); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs new file mode 100644 index 00000000..b4d2e638 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs @@ -0,0 +1,44 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Queues +{ + public class FIFOQueue : QueueBase + { + public FIFOQueue(int capacity, + TF_DataType[] dtypes, + TensorShape[] shapes, + string[] names = null, + string shared_name = null, + string name = "fifo_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + _queue_ref = gen_data_flow_ops.fifo_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs new file mode 100644 index 00000000..d8b93ff2 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/PaddingFIFOQueue.cs @@ -0,0 +1,49 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow.Queues +{ + /// + /// A FIFOQueue that supports batching variable-sized tensors by padding. + /// + public class PaddingFIFOQueue : QueueBase + { + public PaddingFIFOQueue(int capacity, + TF_DataType[] dtypes, + TensorShape[] shapes, + string[] names = null, + string shared_name = null, + string name = "padding_fifo_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + _queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs new file mode 100644 index 00000000..7420c017 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/PriorityQueue.cs @@ -0,0 +1,82 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Queues +{ + public class PriorityQueue : QueueBase + { + public PriorityQueue(int capacity, + TF_DataType[] dtypes, + TensorShape[] shapes, + string[] names = null, + string shared_name = null, + string name = "priority_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + _queue_ref = gen_data_flow_ops.priority_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + + var dtypes1 = dtypes.ToList(); + dtypes1.Insert(0, TF_DataType.TF_INT64); + _dtypes = dtypes1.ToArray(); + + var shapes1 = shapes.ToList(); + shapes1.Insert(0, new TensorShape()); + _shapes = shapes1.ToArray(); + } + + public Operation enqueue_many(long[] indexes, T[] vals, string name = null) + { + return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => + { + var vals_tensor1 = _check_enqueue_dtypes(indexes); + var vals_tensor2 = _check_enqueue_dtypes(vals); + + var tensors = new List(); + tensors.AddRange(vals_tensor1); + tensors.AddRange(vals_tensor2); + + return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, tensors.ToArray(), name: scope); + }); + } + + public Tensor[] dequeue(string name = null) + { + Tensor[] ret; + if (name == null) + name = $"{_name}_Dequeue"; + + if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) + ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name); + else + ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name); + + return ret; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs new file mode 100644 index 00000000..b420d2c9 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs @@ -0,0 +1,123 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Queues +{ + public class QueueBase + { + protected TF_DataType[] _dtypes; + protected TensorShape[] _shapes; + protected string[] _names; + protected Tensor _queue_ref; + protected string _name; + + public QueueBase(TF_DataType[] dtypes, TensorShape[] shapes, string[] names) + { + _dtypes = dtypes; + _shapes = shapes; + _names = names; + } + + public Operation enqueue(Tensor val, string name = null) + { + return tf_with(ops.name_scope(name, $"{_name}_enqueue", val), scope => + { + var vals = new[] { val }; + if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) + return gen_data_flow_ops.queue_enqueue_v2(_queue_ref, vals, name: scope); + else + return gen_data_flow_ops.queue_enqueue(_queue_ref, vals, name: scope); + }); + } + + public Operation enqueue_many(T[] vals, string name = null) + { + return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => + { + var vals_tensor = _check_enqueue_dtypes(vals); + return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, vals_tensor, name: scope); + }); + } + + protected Tensor[] _check_enqueue_dtypes(object vals) + { + var tensors = new List(); + + switch (vals) + { + case int[][] vals1: + { + int i = 0; + foreach (var (val, dtype) in zip(vals1, _dtypes)) + tensors.Add(ops.convert_to_tensor(val, dtype: dtype, name: $"component_{i++}")); + } + break; + + default: + var dtype1 = GetType().Name == "PriorityQueue" ? _dtypes[1] : _dtypes[0]; + tensors.Add(ops.convert_to_tensor(vals, dtype: dtype1, name: $"component_0")); + break; + } + + return tensors.ToArray(); + } + + /// + /// Dequeues one element from this queue. + /// + /// + /// + public Tensor dequeue(string name = null) + { + Tensor ret; + if (name == null) + name = $"{_name}_Dequeue"; + + if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) + ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name)[0]; + else + ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name)[0]; + + return ret; + } + + public Tensor[] dequeue_many(int n, string name = null) + { + if (name == null) + name = $"{_name}_DequeueMany"; + + var ret = gen_data_flow_ops.queue_dequeue_many_v2(_queue_ref, n: n, component_types: _dtypes, name: name); + //var op = ret[0].op; + //var cv = tensor_util.constant_value(op.inputs[1]); + //var batch_dim = new Dimension(cv); + + return _dequeue_return_value(ret); + } + + public Tensor[] _dequeue_return_value(Tensor[] tensors) + { + if (_names != null) + throw new NotImplementedException(""); + return tensors; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs new file mode 100644 index 00000000..6846f478 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/RandomShuffleQueue.cs @@ -0,0 +1,57 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Queues +{ + /// + /// Create a queue that dequeues elements in a random order. + /// + public class RandomShuffleQueue : QueueBase + { + public RandomShuffleQueue(int capacity, + int min_after_dequeue, + TF_DataType[] dtypes, + TensorShape[] shapes, + string[] names = null, + int? seed = null, + string shared_name = null, + string name = "random_shuffle_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + var(seed1, seed2) = random_seed.get_seed(seed); + if (!seed1.HasValue && !seed2.HasValue) + (seed1, seed2) = (0, 0); + + + _queue_ref = gen_data_flow_ops.random_shuffle_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + min_after_dequeue: min_after_dequeue, + seed: seed1.Value, + seed2: seed2.Value, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/TensorArray.cs b/src/TensorFlowNET.Core/Operations/TensorArray.cs index 858dac47..7251bf85 100644 --- a/src/TensorFlowNET.Core/Operations/TensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/TensorArray.cs @@ -33,9 +33,13 @@ namespace Tensorflow.Operations { _GraphTensorArray _implementation; - public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null, + public TF_DataType dtype => _implementation._dtype; + public Tensor handle => _implementation._handle; + public Tensor flow => _implementation._flow; + + public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, TensorShape element_shape = null, + bool infer_shape = true, TensorShape[] element_shape = null, bool colocate_with_first_write_call = true, string name = null) { _implementation = new _GraphTensorArray(dtype, @@ -50,5 +54,8 @@ namespace Tensorflow.Operations colocate_with_first_write_call: colocate_with_first_write_call, name: name); } + + public TensorArray unstack(Tensor value, string name = null) + => _implementation.unstack(value, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index 4c700a5f..bd919ad8 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -16,6 +16,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using static Tensorflow.Binding; @@ -23,7 +24,7 @@ namespace Tensorflow.Operations { internal class _GraphTensorArray { - TF_DataType _dtype; + internal TF_DataType _dtype; /// /// Used to keep track of what tensors the TensorArray should be @@ -33,23 +34,27 @@ namespace Tensorflow.Operations bool _colocate_with_first_write_call; bool _infer_shape; + bool _dynamic_size; List _element_shape; - object _colocate_with; + List _colocate_with; + + internal Tensor _handle; + internal Tensor _flow; public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, TensorShape element_shape = null, + bool infer_shape = true, TensorShape[] element_shape = null, bool colocate_with_first_write_call = true, string name = null) { clear_after_read = clear_after_read ?? true; dynamic_size = dynamic_size ?? false; - + _dynamic_size = dynamic_size.Value; _dtype = dtype; _colocate_with_first_write_call = colocate_with_first_write_call; if (colocate_with_first_write_call) - _colocate_with = new Tensor[0]; + _colocate_with = new List(); // Record the current static shape for the array elements. The element // shape is defined either by `element_shape` or the shape of the tensor @@ -66,11 +71,12 @@ namespace Tensorflow.Operations _element_shape = new List { }; } - tf_with(ops.name_scope(name, "", new { handle, size, flow }), scope => + tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => { if(handle != null) { - + _handle = handle; + _flow = flow; } else { @@ -89,14 +95,65 @@ namespace Tensorflow.Operations if (colocate_with_first_write_call) { ops.colocate_with(ignore_existing: true); - create(); + (_handle, _flow) = create(); } else { - + (_handle, _flow) = create(); } } }); } + + public TensorArray unstack(Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate + { + var num_elements = array_ops.shape(value)[0]; + return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); + }); + } + + public TensorArray scatter(Tensor indices, Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + if (_infer_shape) + { + var shape = new TensorShape(value.TensorShape.dims.Skip(1).ToArray()); + _merge_element_shape(shape); + } + + _maybe_colocate_with(value); + var flow_out = gen_data_flow_ops.tensor_array_scatter_v3( + handle: _handle, + indices: indices, + value: value, + flow_in: _flow, + name: name); + + var ta = new TensorArray(_dtype, + infer_shape:_infer_shape, + element_shape: _element_shape.ToArray(), + dynamic_size: _dynamic_size, + handle: _handle, + flow: flow_out, + colocate_with_first_write_call: _colocate_with_first_write_call); + + + return ta; + }); + } + + public void _merge_element_shape(TensorShape shape) + { + _element_shape.Add(shape); + } + + public void _maybe_colocate_with(Tensor value) + { + _colocate_with.Add(value); + } } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 92f65906..12094e41 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -17,6 +17,8 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; using static Tensorflow.Binding; namespace Tensorflow @@ -54,6 +56,8 @@ namespace Tensorflow return _constant_if_small(0.0D, shape, dtype, name); case TF_DataType.TF_FLOAT: return _constant_if_small(0.0F, shape, dtype, name); + case TF_DataType.TF_INT64: + return _constant_if_small(0l, shape, dtype, name); case TF_DataType.TF_INT32: return _constant_if_small(0, shape, dtype, name); case TF_DataType.TF_INT8: @@ -64,6 +68,44 @@ namespace Tensorflow }); } + public static Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) + { + return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate + { + var tensor_tensor = ops.convert_to_tensor(tensor, name: "tensor"); + var mask_tensor = ops.convert_to_tensor(mask, name: "mask"); + + var shape_mask = mask_tensor.TensorShape; + var ndims_mask = shape_mask.ndim; + var shape_tensor = tensor_tensor.TensorShape; + + if (ndims_mask < 1) + throw new ValueError("mask cannot be scalar."); + + var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], new[] { 0 }); + var shape1 = concat(new[] + { + shape(tensor_tensor)[$":{axis}"], + tf.expand_dims(leading_size, 0), + shape(tensor_tensor)[$"{axis + ndims_mask}:"] + }, 0); + tensor_tensor = reshape(tensor, shape1); + var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); + var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray()); + var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); + tensor_tensor.set_shape(s2); + + mask_tensor = reshape(mask_tensor, new[] { -1 }); + return _apply_mask_1d(tensor_tensor, mask_tensor, axis); + }); + } + + private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) + { + var indices = squeeze(where(mask), axis: new[] { 1 }); + return gather(reshaped_tensor, indices, axis: axis); + } + public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { dtype = dtype.as_base_dtype(); @@ -306,11 +348,40 @@ namespace Tensorflow public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) => gen_array_ops.unique(x, out_idx: out_idx, name: name); + public static Tensor stack(Tensor[] values, int axis = 0, string name = "stack") + { + if (axis == 0) + { + return ops.convert_to_tensor(values, name: name); + } + + var value_shape = ops.convert_to_tensor(values[0], name: name).TensorShape; + + return gen_array_ops.pack(values, axis: axis, name: name); + } + + public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") + { + if(num == null) + { + value = ops.convert_to_tensor(value); + var value_shape = value.TensorShape; + num = value_shape.dims[axis]; + } + + return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name); + } + public static Tensor where(Tensor condition, object x = null, object y = null, string name = null) { if( x == null && y == null) { - throw new NotImplementedException("where"); + return tf_with(ops.name_scope(name, "Where", new { condition }), scope => + { + name = scope; + condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); + return gen_array_ops.where(condition: condition, name: name); + }); } else if(x != null && y != null) { @@ -338,7 +409,7 @@ namespace Tensorflow public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) => size_internal(input, name, optimize: optimize, out_type: out_type); - private static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) + public static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) { return tf_with(ops.name_scope(name, "Shape", new { input }), scope => { @@ -540,7 +611,7 @@ namespace Tensorflow }); } - public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) => gen_array_ops.slice(input, begin, size, name: name); public static Tensor stack(object values, int axis = 0, string name = "stack") @@ -552,6 +623,40 @@ namespace Tensorflow throw new NotImplementedException("array_ops.stack"); } + public static Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0) + { + Tensor result = null; + mode = mode.ToUpper(); + if(mode == "CONSTANT") + { + if (constant_values != 0) + throw new NotImplementedException("gen_array_ops.pad_v2"); + else + result = gen_array_ops.pad(tensor, paddings, name: name); + } + + // Restore shape information where possible. + var paddings_constant = tensor_util.constant_value( + result.op.inputs[1], partial: true); + var input_shape = result.op.inputs[0].TensorShape; + if (input_shape.ndim > -1 && + !result.TensorShape.is_fully_defined() && + !(paddings_constant is null)) + { + var new_shape = new List(); + foreach((NDArray padding, int dim) in zip(paddings_constant.GetNDArrays(), np.array(input_shape.dims).GetNDArrays())) + { + if (padding is null || dim == -1 || padding.GetData().Contains(-1)) + new_shape.Add(-1); + else + new_shape.Add(np.sum(padding) + dim); + } + result.set_shape(new_shape.ToArray()); + } + + return result; + } + public static Tensor placeholder(TF_DataType dtype) { throw new NotImplementedException("array_ops.placeholder"); diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index fa24b0ef..a23cd406 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -198,7 +198,7 @@ namespace Tensorflow /// int /// [DllImport(TensorFlowLibName)] - public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers); + public static extern int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers); [DllImport(TensorFlowLibName)] public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); diff --git a/src/TensorFlowNET.Core/Operations/check_ops.cs b/src/TensorFlowNET.Core/Operations/check_ops.cs index 8b670640..ef2ea3b6 100644 --- a/src/TensorFlowNET.Core/Operations/check_ops.cs +++ b/src/TensorFlowNET.Core/Operations/check_ops.cs @@ -26,7 +26,7 @@ namespace Tensorflow /// /// /// - public static Operation assert_equal(object t1, object t2, object[] data = null, string message = null, string name = null) + public static Operation assert_equal(T1 t1, T2 t2, object[] data = null, string message = null, string name = null) { if (message == null) message = ""; diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs similarity index 94% rename from src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs rename to src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 04ef54a7..e8b5f0eb 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -275,7 +275,7 @@ namespace Tensorflow /// public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") { - data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); + data = ops.convert_to_tensor_or_composite(data, name: "data"); // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below // addresses the following scenario. // @@ -296,9 +296,8 @@ namespace Tensorflow { if (data is Tensor) { - // TODO: ref_switch - //if (data.dtype._is_ref_dtype) - // return control_flow_ops.ref_switch(data, pred, name = name); + if (data.dtype.is_ref_dtype()) + return gen_control_flow_ops.ref_switch(data, pred, name: name); } return @switch(data, pred, name: name); } @@ -519,7 +518,7 @@ namespace Tensorflow inputs = inputs.Select(inp => ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) .ToArray(); - return gen_control_flow_ops.merge(inputs, name).Item1; + return gen_control_flow_ops.merge(inputs, name)[0]; }); } @@ -558,8 +557,31 @@ namespace Tensorflow throw new NotImplementedException("ZerosLikeOutsideLoop"); return array_ops.zeros_like(val, optimize: false); } - - throw new NotImplementedException("ZerosLikeOutsideLoop"); + else + { + var op_ctxt = op._get_control_flow_context(); + if(op_ctxt != null) + { + // We are in a cond context. Use a switch to create zeros only when needed. + var pred = op_ctxt.pred; + var branch = op_ctxt.branch; + var switch_val = @switch(op.inputs[0], pred)[1 - branch]; + var pivot = array_ops.identity(switch_val); + if (val.dtype == dtypes.resource) + throw new NotImplementedException(""); + var zeros_shape = array_ops.shape_internal(switch_val, optimize: false); + // Ensure ops created within array_ops.zeros are dominated by switch in + // cond context. + return tf_with(ops.control_dependencies(new[] { pivot }), delegate + { + return array_ops.zeros(zeros_shape, dtype: val.dtype); + }); + } + else + { + return array_ops.zeros_like(val, optimize: false); + } + } } /// diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 092d152c..36837477 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -26,6 +26,20 @@ namespace Tensorflow public static OpDefLibrary _op_def_lib = new OpDefLibrary(); public static Execute _execute = new Execute(); + public static Tensor batch_to_space_nd(T input, int[] block_shape, int[,] crops, string name = null) + { + var _op = _op_def_lib._apply_op_helper("BatchToSpaceND", name: name, args: new { input, block_shape, crops }); + + return _op.output; + } + + public static Tensor check_numerics(Tensor tensor, string message, string name = null) + { + var _op = _op_def_lib._apply_op_helper("CheckNumerics", name: name, args: new { tensor, message }); + + return _op.output; + } + /// /// Concatenates tensors along one dimension. /// @@ -99,6 +113,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor pad(Tensor input, Tensor paddings, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Pad", name: name, args: new { input, paddings }); + + return _op.output; + } + public static Tensor pack(Tensor[] values, int axis = 0, string name = null) { var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis }); @@ -214,10 +235,16 @@ namespace Tensorflow return (_op.outputs[0], _op.outputs[1]); } + public static Tensor reverse(Tensor tensor, T axis, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ReverseV2", name, new { tensor, axis }); + return _op.output; + } + public static Tensor reshape(T1 tensor, T2 shape, string name = null) { var _op = _op_def_lib._apply_op_helper("Reshape", name, new { tensor, shape }); - return _op.outputs[0]; + return _op.output; } public static Tensor reshape(Tensor tensor, int[] shape, string name = null) @@ -241,9 +268,16 @@ namespace Tensorflow return (_op.outputs[0], _op.outputs[1]); } - public static Tensor where() + public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name = null) { - throw new NotImplementedException("where"); + var _op = _op_def_lib._apply_op_helper("Unpack", name, new { value, num, axis }); + return _op.outputs; + } + + public static Tensor where(Tensor condition, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Where", name, new { input = condition }); + return _op.output; } public static Tensor one_hot(Tensor indices, int depth, @@ -327,12 +361,7 @@ namespace Tensorflow return _op.outputs; } - public static Tensor tile(Tensor input, Tensor multiples, string name = null) - { - var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples }); - return _op.outputs[0]; - } - public static Tensor tile(NDArray input, int[] multiples, string name = null) + public static Tensor tile(Tensor input, T multiples, string name = null) { var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples }); return _op.outputs[0]; @@ -446,7 +475,7 @@ namespace Tensorflow return op.output; } - public static Tensor slice(Tensor input, Tb[] begin, Ts[] size, string name = null) + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) { var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs similarity index 90% rename from src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs rename to src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs index bfbf3413..5f0ceb48 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.cs @@ -114,6 +114,12 @@ namespace Tensorflow return _op; } + public static Tensor[] ref_switch(Tensor data, Tensor pred, string name = null) + { + var _op = _op_def_lib._apply_op_helper("RefSwitch", name, new { data, pred }); + return _op.outputs; + } + /// /// Forwards `data` to the output port determined by `pred`. /// @@ -142,11 +148,18 @@ namespace Tensorflow return new []{_op.outputs[0], _op.outputs[1]}; } - public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) + public static Tensor[] ref_merge(Tensor[] inputs, string name = null) + { + var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); + + return _op.outputs; + } + + public static Tensor[] merge(Tensor[] inputs, string name = null) { var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); - return (_op.outputs[0], _op.outputs[1]); + return _op.outputs; } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index 2cb9aac6..fa194934 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -22,16 +22,18 @@ namespace Tensorflow public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null) { - var _attr_N = indices.Length; var _op = _op_def_lib._apply_op_helper("DynamicStitch", name, new { indices, data }); - return _op.outputs[0]; + return _op.output; } - public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid, - int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, + public static (Tensor, Tensor) tensor_array_v3(T size, TF_DataType dtype = TF_DataType.DtInvalid, + TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) { + if (tensor_array_name == null) + tensor_array_name = string.Empty; + var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new { size, @@ -43,7 +45,161 @@ namespace Tensorflow tensor_array_name }); - return (null, null); + return (_op.outputs[0], _op.outputs[1]); + } + + public static Tensor tensor_array_scatter_v3(Tensor handle, Tensor indices, Tensor value, + Tensor flow_in, string name = null) + { + var _op = _op_def_lib._apply_op_helper("TensorArrayScatterV3", name, new + { + handle, + indices, + value, + flow_in + }); + + return _op.output; + } + + public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, + int capacity = -1, string container = "", string shared_name = "", + string name = null) + { + var _op = _op_def_lib._apply_op_helper("PaddingFIFOQueueV2", name, new + { + component_types, + shapes, + capacity, + container, + shared_name + }); + + return _op.output; + } + + public static Tensor fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, + int capacity = -1, string container = "", string shared_name = "", + string name = null) + { + var _op = _op_def_lib._apply_op_helper("FIFOQueueV2", name, new + { + component_types, + shapes, + capacity, + container, + shared_name + }); + + return _op.output; + } + + public static Tensor priority_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, + int capacity = -1, string container = "", string shared_name = "", + string name = null) + { + var _op = _op_def_lib._apply_op_helper("PriorityQueueV2", name, new + { + component_types, + shapes, + capacity, + container, + shared_name + }); + + return _op.output; + } + + public static Tensor random_shuffle_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, + int capacity = -1, int min_after_dequeue = 0, int seed = 0, int seed2 = 0, + string container = "", string shared_name = "", string name = null) + { + var _op = _op_def_lib._apply_op_helper("RandomShuffleQueueV2", name, new + { + component_types, + shapes, + capacity, + min_after_dequeue, + seed, + seed2, + container, + shared_name + }); + + return _op.output; + } + + public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) + { + var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new + { + handle, + components, + timeout_ms + }); + + return _op; + } + + public static Operation queue_enqueue_v2(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) + { + var _op = _op_def_lib._apply_op_helper("QueueEnqueueV2", name, new + { + handle, + components, + timeout_ms + }); + + return _op; + } + + public static Tensor[] queue_dequeue_v2(Tensor handle, TF_DataType[] component_types, int timeout_ms = -1, string name = null) + { + var _op = _op_def_lib._apply_op_helper("QueueDequeueV2", name, new + { + handle, + component_types, + timeout_ms + }); + + return _op.outputs; + } + + public static Tensor[] queue_dequeue(Tensor handle, TF_DataType[] component_types, int timeout_ms = -1, string name = null) + { + var _op = _op_def_lib._apply_op_helper("QueueDequeue", name, new + { + handle, + component_types, + timeout_ms + }); + + return _op.outputs; + } + + public static Operation queue_enqueue_many_v2(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) + { + var _op = _op_def_lib._apply_op_helper("QueueEnqueueManyV2", name, new + { + handle, + components, + timeout_ms + }); + + return _op; + } + + public static Tensor[] queue_dequeue_many_v2(Tensor handle, int n, TF_DataType[] component_types, int timeout_ms = -1, string name = null) + { + var _op = _op_def_lib._apply_op_helper("QueueDequeueManyV2", name, new + { + handle, + n, + component_types, + timeout_ms + }); + + return _op.outputs; } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs index 90893815..143d4fe8 100644 --- a/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs @@ -169,5 +169,33 @@ namespace Tensorflow return _op.outputs[0]; } } + + public static Tensor resize_nearest_neighbor(Tensor images, Tsize size, bool align_corners = false, + bool half_pixel_centers = false, string name = null) + { + var op = _op_def_lib._apply_op_helper("ResizeNearestNeighbor", name: name, args: new + { + images, + size, + align_corners, + half_pixel_centers + }); + + return op.output; + } + + public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tsize size, bool align_corners = false, + bool half_pixel_centers = false, string name = null) + { + var op = _op_def_lib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new + { + grads, + size, + align_corners, + half_pixel_centers + }); + + return op.output; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index c1257e19..81870e5b 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -632,6 +632,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor _any(Tx input, Ty axis, bool keep_dims = false, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Any", name, new { input, reduction_indices = axis, keep_dims }); + + return _op.outputs[0]; + } + public static Tensor _max(Tx input, Ty axis, bool keep_dims=false, string name = null) { var _op = _op_def_lib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }); diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index e47002ef..6e91be02 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -7730,7 +7730,7 @@ namespace Tensorflow.Operations /// /// /// RFC 4180 format is expected for the CSV records. - /// (https://tools.ietf.org/html/rfc4180) + /// (https://tools.ietensorflow.org/html/rfc4180) /// Note that we allow leading and trailing spaces with int or float field. /// public static Tensor[] decode_c_s_v (Tensor records, Tensor[] record_defaults, string field_delim = null, bool? use_quote_delim = null, string na_value = null, int[] select_cols = null, string name = "DecodeCSV") diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs index 06ae70a3..011b673f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs @@ -90,6 +90,23 @@ namespace Tensorflow return _op.outputs[0]; } + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null) + { + var _op = _op_def_lib._apply_op_helper("RandomShuffle", + name: name, + args: new { value, seed, seed2 }); + + return _op.output; + } + /// /// Outputs random values from a truncated normal distribution. /// diff --git a/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs b/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs new file mode 100644 index 00000000..d59afc88 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs @@ -0,0 +1,74 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System.Collections.Generic; +using Tensorflow.Framework; + +namespace Tensorflow +{ + public class gen_sparse_ops + { + public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + + /// + /// Converts a sparse representation into a dense tensor. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor sparse_to_dense(Tensor sparse_indices, + int[] output_shape, + T sparse_values, + T default_value, + bool validate_indices = true, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("SparseToDense", name, args: new + { + sparse_indices, + output_shape, + sparse_values, + default_value, + validate_indices + }); + + return _op.output; + } + + public static Tensor sparse_to_dense(Tensor sparse_indices, + Tensor output_shape, + Tensor sparse_values, + T default_value = default, + bool validate_indices = true, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("SparseToDense", name, args: new + { + sparse_indices, + output_shape, + sparse_values, + default_value, + validate_indices + }); + + return _op.output; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index 65ed8eb1..b7573b92 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -128,5 +128,46 @@ namespace Tensorflow throw new NotImplementedException(""); } + + /// + /// Resize `images` to `size` using the specified `method`. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor resize_images(Tensor images, Tensor size, ResizeMethod method = ResizeMethod.BILINEAR, + bool align_corners = false, bool preserve_aspect_ratio = false, string name = null) + { + throw new NotImplementedException(""); + } + + /// + /// Resize `images` to `size` using nearest neighbor interpolation. + /// + /// + /// + /// + /// + /// + /// + public static Tensor resize_nearest_neighbor(Tensor images, Tsize size, bool align_corners = false, + string name = null, bool half_pixel_centers = false) + => gen_image_ops.resize_nearest_neighbor(images: images, + size: size, + align_corners: align_corners, + half_pixel_centers: half_pixel_centers, + name: name); + } + + public enum ResizeMethod + { + BILINEAR = 0, + NEAREST_NEIGHBOR = 1, + BICUBIC = 2, + AREA = 3 } } diff --git a/src/TensorFlowNET.Core/Operations/map_fn.cs b/src/TensorFlowNET.Core/Operations/map_fn.cs new file mode 100644 index 00000000..1206d5b9 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/map_fn.cs @@ -0,0 +1,86 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Operation + { + /// + /// map on the list of tensors unpacked from `elems` on dimension 0. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A tensor or (possibly nested) sequence of tensors. + public static Tensor map_fn(Func fn, + Tensor elems, + TF_DataType dtype = TF_DataType.DtInvalid, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + string name = null) + { + var elems_flat = new[] { elems }; + tf_with(ops.name_scope(name, "map", elems_flat), delegate + { + var varscope = tf.get_variable_scope(); + elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")) + .ToArray(); + + dtype = elems_flat.Select(elem => elem.dtype).First(); + var dtype_flat = new[] { dtype }; + + // Convert elems to tensor array. n may be known statically. + var static_shape = elems_flat[0].shape; + + var n = static_shape[0]; + + // TensorArrays are always flat + var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype, + size: ops.convert_to_tensor(n), + dynamic_size: false, + infer_shape: true)).ToArray(); + + // Unpack elements + var elems_ta_1 = new List(); + foreach (var (elem_ta, elem) in zip(elems_ta, elems_flat)) + elems_ta_1.Add(elem_ta.unstack(elem)); + + elems_ta = elems_ta_1.ToArray(); + + var i = constant_op.constant(0); + + var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt, + size: ops.convert_to_tensor(n), + dynamic_size: false, + infer_shape: infer_shape)).ToArray(); + + /*Func compute = (i, tas) => + { + throw new NotImplementedException(""); + }; + + var r_a = control_flow_ops.while_loop( + (i, _) => i < n, + compute, + new[] { i, accs_ta }, + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory, + maximum_iterations: n);*/ + }); + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index f5cfdb37..94c42ba2 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -361,6 +361,14 @@ namespace Tensorflow }); } + public static Tensor reduce_any(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + { + var r = _ReductionDims(input_tensor, axis); + var max = (axis != null) ? gen_math_ops._any(input_tensor, axis, keepdims, name) : + gen_math_ops._any(input_tensor, r, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, max); + } + public static Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) { var r = _ReductionDims(input_tensor, axis); @@ -422,6 +430,13 @@ namespace Tensorflow return _may_reduce_to_scalar(keepdims, axis, m); } + public static Tensor reduce_sum(Tensor input_tensor, int[] axis, bool keepdims = false, string name = null) + { + var r = _ReductionDims(input_tensor, axis); + var m = gen_math_ops._sum(input_tensor, r, keep_dims: keepdims, name: name); + return _may_reduce_to_scalar(keepdims, axis, m); + } + public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false, string name = null) { var m = gen_math_ops._sum(input_tensor, axis, keep_dims: keepdims, name: name); @@ -492,7 +507,7 @@ namespace Tensorflow public static Tensor rsqrt(Tensor x, string name = null) => gen_math_ops.rsqrt(x, name: name); - public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range" ) + public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") { if(limit == null) { diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index bd70c10a..bced0047 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using Tensorflow.Operations; using static Tensorflow.Binding; @@ -83,6 +84,19 @@ namespace Tensorflow }); } + /// + /// Batch normalization. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// public static Tensor[] fused_batch_norm(Tensor x, RefVariable scale, RefVariable offset, @@ -103,7 +117,7 @@ namespace Tensorflow var min_epsilon = 1.001e-5f; epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; - return gen_nn_ops._fused_batch_norm(x, + var results = gen_nn_ops.fused_batch_norm(x, scale_tensor, offset_tensor, mean, @@ -112,6 +126,12 @@ namespace Tensorflow data_format, is_training, name); + + var y = results[0]; + var batch_mean = results[1]; + var batch_var = results[2]; + + return new[] { y, batch_mean, batch_var }; } /// @@ -132,6 +152,27 @@ namespace Tensorflow }); } + public static Tensor sigmoid_cross_entropy_with_logits(Tensor labels, Tensor logits, string name = null) + { + return tf_with(ops.name_scope(name, "logistic_loss", new { logits, labels }), scope => + { + name = scope; + logits = ops.convert_to_tensor(logits, name: "logits"); + labels = ops.convert_to_tensor(labels, name: "labels"); + labels.TensorShape.merge_with(logits.TensorShape); + + var zeros = array_ops.zeros_like(logits, dtype: logits.dtype); + var cond = (logits >= zeros); + var relu_logits = array_ops.where(cond, logits, zeros); + var neg_abs_logits = array_ops.where(cond, -logits, logits); + + return math_ops.add( + relu_logits - logits * labels, + gen_math_ops.log1p(gen_math_ops.exp(neg_abs_logits)), + name: name); + }); + } + /// /// Returns the fraction of zeros in value. /// diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index b189bb83..7ae1f3a9 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -116,6 +116,12 @@ namespace Tensorflow return _softmax(logits, gen_nn_ops.log_softmax, axis, name); } + /// equivalent to `dim` + public static Tensor softmax(Tensor logits, int axis = -1, string name = null) + { + return _softmax(logits, gen_nn_ops.softmax, axis, name); + } + public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) { return tf_with(ops.name_scope(name, "LeakyRelu", new { features, alpha }), scope => diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs index 02e522bf..9251f867 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs @@ -104,6 +104,19 @@ namespace Tensorflow }); } + /// + /// Randomly shuffles a tensor along its first dimension. + /// + /// + /// + /// + /// + public static Tensor random_shuffle(Tensor value, int? seed = null, string name = null) + { + var (seed1, seed2) = random_seed.get_seed(seed); + return gen_random_ops.random_shuffle(value, seed: seed1.Value, seed2: seed2.Value, name: name); + } + public static Tensor truncated_normal(int[] shape, float mean = 0.0f, float stddev = 1.0f, diff --git a/src/TensorFlowNET.Core/Operations/sparse_ops.cs b/src/TensorFlowNET.Core/Operations/sparse_ops.cs new file mode 100644 index 00000000..6a30771c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/sparse_ops.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class sparse_ops + { + /// + /// Converts a sparse representation into a dense tensor. + /// + /// + /// + /// + /// + /// + /// + /// + /// Dense `Tensor` of shape `output_shape`. Has the same type as `sparse_values`. + public Tensor sparse_to_dense(Tensor sparse_indices, + int[] output_shape, + T sparse_values, + T default_value = default, + bool validate_indices = true, + string name = null) + => gen_sparse_ops.sparse_to_dense(sparse_indices, + output_shape, + sparse_values, + default_value: default_value, + validate_indices: validate_indices, + name: name); + } +} diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 58177df2..1701c625 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -21,6 +21,9 @@ using System.Collections.Generic; using System.Linq; using System.Numerics; using System.Text; +using Google.Protobuf; +using NumSharp.Backends; +using Tensorflow.Util; namespace Tensorflow { @@ -33,25 +36,20 @@ namespace Tensorflow protected byte[] _target; public Graph graph => _graph; - public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) + public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null) { - _graph = g is null ? ops.get_default_graph() : g; + _graph = g ?? ops.get_default_graph(); _graph.as_default(); - _target = UTF8Encoding.UTF8.GetBytes(target); + _target = Encoding.UTF8.GetBytes(target); - SessionOptions newOpts = null; - if (opts == null) - newOpts = new SessionOptions(); + SessionOptions lopts = opts ?? new SessionOptions(); - var status = new Status(); - - _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); - - // dispose newOpts - if (opts == null) - newOpts.Dispose(); - - status.Check(true); + lock (Locks.ProcessWide) + { + status = status ?? new Status(); + _handle = c_api.TF_NewSession(_graph, opts ?? lopts, status); + status.Check(true); + } } public virtual void run(Operation op, params FeedItem[] feed_dict) @@ -71,19 +69,19 @@ namespace Tensorflow public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { - var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); + var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict); return (results[0], results[1], results[2], results[3]); } public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { - var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); + var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict); return (results[0], results[1], results[2]); } public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { - var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); + var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict); return (results[0], results[1]); } @@ -94,33 +92,24 @@ namespace Tensorflow public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) { - var feed_items = feed_dict == null ? new FeedItem[0] : - feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); + var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); return _run(fetches, feed_items); } private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) { var feed_dict_tensor = new Dictionary(); - var feed_map = new Dictionary(); - - Func> feed_fn = (item) => - { - return new (object, object)[] { (item.Key, item.Value) }; - }; + //var feed_map = new Dictionary(); // Validate and process feed_dict. if (feed_dict != null) { - foreach (var feed in feed_dict) + foreach (var subfeed in feed_dict) { - foreach (var (subfeed, subfeed_val) in feed_fn(feed)) - { - var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); - //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used - feed_dict_tensor[subfeed_t] = subfeed_val; - feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); - } + var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); + //var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used + feed_dict_tensor[subfeed_t] = subfeed.Value; + //feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); } } @@ -137,7 +126,7 @@ namespace Tensorflow // We only want to really perform the run if fetches or targets are provided, // or if the call is a partial run that specifies feeds. - var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); + var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor); return fetch_handler.build_results(this, results); } @@ -157,88 +146,81 @@ namespace Tensorflow /// private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) { - var feeds = feed_dict.Select(x => + var feeds = new KeyValuePair[feed_dict.Count]; + int i = 0; + foreach (var x in feed_dict) { - if (x.Key is Tensor tensor) + if (x.Key is Tensor key) { switch (x.Value) { + case Tensor v: + if (v.dtype != key.dtype) + throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); + feeds[i++] = new KeyValuePair(key._as_tf_output(), v); + break; + case NDArray v: + feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); + break; + case IntPtr v: + var tensor = new Tensor(v); + if (tensor.dtype != key.dtype) + throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); + + feeds[i++] = new KeyValuePair(key._as_tf_output(), tensor); + break; #if _REGEN - %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] - %foreach types% - case #1 v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case #1[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - % + // @formatter:off — disable formatter after this line + %types = ["bool", "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %foreach types% + case #1 v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case #1[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + % + // @formatter:on — enable formatter after this line #else - case sbyte v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case sbyte[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case byte v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case byte[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case short v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case short[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ushort v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ushort[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case int v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case int[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case uint v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case uint[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case long v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case long[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ulong v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ulong[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case float v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case float[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case double v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case double[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Complex v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Complex[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + // @formatter:off — disable formatter after this line + case bool v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case bool[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case sbyte v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case sbyte[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case byte v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case byte[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case short v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case short[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case ushort v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case ushort[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case int v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case int[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case uint v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case uint[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case long v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case long[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case ulong v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case ulong[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case float v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case float[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case double v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case double[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case Complex v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + case Complex[] v: feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break; + // @formatter:on — enable formatter after this line #endif - case bool v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); + case string v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case IntPtr v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Tensor v: - return new KeyValuePair(tensor._as_tf_output(), v); - case NDArray v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); + feeds[i++] = new KeyValuePair(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); + break; default: - throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "")}"); + throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? ""}"); } } - throw new NotImplementedException("_do_run.feed_dict"); - }).ToArray(); - var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); - var targets = target_list; + } + var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); + //var targets = target_list; return _call_tf_sessionrun(feeds, fetches, target_list); } + private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) { // Ensure any changes to the graph are reflected in the runtime. @@ -251,12 +233,12 @@ namespace Tensorflow c_api.TF_SessionRun(_handle, run_options: null, inputs: feed_dict.Select(f => f.Key).ToArray(), - input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), + input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(), ninputs: feed_dict.Length, outputs: fetch_list, output_values: output_values, noutputs: fetch_list.Length, - target_opers: target_list.Select(f => (IntPtr)f).ToArray(), + target_opers: target_list.Select(f => (IntPtr) f).ToArray(), ntargets: target_list.Count, run_metadata: IntPtr.Zero, status: status); @@ -268,117 +250,184 @@ namespace Tensorflow for (int i = 0; i < fetch_list.Length; i++) result[i] = fetchValue(output_values[i]); - for (int i = 0; i < feed_dict.Length; i++) - feed_dict[i].Value.Dispose(); - return result; } - private unsafe NDArray fetchValue(IntPtr output) + private static unsafe NDArray fetchValue(IntPtr output) { - var tensor = new Tensor(output); - NDArray nd = null; - Type type = tensor.dtype.as_numpy_dtype(); - var ndims = tensor.shape; - var offset = c_api.TF_TensorData(output); - - if(ndims.Length == 0) + NDArray ret; + using (var tensor = new Tensor(output)) { - switch (tensor.dtype) + var ndims = tensor.shape; + var srcAddress = c_api.TF_TensorData(output).ToInt64(); + + if (ndims.Length == 0) { - case TF_DataType.TF_BOOL: - nd = NDArray.Scalar(*(bool*)offset); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.BufferToArray(); - // wired, don't know why we have to start from offset 9. - // length in the begin - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = NDArray.FromString(str); - break; - case TF_DataType.TF_UINT8: - nd = NDArray.Scalar(*(byte*)offset); - break; - case TF_DataType.TF_INT16: - nd = NDArray.Scalar(*(short*)offset); - break; - case TF_DataType.TF_INT32: - nd = NDArray.Scalar(*(int*)offset); - break; - case TF_DataType.TF_INT64: - nd = NDArray.Scalar(*(long*)offset); - break; - case TF_DataType.TF_FLOAT: - nd = NDArray.Scalar(*(float*)offset); - break; - case TF_DataType.TF_DOUBLE: - nd = NDArray.Scalar(*(double*)offset); - break; - default: - throw new NotImplementedException("can't fetch output"); - } - } - else - { - switch (tensor.dtype) + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + ret = NDArray.Scalar(*(bool*) srcAddress); + break; + case TF_DataType.TF_STRING: + using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) + ret = NDArray.FromString(reader.ReadString()); + break; + case TF_DataType.TF_UINT8: + ret = NDArray.Scalar(*(byte*) srcAddress); + break; + case TF_DataType.TF_INT16: + ret = NDArray.Scalar(*(short*) srcAddress); + break; + case TF_DataType.TF_INT32: + ret = NDArray.Scalar(*(int*) srcAddress); + break; + case TF_DataType.TF_INT64: + ret = NDArray.Scalar(*(long*) srcAddress); + break; + case TF_DataType.TF_UINT16: + ret = NDArray.Scalar(*(ushort*) srcAddress); + break; + case TF_DataType.TF_UINT32: + ret = NDArray.Scalar(*(uint*) srcAddress); + break; + case TF_DataType.TF_UINT64: + ret = NDArray.Scalar(*(ulong*) srcAddress); + break; + case TF_DataType.TF_FLOAT: + ret = NDArray.Scalar(*(float*) srcAddress); + break; + case TF_DataType.TF_DOUBLE: + ret = NDArray.Scalar(*(double*) srcAddress); + break; + default: + throw new NotImplementedException("can't fetch output"); + } + } else { - case TF_DataType.TF_BOOL: - var bools = new bool[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(bools).reshape(ndims); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.BufferToArray(); - // wired, don't know why we have to start from offset 9. - // length in the begin - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = np.array(str); - break; - case TF_DataType.TF_UINT8: - var _bytes = new byte[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(_bytes).reshape(ndims); - break; - case TF_DataType.TF_INT16: - var shorts = new short[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(shorts).reshape(ndims); - break; - case TF_DataType.TF_INT32: - var ints = new int[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(ints).reshape(ndims); - break; - case TF_DataType.TF_INT64: - var longs = new long[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(longs).reshape(ndims); - break; - case TF_DataType.TF_FLOAT: - var floats = new float[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(floats).reshape(ndims); - break; - case TF_DataType.TF_DOUBLE: - var doubles = new double[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(doubles).reshape(ndims); - break; - default: - throw new NotImplementedException("can't fetch output"); + //var size = (long) tensor.size; + //var itemsize = (long) tensor.itemsize; + var bytesize = (long) tensor.bytesize; + var src = (void*) srcAddress; + +#if _REGEN + #region Compute + switch (tensor.dtype) + { + %foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")% + case TF_DataType.#3: + { + ret = new NDArray(NPTypeCode.#1, ndims, false); + System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize); + break; + } + % + case TF_DataType.TF_STRING: + { + //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString + using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) + ret = NDArray.FromString(reader.ReadString()); + break; + } + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + { + ret = new NDArray(NPTypeCode.Boolean, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_UINT8: + { + ret = new NDArray(NPTypeCode.Byte, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_INT16: + { + ret = new NDArray(NPTypeCode.Int16, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_UINT16: + { + ret = new NDArray(NPTypeCode.UInt16, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_INT32: + { + ret = new NDArray(NPTypeCode.Int32, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_UINT32: + { + ret = new NDArray(NPTypeCode.UInt32, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_INT64: + { + ret = new NDArray(NPTypeCode.Int64, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_UINT64: + { + ret = new NDArray(NPTypeCode.UInt64, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_DOUBLE: + { + ret = new NDArray(NPTypeCode.Double, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_FLOAT: + { + ret = new NDArray(NPTypeCode.Single, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + + case TF_DataType.TF_STRING: + { + throw new NotImplementedException(); + //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString + using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) + ret = NDArray.FromString(reader.ReadString()); + break; + } + + default: + throw new NotSupportedException(); + } + + #endregion + +#endif } } - - tensor.Dispose(); - return nd; + return ret; } /// @@ -392,9 +441,7 @@ namespace Tensorflow } private void _extend_graph() - { - - } + { } public void close() { @@ -403,11 +450,12 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) { - using (var status = new Status()) - { - c_api.TF_DeleteSession(handle, status); - status.Check(true); - } + lock (Locks.ProcessWide) + using (var status = new Status()) + { + c_api.TF_DeleteSession(handle, status); + status.Check(true); + } } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index ec2e443f..a89d94dc 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -15,62 +15,77 @@ ******************************************************************************/ using System; +using System.IO; +using System.Runtime.CompilerServices; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow { public class Session : BaseSession, IObjectLife { - public Session(string target = "", Graph g = null) - : base(target, g, null) - { - - } + public Session(string target = "", Graph g = null) : base(target, g, null) + { } - public Session(IntPtr handle, Graph g = null) - : base("", g, null) + public Session(IntPtr handle, Graph g = null) : base("", g, null) { _handle = handle; } - public Session(Graph g, SessionOptions opts = null, Status s = null) - : base("", g, opts) - { - if (s == null) - s = new Status(); - } + public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s) + { } public Session as_default() { - tf.defaultSession = this; + tf._defaultSessionFactory.Value = this; return this; } + [MethodImpl(MethodImplOptions.NoOptimization)] public static Session LoadFromSavedModel(string path) { - var graph = c_api.TF_NewGraph(); - var status = new Status(); - var opt = new SessionOptions(); - - var tags = new string[] { "serve" }; - var buffer = new TF_Buffer(); - - var sess = c_api.TF_LoadSessionFromSavedModel(opt, - IntPtr.Zero, - path, - tags, - tags.Length, - graph, - ref buffer, - status); - - // load graph bytes - // var data = new byte[buffer.length]; - // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); - // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ - status.Check(true); - - return new Session(sess, g: new Graph(graph).as_default()); + lock (Locks.ProcessWide) + { + var graph = c_api.TF_NewGraph(); + var status = new Status(); + var opt = new SessionOptions(); + + var tags = new string[] {"serve"}; + var buffer = new TF_Buffer(); + + IntPtr sess; + try + { + sess = c_api.TF_LoadSessionFromSavedModel(opt, + IntPtr.Zero, + path, + tags, + tags.Length, + graph, + ref buffer, + status); + status.Check(true); + } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) + { + status = new Status(); + sess = c_api.TF_LoadSessionFromSavedModel(opt, + IntPtr.Zero, + Path.GetFullPath(path), + tags, + tags.Length, + graph, + ref buffer, + status); + status.Check(true); + } + + // load graph bytes + // var data = new byte[buffer.length]; + // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); + // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ + + return new Session(sess, g: new Graph(graph)).as_default(); + } } public static implicit operator IntPtr(Session session) => session._handle; diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index ce561f75..21ff6f6e 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Diagnostics; using System.Runtime.CompilerServices; using static Tensorflow.c_api; @@ -52,6 +53,7 @@ namespace Tensorflow /// /// When the returned check is not TF_Code.TF_OK [MethodImpl(MethodImplOptions.AggressiveInlining)] + [DebuggerHidden] public void Check(bool throwException = false) { if (Code != TF_Code.TF_OK) diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index bd8c0a29..12a4c5f3 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.11.1 + 0.11.5 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -17,16 +17,21 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.11.1.0 + 0.11.5.0 Changes since v0.10.0: -1. Upgrade NumSharp to v0.20. +1. Upgrade NumSharp to v0.20.3. 2. Add DisposableObject class to manage object lifetime. 3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables. 4. Change tensorflow to non-static class in order to execute some initialization process. 5. Overload session.run(), make syntax simpler. -6. Add Local Response Normalization. +6. Add Local Response Normalization. +7. Add tf.image related APIs. +8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. +9. MultiThread is safe. +10. Support n-dim indexing for tensor. +11. Add RegisterNoGradient 7.3 - 0.11.1.0 + 0.11.5.0 LICENSE true true @@ -58,11 +63,12 @@ Docs: https://tensorflownet.readthedocs.io - + + diff --git a/src/TensorFlowNET.Core/Tensors/AllocationType.cs b/src/TensorFlowNET.Core/Tensors/AllocationType.cs new file mode 100644 index 00000000..9f5c8bad --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/AllocationType.cs @@ -0,0 +1,27 @@ +namespace Tensorflow +{ + /// + /// Used internally to + /// + public enum AllocationType + { + None = 0, + /// + /// Allocation was done by passing in a pointer, might be also holding reference to a C# object. + /// + FromPointer = 1, + /// + /// Allocation was done by calling c_api.TF_AllocateTensor or TF decided it has to copy data during c_api.TF_NewTensor.

+ /// Deallocation is handled solely by Tensorflow. + ///
+ Tensorflow = 2, + /// + /// Allocation was done by Marshal.AllocateHGlobal + /// + Marshal = 3, + /// + /// Allocation was done by GCHandle.Alloc + /// + GCHandle = 4, + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Dimension.cs b/src/TensorFlowNET.Core/Tensors/Dimension.cs new file mode 100644 index 00000000..58520270 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Dimension.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class Dimension + { + int _value; + public int value => _value; + + public Dimension(int value) + { + _value = value; + } + + public Dimension merge_with(Dimension other) + { + if (_value == -1) + return new Dimension(other.value); + else + return new Dimension(_value); + } + + public override string ToString() => $"Dimension({_value})"; + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs index 97496df1..c916b321 100644 --- a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs +++ b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs @@ -35,5 +35,6 @@ DtFloatRef = 101, // DT_FLOAT_REF DtDoubleRef = 102, // DT_DOUBLE_REF DtInt32Ref = 103, // DT_INT32_REF + DtInt64Ref = 109 // DT_INT64_REF } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 625b424a..34edcb4f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -28,42 +28,37 @@ using static Tensorflow.c_api; namespace Tensorflow { + [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] public partial class Tensor { /// - /// true if unmanaged buffer has been freed. + /// When Tensor was created from an object that is managed by C#'s GC - this will hold reference to prevent it from being collected. /// - private bool _deallocator_called => _deallocatorArgs.deallocator_called; + protected object AllocationReferenceHolder; /// - /// true if the Tensor was created from a managed array + /// The handle that was used to allocate this tensor, dependent on . /// - private bool _isPinnedArray => _deallocatorArgs.gc_handle != IntPtr.Zero; + protected object AllocationHandle; /// - /// True only if the Tensor object was created in a way that the Tensor object itself allocated memory or pinned a managed object. - /// False if the Tensor was created from a pointer + /// True if this Tensor holds data allocated by C#. /// - public bool IsMemoryOwner { get; private set; } + public bool IsMemoryOwner => AllocationType >= AllocationType.Marshal; /// - /// This holds values that are used by the unmanaged deallocator callback + /// The allocation method used to create this Tensor. /// - private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; - - // note: they must be assigned to a static variable in order to work as unmanaged callbacks - private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory; - private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle; - private static readonly Deallocator _nothingDeallocator = FreeNothing; + public AllocationType AllocationType { get; protected set; } /// - /// Create a Tensor object from an existing TF handle + /// Create a Tensor object from an existing TF handle /// - /// + /// Handle to a object. public Tensor(IntPtr handle) { _handle = handle; - IsMemoryOwner = false; + //no need to set AllocationType = AllocationType.None; } /// @@ -71,430 +66,412 @@ namespace Tensorflow /// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor /// but not the memory itself! /// - /// Pointer to unmanaged, fixed or pinned memory which the caller owns + /// Pointer to unmanaged, fixed or pinned memory which the caller owns /// Tensor shape /// TF data type /// Size of the tensor in memory - public Tensor(IntPtr ptr, long[] shape, TF_DataType dType, int num_bytes) + public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) { - _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); - IsMemoryOwner = false; + unsafe + { + _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (UIntPtr) num_bytes); + AllocationType = TF_TensorData(_handle) == data_ptr ? AllocationType.FromPointer : AllocationType.Tensorflow; + } + } + + /// + /// Create a new Tensor from the given unmanaged memory pointer (which must be allocated, fixed or pinned by the caller) + /// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor + /// but not the memory itself! + /// + /// Pointer to unmanaged, fixed or pinned memory which the caller owns + /// Tensor shape + /// TF data type + /// Size of the tensor in memory + public unsafe Tensor(void* data_ptr, long[] shape, TF_DataType dType, int num_bytes) + { + _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (UIntPtr) num_bytes); + AllocationType = TF_TensorData(_handle).ToPointer() == data_ptr ? AllocationType.FromPointer : AllocationType.Tensorflow; } #if _REGEN - %types=["sbyte", "bool", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %types = ["sbyte", "bool", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] %foreach types% /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(#1[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), new long[]{data.Length}, data, Marshal.SizeOf<#1>()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(#1)), new long[] {data.Length}, data, #(#1=="Complex"|"Marshal.SizeOf()"|"sizeof(#(str(#1)))")); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(#1[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), shape, data, Marshal.SizeOf<#1>()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(#1)), shape, data, #(#1=="Complex"|"Marshal.SizeOf()"|"sizeof(#(str(#1)))")); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(#1 value, TF_DataType? dType = null) { - var v = (#1*)Marshal.AllocHGlobal(sizeof(#1)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(#1)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(#1), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(#1)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(#1)); + *(#1*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } % #else - - /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(sbyte[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(sbyte)), new long[] {data.Length}, data, sizeof(sbyte)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(sbyte[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(sbyte)), shape, data, sizeof(sbyte)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(sbyte value, TF_DataType? dType = null) { - var v = (sbyte*)Marshal.AllocHGlobal(sizeof(sbyte)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(sbyte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(sbyte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(sbyte)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(sbyte)); + *(sbyte*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(bool[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(bool)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(bool)), new long[] {data.Length}, data, sizeof(bool)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(bool[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(bool)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(bool)), shape, data, sizeof(bool)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(bool value, TF_DataType? dType = null) { - var v = (bool*)Marshal.AllocHGlobal(sizeof(bool)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(bool)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(bool), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(bool)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(bool)); + *(bool*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(byte[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(byte)), new long[] {data.Length}, data, sizeof(byte)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(byte[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(byte)), shape, data, sizeof(byte)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(byte value, TF_DataType? dType = null) { - var v = (byte*)Marshal.AllocHGlobal(sizeof(byte)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(byte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(byte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(byte)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(byte)); + *(byte*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(short[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(short)), new long[] {data.Length}, data, sizeof(short)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(short[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(short)), shape, data, sizeof(short)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(short value, TF_DataType? dType = null) { - var v = (short*)Marshal.AllocHGlobal(sizeof(short)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(short)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(short), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(short)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(short)); + *(short*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(ushort[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ushort)), new long[] {data.Length}, data, sizeof(ushort)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(ushort[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ushort)), shape, data, sizeof(ushort)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(ushort value, TF_DataType? dType = null) { - var v = (ushort*)Marshal.AllocHGlobal(sizeof(ushort)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ushort)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ushort), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(ushort)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(ushort)); + *(ushort*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(int[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(int)), new long[] {data.Length}, data, sizeof(int)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(int[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(int)), shape, data, sizeof(int)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(int value, TF_DataType? dType = null) { - var v = (int*)Marshal.AllocHGlobal(sizeof(int)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(int)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(int), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(int)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(int)); + *(int*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(uint[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(uint)), new long[] {data.Length}, data, sizeof(uint)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(uint[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(uint)), shape, data, sizeof(uint)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(uint value, TF_DataType? dType = null) { - var v = (uint*)Marshal.AllocHGlobal(sizeof(uint)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(uint)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(uint), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(uint)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(uint)); + *(uint*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(long[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(long)), new long[] {data.Length}, data, sizeof(long)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(long[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(long)), shape, data, sizeof(long)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(long value, TF_DataType? dType = null) { - var v = (long*)Marshal.AllocHGlobal(sizeof(long)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(long)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(long), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(long)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(long)); + *(long*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(ulong[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ulong)), new long[] {data.Length}, data, sizeof(ulong)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(ulong[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ulong)), shape, data, sizeof(ulong)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(ulong value, TF_DataType? dType = null) { - var v = (ulong*)Marshal.AllocHGlobal(sizeof(ulong)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ulong)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ulong), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(ulong)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(ulong)); + *(ulong*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(float[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(float)), new long[] {data.Length}, data, sizeof(float)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(float[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(float)), shape, data, sizeof(float)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(float value, TF_DataType? dType = null) { - var v = (float*)Marshal.AllocHGlobal(sizeof(float)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(float)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(float), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(float)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(float)); + *(float*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(double[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(double)), new long[] {data.Length}, data, sizeof(double)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(double[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(double)), shape, data, sizeof(double)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(double value, TF_DataType? dType = null) { - var v = (double*)Marshal.AllocHGlobal(sizeof(double)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(double)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(double)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(double)); + *(double*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(Complex[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(Complex)), new long[] {data.Length}, data, Marshal.SizeOf()); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(Complex[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(Complex)), shape, data, Marshal.SizeOf()); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(Complex value, TF_DataType? dType = null) { - var v = (Complex*)Marshal.AllocHGlobal(sizeof(Complex)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(Complex)); + *(Complex*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } #endif /// - /// Create a string Tensor from the given string + /// Create a string Tensor from the given string /// public unsafe Tensor(string str) { var status = new Status(); var buffer = Encoding.UTF8.GetBytes(str); - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); + c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); _handle = handle; status.Check(true); } public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) { + if (tensorDType == null) + tensorDType = nd.dtype.as_dtype(); + // todo: handle nd of type "String" here too if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) { if (nd.Unsafe.Storage.Shape.IsContiguous) { - var bytesLength = (UIntPtr)nd.size; + var bytesLength = (UIntPtr) nd.size; var size = c_api.TF_StringEncodedSize(bytesLength); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); @@ -504,13 +481,12 @@ namespace Tensorflow status.Check(true); _handle = handle; - IsMemoryOwner = false; - } - else + } else { var buffer = nd.ToArray(); var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); @@ -521,7 +497,6 @@ namespace Tensorflow status.Check(true); _handle = handle; - IsMemoryOwner = false; } return; @@ -532,27 +507,27 @@ namespace Tensorflow private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) { - if (nd.dtype.Name == "String") + if (nd.typecode == NPTypeCode.String) throw new NotImplementedException("Support for NDArray of type string not implemented yet"); - IArraySlice arraySlice; - if (nd.Unsafe.Storage.Shape.IsContiguous == false) - { - // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. - arraySlice = nd.CloneData(); - } - else + + var arraySlice = nd.Unsafe.Storage.Shape.IsContiguous ? nd.GetData() : nd.CloneData(); + + var handle = TF_NewTensor( + given_dtype ?? nd.dtype.as_dtype(), + dims: nd.shape.Select(i => (long) i).ToArray(), + num_dims: nd.ndim, + data: arraySlice.Address, + len: (UIntPtr) (nd.size * nd.dtypesize)); + + //if TF decided not to perform copy, hold reference for given NDArray. + if (TF_TensorData(handle).ToPointer() == arraySlice.Address) { - // the memory is contiguous - arraySlice = nd.GetData(); - } - this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it - var ptr = new IntPtr(arraySlice.Address); - int num_bytes = (nd.size * nd.dtypesize); - var dtype = given_dtype ?? nd.dtype.as_dtype(); - var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); - IsMemoryOwner = false; - return handle; + AllocationType = AllocationType.FromPointer; + AllocationReferenceHolder = arraySlice; + } else + AllocationType = AllocationType.Tensorflow; + return handle; } public unsafe Tensor(byte[][] buffer, long[] shape) @@ -560,11 +535,13 @@ namespace Tensorflow int size = 0; foreach (var b in buffer) { - size += (int)TF_StringEncodedSize((UIntPtr)b.Length); + size += (int) TF_StringEncodedSize((UIntPtr) b.Length); } + int totalSize = size + buffer.Length * 8; ulong offset = 0; - IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); + IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr) totalSize); + AllocationType = AllocationType.Tensorflow; // Clear offset table IntPtr pOffset = TF_TensorData(handle); @@ -572,15 +549,15 @@ namespace Tensorflow IntPtr dstLimit = pOffset + totalSize; for (int i = 0; i < buffer.Length; i++) { - Marshal.WriteInt64(pOffset, (long)offset); + Marshal.WriteInt64(pOffset, (long) offset); using (var status = new Status()) { fixed (byte* src = &buffer[i][0]) { - var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); + var written = TF_StringEncode(src, (UIntPtr) buffer[i].Length, (sbyte*) dst, (UIntPtr) (dstLimit.ToInt64() - dst.ToInt64()), status); status.Check(true); pOffset += 8; - dst += (int)written; + dst += (int) written; offset += written; } } @@ -612,24 +589,26 @@ namespace Tensorflow /// [MethodImpl(MethodImplOptions.AggressiveInlining)] [SuppressMessage("ReSharper", "LocalVariableHidesMember")] - protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) + protected unsafe IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int element_size) { if (dt == TF_DataType.TF_STRING && data is byte[] buffer) { - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); var status = new Status(); fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); + c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); status.Check(true); return handle; } - return CreateTensorWithoutCopying(dt, shape, data, 0, data.Length, element_size); + + return CreateTensorFromArray(dt, shape, data, 0, data.Length, element_size); } /// @@ -647,67 +626,34 @@ namespace Tensorflow /// specified dimensions. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int start, int count, int element_size) + protected IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int start, int count, int element_size) { if (start < 0 || start > data.Length - count) throw new ArgumentException($"Array length {data.Length} does not match the given shape {new Shape(shape.Cast().ToArray())}"); // get a handle to the pinned array which we will pass on to the tensor computation engine to use var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned); - _deallocatorArgs = new DeallocatorArgs() { gc_handle = GCHandle.ToIntPtr(gcHandle) }; - if (shape == null || shape.Length == 0) - return TF_NewTensor(dt, new long[0], 0, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs); - else - return TF_NewTensor(dt, shape, shape.Length, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs); - } - - [MonoPInvokeCallback(typeof(Deallocator))] - internal static void FreeHGlobalMemory(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) - { - if (args.deallocator_called) - return; + var pinnedAddr = gcHandle.AddrOfPinnedObject(); - // NumSharp will dispose - Marshal.FreeHGlobal(dataPtr); - args.deallocator_called = true; - } + //call NewTensor + IntPtr handle; + if (shape == null || shape.Length == 0) + handle = TF_NewTensor(dt, new long[0], 0, pinnedAddr + start * element_size, (UIntPtr) (count * element_size)); + else + handle = TF_NewTensor(dt, shape, shape.Length, pinnedAddr + start * element_size, (UIntPtr) (count * element_size)); - [MonoPInvokeCallback(typeof(Deallocator))] - internal static void FreeGCHandle(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) - { - if (args.deallocator_called || args.gc_handle == IntPtr.Zero) - return; - // note: since the ptr given to tensorflow is just the addr of the pinned object we can not directly free it! we need to free the gcHandle instead - //args.gchandle.Free(); - args.deallocator_called = true; - } + //Figure if TF decided to clone or not. + if (c_api.TF_TensorData(handle) == pinnedAddr) + { + AllocationType = AllocationType.GCHandle; + AllocationHandle = gcHandle; + } else + { + AllocationType = AllocationType.Tensorflow; + gcHandle.Free(); + } - [MonoPInvokeCallback(typeof(Deallocator))] - internal static void FreeNothing(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) - { - args.deallocator_called = true; + return handle; } - } - - /// - /// This attribute can be applied to callback functions that will be invoked - /// from unmanaged code to managed code. - /// - /// - /// - /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))] - /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..} - /// - /// - public sealed class MonoPInvokeCallbackAttribute : Attribute - { - /// - /// Use this constructor to annotate the type of the callback function that - /// will be invoked from unmanaged code. - /// - /// T. - public MonoPInvokeCallbackAttribute(Type t) { } - } - -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs new file mode 100644 index 00000000..d916f624 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs @@ -0,0 +1,224 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Tensor + { + public Tensor this[int idx] => slice(idx); + + public Tensor this[params Slice[] slices] + { + get + { + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach (var s in slices) + { + if (s.IsNewAxis) + { + begin.Add(0); + end.Add(0); + strides.Add(1); + new_axis_mask |= (1 << index); + } + else if (s.IsEllipsis) + { + begin.Add(0); + end.Add(0); + strides.Add(1); + ellipsis_mask |= (1 << index); + } + else + { + if (s.Start.HasValue) + { + begin.Add(s.Start.Value); + } + else + { + begin.Add(0); + begin_mask |= (1 << index); + } + + if (s.Stop.HasValue) + { + end.Add(s.Stop.Value); + } + else + { + end.Add(0); + end_mask |= (1 << index); + } + + strides.Add(s.Step); + } + + index += 1; + } + + return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if (begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return gen_array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + name: name); + } + + throw new NotImplementedException(""); + }); + } + } + + public Tensor this[params string[] slices] + => this[slices.Select(x => new Slice(x)).ToArray()]; + + + public Tensor slice(Slice slice) + { + var slice_spec = new int[] { slice.Start.Value }; + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach (var s in slice_spec) + { + begin.Add(s); + if (slice.Stop.HasValue) + { + end.Add(slice.Stop.Value); + } + else + { + end.Add(0); + end_mask |= (1 << index); + } + + strides.Add(slice.Step); + + index += 1; + } + + return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if (begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return gen_array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + name: name); + } + + throw new NotImplementedException(""); + }); + } + + public Tensor slice(int start) + { + var slice_spec = new int[] { start }; + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach (var s in slice_spec) + { + begin.Add(s); + end.Add(s + 1); + strides.Add(1); + shrink_axis_mask |= (1 << index); + index += 1; + } + + return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if (begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return gen_array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + name: name); + } + + throw new NotImplementedException(""); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index 4b15864f..ae14958f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -14,110 +14,286 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using System.Linq; +using System.Numerics; using static Tensorflow.Binding; namespace Tensorflow { public partial class Tensor { - public static Tensor operator +(double x, Tensor y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(float x, Tensor y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(int x, Tensor y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(Tensor x, Tensor y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(Tensor x, int y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(Tensor x, float y) => BinaryOpWrapper("add", x, y); - public static Tensor operator +(Tensor x, double y) => BinaryOpWrapper("add", x, y); +#if _REGEN + #region Compute + %operators = ["add", "sub", "mul", "div", "mod"] + %operators_sign = ["+", "-", "*", "/", "%"] + %operators_comparers = [">", "<", ">=", "<="] + %operators_comparers_names = ["greater", "less", "greater_equal", "less_equal"] - public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1); + %possabilities = ["NDArray", "sbyte", "byte", "short", "ushort", "int", "uint", "ulong", "long", "float", "double", "Complex"] + + %foreach operators, operators_sign% + public static Tensor operator #2(Tensor lhs, Tensor rhs) => BinaryOpWrapper("#1", lhs, rhs); + %foreach possabilities% + public static Tensor operator #2(Tensor lhs, #101 rhs) => BinaryOpWrapper("#1", lhs, rhs); + public static Tensor operator #2(#101 lhs, Tensor rhs) => BinaryOpWrapper("#1", lhs, rhs); + % + % - public static Tensor operator -(double x, Tensor y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(int x, Tensor y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(Tensor x, float y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y); + %foreach operators_comparers_names, operators_comparers % + public static Tensor operator #2(Tensor lhs, Tensor rhs) => gen_math_ops.#1(lhs, rhs); + %foreach possabilities% + public static Tensor operator #2(Tensor lhs, #101 rhs) => gen_math_ops.#1(lhs, rhs); + public static Tensor operator #2(#101 lhs, Tensor rhs) => gen_math_ops.#1(lhs, rhs); + % + % + public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); + #endregion +#else + #region Compute - public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(Tensor x, Tensor y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(Tensor x, int y) => BinaryOpWrapper("mul", x, y); - public static Tensor operator *(Tensor tensor, bool constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, sbyte constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, byte constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, ushort constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, short constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, uint constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, long constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, ulong constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, float constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(Tensor tensor, double constant) => BinaryOpWrapper("mul", tensor, constant); - public static Tensor operator *(bool constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(sbyte constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(byte constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(ushort constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(short constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(int constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(uint constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(long constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator *(ulong constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); + + public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, sbyte rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(sbyte lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, byte rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(byte lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, short rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(short lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, ushort rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(ushort lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, int rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(int lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, uint rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(uint lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, ulong rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(ulong lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, long rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(long lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, float rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(float lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, double rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(double lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Tensor lhs, Complex rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator +(Complex lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); + public static Tensor operator -(Tensor lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, NDArray rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(NDArray lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, sbyte rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(sbyte lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, byte rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(byte lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, short rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(short lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, ushort rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(ushort lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, int rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(int lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, uint rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(uint lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, ulong rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(ulong lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, long rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(long lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, float rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(float lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, double rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(double lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Tensor lhs, Complex rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator -(Complex lhs, Tensor rhs) => BinaryOpWrapper("sub", lhs, rhs); + public static Tensor operator *(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, sbyte rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(sbyte lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, byte rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(byte lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, short rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(short lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, ushort rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(ushort lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, int rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(int lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, uint rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(uint lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, ulong rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(ulong lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, long rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(long lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, float rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(float lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, double rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); + public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(sbyte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, byte rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(byte lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, short rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(short lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, ushort rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(ushort lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, int rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(int lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, uint rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(uint lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, ulong rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(ulong lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, long rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(long lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, float rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(float lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, double rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(double lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Tensor lhs, Complex rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator /(Complex lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); + public static Tensor operator %(Tensor lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, NDArray rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(NDArray lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, sbyte rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(sbyte lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, byte rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(byte lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, short rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(short lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, ushort rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(ushort lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, int rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(int lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, uint rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(uint lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, ulong rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(ulong lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, long rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(long lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, float rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(float lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, double rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(double lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Tensor lhs, Complex rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator %(Complex lhs, Tensor rhs) => BinaryOpWrapper("mod", lhs, rhs); + public static Tensor operator >(Tensor lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, NDArray rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(NDArray lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, sbyte rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(sbyte lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, byte rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(byte lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, short rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(short lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, ushort rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(ushort lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, int rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(int lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, uint rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(uint lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, ulong rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(ulong lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, long rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(long lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, float rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(float lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, double rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(double lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Tensor lhs, Complex rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator >(Complex lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); + public static Tensor operator <(Tensor lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, NDArray rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(NDArray lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, sbyte rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(sbyte lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, byte rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(byte lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, short rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(short lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, ushort rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(ushort lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, int rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(int lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, uint rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(uint lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, ulong rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(ulong lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, long rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(long lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, float rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(float lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, double rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(double lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Tensor lhs, Complex rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator <(Complex lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); + public static Tensor operator >=(Tensor lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, NDArray rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(NDArray lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, sbyte rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(sbyte lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, byte rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(byte lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, short rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(short lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, ushort rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(ushort lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, int rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(int lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, uint rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(uint lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, ulong rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(ulong lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, long rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(long lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, float rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(float lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, double rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(double lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Tensor lhs, Complex rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator >=(Complex lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, NDArray rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(NDArray lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, sbyte rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(sbyte lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, byte rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(byte lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, short rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(short lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, ushort rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(ushort lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, int rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(int lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, uint rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(uint lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, ulong rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(ulong lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, long rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(long lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, float rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(float lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, double rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(double lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); + public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); + #endregion +#endif + private static readonly TF_DataType[] _intTfDataTypes = { TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 }; - public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); - public static Tensor operator /(Tensor x, Tensor y) => - _intTfDataTypes.Contains(x.dtype) - ? BinaryOpWrapper("floordiv", x, y) - : BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); - public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(Tensor x, double y) => BinaryOpWrapper("truediv", x, y); - - public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); - - public static Tensor operator >(double x, Tensor y) => gen_math_ops.greater(x, y); - public static Tensor operator >(float x, Tensor y) => gen_math_ops.greater(x, y); - public static Tensor operator >(int x, Tensor y) => gen_math_ops.greater(x, y); - public static Tensor operator >(Tensor x, Tensor y) => gen_math_ops.greater(x, y); - public static Tensor operator >(Tensor x, int y) => gen_math_ops.greater(x, y); - public static Tensor operator >(Tensor x, float y) => gen_math_ops.greater(x, y); - public static Tensor operator >(Tensor x, double y) => gen_math_ops.greater(x, y); - - public static Tensor operator <(double x, Tensor y) => gen_math_ops.less(x, y); - public static Tensor operator <(float x, Tensor y) => gen_math_ops.less(x, y); - public static Tensor operator <(int x, Tensor y) => gen_math_ops.less(x, y); - public static Tensor operator <(Tensor x, Tensor y) => gen_math_ops.less(x, y); - public static Tensor operator <(Tensor x, int y) => gen_math_ops.less(x, y); - public static Tensor operator <(Tensor x, float y) => gen_math_ops.less(x, y); - public static Tensor operator <(Tensor x, double y) => gen_math_ops.less(x, y); - - public static Tensor operator >=(double x, Tensor y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(float x, Tensor y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(int x, Tensor y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(Tensor x, int y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(Tensor x, float y) => gen_math_ops.greater_equal(x, y); - public static Tensor operator >=(Tensor x, double y) => gen_math_ops.greater_equal(x, y); - - public static Tensor operator <=(int x, Tensor y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(float x, Tensor y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(double x, Tensor y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(Tensor x, Tensor y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(Tensor x, int y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(Tensor x, float y) => gen_math_ops.less_equal(x, y); - public static Tensor operator <=(Tensor x, double y) => gen_math_ops.less_equal(x, y); - private static Tensor BinaryOpWrapper(string name, Tx x, Ty y) { TF_DataType dtype = TF_DataType.DtInvalid; + if (x is Tensor tl) dtype = tl.dtype.as_base_dtype(); if (y is Tensor tr) @@ -125,15 +301,20 @@ namespace Tensorflow return tf_with(ops.name_scope(null, name, new { x, y }), scope => { - Tensor result = null; + Tensor result; var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); - switch (name.ToLower()) + switch (name.ToLowerInvariant()) { case "add": result = gen_math_ops.add(x1, y1, name: scope); break; + case "div": + result = _intTfDataTypes.Contains(x1.dtype) || _intTfDataTypes.Contains(y1.dtype) + ? gen_math_ops.floor_div(x1, y1, name: scope) + : gen_math_ops.real_div(x1, y1, name: scope); + break; case "floordiv": result = gen_math_ops.floor_div(x1, y1, name: scope); break; @@ -150,7 +331,7 @@ namespace Tensorflow result = gen_math_ops.floor_mod(x1, y1, name: scope); break; default: - throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty)}"); + throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); } return result; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 75cba69e..fb8e2457 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -28,7 +28,6 @@ using NumSharp.Backends; using NumSharp.Backends.Unmanaged; using NumSharp.Utilities; using Tensorflow.Framework; -using static Tensorflow.Binding; namespace Tensorflow { @@ -106,10 +105,13 @@ namespace Tensorflow if (_handle == IntPtr.Zero) { - var status = new Status(); - c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); - status.Check(); - } else + using (var status = new Status()) + { + c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); + status.Check(); + } + } + else { for (int i = 0; i < rank; i++) dims[i] = c_api.TF_Dim(_handle, i); @@ -120,37 +122,31 @@ namespace Tensorflow set { - var status = new Status(); + using (var status = new Status()) + { + if (value == null) + c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); + else + c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); - if (value == null) - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); - else - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); + status.Check(true); + } } } public int[] _shape_tuple() { - return (int[]) shape.Clone(); + return rank < 0 ? null : shape; } - public TensorShape TensorShape => tensor_util.to_shape(shape); + public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); /// /// Updates the shape of this tensor. /// public void set_shape(TensorShape shape) { - this.shape = (int[]) shape.dims.Clone(); - } - - /// - /// Updates the shape of this tensor. - /// - [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] - public void SetShape(TensorShape shape) - { - this.shape = (int[]) shape.dims.Clone(); + this.shape = shape.rank > 0 ? shape.dims : null; } /// @@ -164,6 +160,7 @@ namespace Tensorflow /// /// number of dimensions

+ /// -1 Unknown

/// 0 Scalar (magnitude only)

/// 1 Vector (magnitude and direction)

/// 2 Matrix (table of numbers)

@@ -177,11 +174,13 @@ namespace Tensorflow { if (_handle == IntPtr.Zero) { - var status = new Status(); - var output = _as_tf_output(); - int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); - status.Check(); - return ndim; + using (var status = new Status()) + { + var output = _as_tf_output(); + int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); + status.Check(); + return ndim; + } } return c_api.TF_NumDims(_handle); @@ -260,31 +259,31 @@ namespace Tensorflow switch (dtype.as_numpy_dtype().GetTypeCode()) { %foreach supported_dtypes,supported_dtypes_lowercase% - case NPTypeCode.#1: return new T[] {Converts.ChangeType(*(#2*) buffer, NPTypeCode.#1)}; + case NPTypeCode.#1: return new T[] {Converts.ChangeType(*(#2*) buffer)}; % - case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)}; + case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this)}; default: throw new NotSupportedException(); } #endregion #else #region Compute - switch (dtype.as_numpy_dtype()?.GetTypeCode()) + switch (dtype.as_numpy_dtype().GetTypeCode()) { - case NPTypeCode.Boolean: return new T[] {Converts.ChangeType(*(bool*) buffer, NPTypeCode.Boolean)}; - case NPTypeCode.Byte: return new T[] {Converts.ChangeType(*(byte*) buffer, NPTypeCode.Byte)}; - case NPTypeCode.Int16: return new T[] {Converts.ChangeType(*(short*) buffer, NPTypeCode.Int16)}; - case NPTypeCode.UInt16: return new T[] {Converts.ChangeType(*(ushort*) buffer, NPTypeCode.UInt16)}; - case NPTypeCode.Int32: return new T[] {Converts.ChangeType(*(int*) buffer, NPTypeCode.Int32)}; - case NPTypeCode.UInt32: return new T[] {Converts.ChangeType(*(uint*) buffer, NPTypeCode.UInt32)}; - case NPTypeCode.Int64: return new T[] {Converts.ChangeType(*(long*) buffer, NPTypeCode.Int64)}; - case NPTypeCode.UInt64: return new T[] {Converts.ChangeType(*(ulong*) buffer, NPTypeCode.UInt64)}; - case NPTypeCode.Char: return new T[] {Converts.ChangeType(*(char*) buffer, NPTypeCode.Char)}; - case NPTypeCode.Double: return new T[] {Converts.ChangeType(*(double*) buffer, NPTypeCode.Double)}; - case NPTypeCode.Single: return new T[] {Converts.ChangeType(*(float*) buffer, NPTypeCode.Single)}; - case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)}; + case NPTypeCode.Boolean: return new T[] {Converts.ChangeType(*(bool*) buffer)}; + case NPTypeCode.Byte: return new T[] {Converts.ChangeType(*(byte*) buffer)}; + case NPTypeCode.Int16: return new T[] {Converts.ChangeType(*(short*) buffer)}; + case NPTypeCode.UInt16: return new T[] {Converts.ChangeType(*(ushort*) buffer)}; + case NPTypeCode.Int32: return new T[] {Converts.ChangeType(*(int*) buffer)}; + case NPTypeCode.UInt32: return new T[] {Converts.ChangeType(*(uint*) buffer)}; + case NPTypeCode.Int64: return new T[] {Converts.ChangeType(*(long*) buffer)}; + case NPTypeCode.UInt64: return new T[] {Converts.ChangeType(*(ulong*) buffer)}; + case NPTypeCode.Char: return new T[] {Converts.ChangeType(*(char*) buffer)}; + case NPTypeCode.Double: return new T[] {Converts.ChangeType(*(double*) buffer)}; + case NPTypeCode.Single: return new T[] {Converts.ChangeType(*(float*) buffer)}; + case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this)}; default: - throw new NotSupportedException(); + throw new NotSupportedException(); } #endregion #endif @@ -436,128 +435,49 @@ namespace Tensorflow return ops._eval_using_default_session(this, feed_dict, graph, session); } - public Tensor slice(Slice slice) - { - var slice_spec = new int[] {slice.Start.Value}; - var begin = new List(); - var end = new List(); - var strides = new List(); - - var index = 0; - var (new_axis_mask, shrink_axis_mask) = (0, 0); - var (begin_mask, end_mask) = (0, 0); - var ellipsis_mask = 0; - - foreach (var s in slice_spec) - { - begin.Add(s); - if (slice.Stop.HasValue) - { - end.Add(slice.Stop.Value); - } else - { - end.Add(0); - end_mask |= (1 << index); - } - - strides.Add(slice.Step); - - index += 1; - } - - return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => - { - string name = scope; - if (begin != null) - { - var (packed_begin, packed_end, packed_strides) = - (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); - - return gen_array_ops.strided_slice( - this, - packed_begin, - packed_end, - packed_strides, - begin_mask: begin_mask, - end_mask: end_mask, - shrink_axis_mask: shrink_axis_mask, - new_axis_mask: new_axis_mask, - ellipsis_mask: ellipsis_mask, - name: name); - } - - throw new NotImplementedException(""); - }); - } - - public Tensor slice(int start) + public override string ToString() { - var slice_spec = new int[] {start}; - var begin = new List(); - var end = new List(); - var strides = new List(); - - var index = 0; - var (new_axis_mask, shrink_axis_mask) = (0, 0); - var (begin_mask, end_mask) = (0, 0); - var ellipsis_mask = 0; - - foreach (var s in slice_spec) + // this can throw IndexOutOfRangeException + switch (rank) { - begin.Add(s); - end.Add(s + 1); - strides.Add(1); - shrink_axis_mask |= (1 << index); - index += 1; + case -1: + return $"tf.Tensor '{name}' shape= dtype={dtype}"; + case 0: + return $"tf.Tensor '{name}' shape=() dtype={dtype}"; + default: + return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; } - - return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => - { - string name = scope; - if (begin != null) - { - var (packed_begin, packed_end, packed_strides) = - (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); - - return gen_array_ops.strided_slice( - this, - packed_begin, - packed_end, - packed_strides, - begin_mask: begin_mask, - end_mask: end_mask, - shrink_axis_mask: shrink_axis_mask, - new_axis_mask: new_axis_mask, - ellipsis_mask: ellipsis_mask, - name: name); - } - - throw new NotImplementedException(""); - }); } - public override string ToString() + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected override void DisposeManagedResources() { - // this can throw IndexOutOfRangeException - //if(NDims == 0) - //{ - // switch (dtype) - // { - // case TF_DataType.TF_INT32: - // return Data()[0].ToString(); - // } - //} - - return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; + AllocationReferenceHolder = null; } + [SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")] protected override void DisposeUnmanagedResources(IntPtr handle) { c_api.TF_DeleteTensor(handle); + + if (AllocationHandle == null) + return; + + if (AllocationType == AllocationType.GCHandle) + { + ((GCHandle) AllocationHandle).Free(); + AllocationHandle = null; + AllocationType = AllocationType.None; + } else if (AllocationType == AllocationType.Marshal) + { + Marshal.FreeHGlobal((IntPtr) AllocationHandle); + AllocationHandle = null; + AllocationType = AllocationType.None; + } else + throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType})."); } public bool IsDisposed => _disposed; diff --git a/src/TensorFlowNET.Core/Tensors/TensorConverter.cs b/src/TensorFlowNET.Core/Tensors/TensorConverter.cs new file mode 100644 index 00000000..dad051c6 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TensorConverter.cs @@ -0,0 +1,285 @@ +using System; +using System.Threading.Tasks; +using NumSharp; +using NumSharp.Backends; +using NumSharp.Utilities; + +namespace Tensorflow +{ + /// + /// Provides various methods to conversion between types and . + /// + public static class TensorConverter + { + /// + /// Convert given to . + /// + /// The ndarray to convert, can be regular, jagged or multi-dim array. + /// Convert to given before inserting it into a . + /// + public static Tensor ToTensor(NDArray nd, TF_DataType? astype = null) + { + return new Tensor(astype == null ? nd : nd.astype(astype.Value.as_numpy_typecode(), false)); + } + + /// + /// Convert given to . + /// + /// The ndarray to convert. + /// Convert to given before inserting it into a . + /// + public static Tensor ToTensor(NDArray nd, NPTypeCode? astype = null) + { + return new Tensor(astype == null ? nd : nd.astype(astype.Value, false)); + } + + /// + /// Convert given to . + /// + /// The array to convert, can be regular, jagged or multi-dim array. + /// Convert to given before inserting it into a . + /// + public static Tensor ToTensor(Array array, TF_DataType? astype = null) + { + if (array == null) throw new ArgumentNullException(nameof(array)); + var arrtype = array.ResolveElementType(); + + var astype_type = astype?.as_numpy_dtype() ?? arrtype; + if (astype_type == arrtype) + { + //no conversion required + if (astype == TF_DataType.TF_STRING) + { + throw new NotSupportedException(); //TODO! when string is fully implemented. + } + + if (astype == TF_DataType.TF_INT8) + { + if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged + array = Arrays.Flatten(array); + + return new Tensor((sbyte[]) array); + } + + //is multidim or jagged, if so - use NDArrays constructor as it records shape. + if (array.Rank != 1 || array.GetType().GetElementType().IsArray) + return new Tensor(new NDArray(array)); + +#if _REGEN + #region Compute + switch (arrtype) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: return new Tensor((#2[])arr); + % + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + + switch (arrtype.GetTypeCode()) + { + case NPTypeCode.Boolean: return new Tensor((bool[]) array); + case NPTypeCode.Byte: return new Tensor((byte[]) array); + case NPTypeCode.Int16: return new Tensor((short[]) array); + case NPTypeCode.UInt16: return new Tensor((ushort[]) array); + case NPTypeCode.Int32: return new Tensor((int[]) array); + case NPTypeCode.UInt32: return new Tensor((uint[]) array); + case NPTypeCode.Int64: return new Tensor((long[]) array); + case NPTypeCode.UInt64: return new Tensor((ulong[]) array); + case NPTypeCode.Char: return new Tensor((char[]) array); + case NPTypeCode.Double: return new Tensor((double[]) array); + case NPTypeCode.Single: return new Tensor((float[]) array); + default: + throw new NotSupportedException(); + } + + #endregion + +#endif + } else + { + //conversion is required. + //by this point astype is not null. + + //flatten if required + if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged + array = Arrays.Flatten(array); + + try + { + return ToTensor( + ArrayConvert.To(array, astype.Value.as_numpy_typecode()), + null + ); + } catch (NotSupportedException) + { + //handle dtypes not supported by ArrayConvert + var ret = Array.CreateInstance(astype_type, array.LongLength); + Parallel.For(0, ret.LongLength, i => ret.SetValue(Convert.ChangeType(array.GetValue(i), astype_type), i)); + return ToTensor(ret, null); + } + } + } + + /// + /// Convert given to . + /// + /// The constant scalar to convert + /// Convert to given before inserting it into a . + /// + public static Tensor ToTensor(T constant, TF_DataType? astype = null) where T : unmanaged + { + //was conversion requested? + if (astype == null) + { + //No conversion required + var constantType = typeof(T).as_dtype(); + if (constantType == TF_DataType.TF_INT8) + return new Tensor((sbyte) (object) constant); + + if (constantType == TF_DataType.TF_STRING) + return new Tensor((string) (object) constant); + +#if _REGEN + #region Compute + switch (InfoOf.NPTypeCode) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: return new Tensor((#2)(object)constant); + % + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + + switch (InfoOf.NPTypeCode) + { + case NPTypeCode.Boolean: return new Tensor((bool) (object) constant); + case NPTypeCode.Byte: return new Tensor((byte) (object) constant); + case NPTypeCode.Int16: return new Tensor((short) (object) constant); + case NPTypeCode.UInt16: return new Tensor((ushort) (object) constant); + case NPTypeCode.Int32: return new Tensor((int) (object) constant); + case NPTypeCode.UInt32: return new Tensor((uint) (object) constant); + case NPTypeCode.Int64: return new Tensor((long) (object) constant); + case NPTypeCode.UInt64: return new Tensor((ulong) (object) constant); + case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant)); + case NPTypeCode.Double: return new Tensor((double) (object) constant); + case NPTypeCode.Single: return new Tensor((float) (object) constant); + default: + throw new NotSupportedException(); + } + + #endregion +#endif + } + + //conversion required + + if (astype == TF_DataType.TF_INT8) + return new Tensor(Converts.ToSByte(constant)); + + if (astype == TF_DataType.TF_STRING) + return new Tensor(Converts.ToString(constant)); + + var astype_np = astype?.as_numpy_typecode(); + +#if _REGEN + #region Compute + switch (astype_np) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: return new Tensor(Converts.To#1(constant)); + % + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + switch (astype_np) + { + case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant)); + case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant)); + case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant)); + case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant)); + case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant)); + case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant)); + case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant)); + case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant)); + case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant)); + case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant)); + case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant)); + default: + throw new NotSupportedException(); + } + #endregion +#endif + } + + /// + /// Convert given to . + /// + /// The constant scalar to convert + /// Convert to given before inserting it into a . + /// + public static Tensor ToTensor(string constant, TF_DataType? astype = null) + { + switch (astype) + { + //was conversion requested? + case null: + case TF_DataType.TF_STRING: + return new Tensor(constant); + //conversion required + case TF_DataType.TF_INT8: + return new Tensor(Converts.ToSByte(constant)); + default: + { + var astype_np = astype?.as_numpy_typecode(); + +#if _REGEN + #region Compute + switch (astype_np) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: return new Tensor(Converts.To#1(constant)); + % + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + switch (astype_np) + { + case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant)); + case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant)); + case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant)); + case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant)); + case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant)); + case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant)); + case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant)); + case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant)); + case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant)); + case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant)); + case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant)); + default: + throw new NotSupportedException(); + } + #endregion +#endif + } + } + } + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 8c9e571e..1e239d50 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -1,8 +1,10 @@ using NumSharp; using System; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; +using static Tensorflow.Binding; namespace Tensorflow { @@ -22,17 +24,40 @@ namespace Tensorflow /// /// Returns the rank of this shape. /// - public int ndim => shape.NDim; + public int ndim => rank; + private int _rank; /// /// Returns the rank of this shape. /// - public int rank => shape.NDim; + public int rank => _rank > -1 ? shape.NDim : -1; /// /// Returns the size this shape represents. /// - public int size => shape.Size; + public int size + { + get + { + var dims = shape.Dimensions; + var computed = 1; + for (int i = 0; i < dims.Length; i++) + { + var val = dims[i]; + if (val <= 0) + continue; + computed *= val; + } + + return computed; + } + } + + public TensorShape() + { + _rank = -1; + shape = new Shape(); + } public TensorShape(TensorShapeProto proto) { @@ -59,12 +84,30 @@ namespace Tensorflow switch (dims.Length) { case 0: shape = new Shape(new int[0]); break; - case 1: shape = Shape.Vector((int) dims[0]); break; + case 1: shape = Shape.Vector(dims[0]); break; case 2: shape = Shape.Matrix(dims[0], dims[1]); break; default: shape = new Shape(dims); break; } } + public TensorShape(int[][] dims) + { + if(dims.Length == 1) + { + switch (dims[0].Length) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int)dims[0][0]); break; + case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break; + default: shape = new Shape(dims[0]); break; + } + } + else + { + throw new NotImplementedException("TensorShape int[][] dims"); + } + } + /// /// /// @@ -91,7 +134,7 @@ namespace Tensorflow /// public bool is_fully_defined() { - return dims != null && dims.Count(x => x < 1) == 0; + return rank > -1 && dims != null && dims.Count(x => x < 1) == 0; } public bool is_compatible_with(TensorShape shape2) @@ -108,6 +151,24 @@ namespace Tensorflow return this; } + public TensorShape with_rank(int rank) + { + return merge_with(unknown_shape(rank: rank)); + } + + /// + /// Returns an unknown TensorShape, optionally with a known rank. + /// + /// + /// + public TensorShape unknown_shape(int rank = -1) + { + if (rank == -1) + return new TensorShape(-1); + else + return new TensorShape(Enumerable.Repeat(-1, rank).ToArray()); + } + /// /// Returns the concatenation of the dimension in `self` and `other`. /// @@ -143,6 +204,37 @@ namespace Tensorflow } } + /// + /// Returns a `TensorShape` combining the information in `self` and `other`. + /// + /// + /// + public TensorShape merge_with(TensorShape other) + { + if (dims == null) + return other; + + var new_dims = new List(); + + foreach (var i in range(ndim)) + { + var dim = new Dimension(dims[i]); + var merged = dim.merge_with(new Dimension(other.dims[i])); + new_dims.Add(merged.value); + } + + return new TensorShape(new_dims.ToArray()); + } + + /// + /// Returns a cloned array from . + /// + public int[] as_list() { + if (shape.IsEmpty) + throw new ValueError("as_list() is not defined on an unknown TensorShape."); + return (int[]) dims.Clone(); + } + public override string ToString() { return shape.ToString(); @@ -155,7 +247,7 @@ namespace Tensorflow public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); public static explicit operator int(TensorShape shape) => shape.size; - public static explicit operator TensorShape(int dim) => new TensorShape(dim); + public static implicit operator TensorShape(int dim) => new TensorShape(dim); public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); @@ -171,6 +263,11 @@ namespace Tensorflow public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); - + + public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7); + + public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); } } diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index 6b20b34f..be5f3932 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; namespace Tensorflow @@ -77,6 +78,51 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref DeallocatorArgs deallocator_arg); + /// + /// Return a new tensor that holds the bytes data[0,len-1] + /// + /// + /// + /// + /// + /// num_bytes, ex: 6 * sizeof(float) + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, IntPtr deallocator_arg); + + /// + /// Return a new tensor that holds the bytes data[0,len-1] + /// + /// + /// + /// + /// + /// num_bytes, ex: 6 * sizeof(float) + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len) + { + return TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty); + } + /// + /// Return a new tensor that holds the bytes data[0,len-1] + /// + /// + /// + /// + /// + /// num_bytes, ex: 6 * sizeof(float) + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, void* data, UIntPtr len) + { + return TF_NewTensor(dataType, dims, num_dims, new IntPtr(data), len); + } + /// /// Return the number of dimensions that the tensor has. /// @@ -159,5 +205,32 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, IntPtr status); + + + public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; + + [MonoPInvokeCallback(typeof(c_api.Deallocator))] + private static void FreeNothingDeallocator(IntPtr dataPtr, IntPtr len, ref c_api.DeallocatorArgs args) + { } + + /// + /// This attribute can be applied to callback functions that will be invoked + /// from unmanaged code to managed code. + /// + /// + /// + /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))] + /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..} + /// + /// + public sealed class MonoPInvokeCallbackAttribute : Attribute + { + /// + /// Use this constructor to annotate the type of the callback function that + /// will be invoked from unmanaged code. + /// + /// T. + public MonoPInvokeCallbackAttribute(Type t) { } + } } } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 37f1ca61..3827229d 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -16,12 +16,14 @@ using System; using System.Numerics; +using NumSharp; using NumSharp.Backends; namespace Tensorflow { public static class dtypes { + public static TF_DataType @bool = TF_DataType.TF_BOOL; public static TF_DataType int8 = TF_DataType.TF_INT8; public static TF_DataType int32 = TF_DataType.TF_INT32; public static TF_DataType int64 = TF_DataType.TF_INT64; @@ -31,6 +33,7 @@ namespace Tensorflow public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static TF_DataType resource = TF_DataType.TF_RESOURCE; /// /// @@ -45,6 +48,8 @@ namespace Tensorflow return typeof(bool); case TF_DataType.TF_UINT8: return typeof(byte); + case TF_DataType.TF_INT8: + return typeof(sbyte); case TF_DataType.TF_INT64: return typeof(long); case TF_DataType.TF_UINT64: @@ -243,7 +248,8 @@ namespace Tensorflow public static bool is_integer(this TF_DataType type) { return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64 || - type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64; + type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64 || + type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; } public static bool is_floating(this TF_DataType type) diff --git a/src/TensorFlowNET.Core/Tensors/shape_utils.cs b/src/TensorFlowNET.Core/Tensors/shape_utils.cs new file mode 100644 index 00000000..0974dc5b --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/shape_utils.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class shape_utils + { + public static Tensor static_or_dynamic_map_fn(Func fn, Tensor elems, TF_DataType[] dtypes = null, + int parallel_iterations = 32, bool back_prop = true) + { + var outputs = tf.unstack(elems).Select(arg => fn(arg)).ToArray(); + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 59c107fc..142afe06 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -118,110 +118,10 @@ namespace Tensorflow if (values == null) throw new ValueError("None values not supported."); - if(np_dt == null) - { - switch (values) - { - case bool boolVal: - nparray = boolVal; - break; - case int intVal: - nparray = intVal; - break; - case int[] intVals: - nparray = np.array(intVals); - break; - case int[,] intVals: - nparray = np.array(intVals); - break; - case long intVal: - nparray = intVal; - break; - case long[] intVals: - nparray = np.array(intVals); - break; - case long[,] intVals: - nparray = np.array(intVals); - break; - case float floatVal: - nparray = floatVal; - break; - case float[] floatVals: - nparray = floatVals; - break; - case float[,] floatVals: - nparray = np.array(floatVals); - break; - case double doubleVal: - nparray = doubleVal; - break; - case double[] doubleVals: - nparray = np.array(doubleVals); - break; - case double[,] doubleVals: - nparray = np.array(doubleVals); - break; - case string strVal: - nparray = strVal; - break; - case string[] strVals: - nparray = strVals; - break; - case byte[] byteValues: - nparray = byteValues; - break; - case byte[,] byteValues: - nparray = np.array(byteValues); - break; - default: - throw new NotImplementedException($"make_tensor_proto: Support for type {values.GetType()} Not Implemented"); - } - } - else - { - // convert data type - switch (np_dt.Name) - { - case "Int32": - if (values.GetType().IsArray) - nparray = np.array((int[])values, np_dt); - else - nparray = Converts.ToInt32(values); - break; - case "Int64": - if (values.GetType().IsArray) - nparray = np.array((int[])values, np_dt); - else - nparray = Converts.ToInt64(values); - break; - case "Single": - if (values.GetType().IsArray) - nparray = np.array((float[])values, np_dt); - else - nparray = Converts.ToSingle(values); - break; - case "Double": - if (values.GetType().IsArray) - nparray = np.array((double[])values, np_dt); - else - nparray = Converts.ToDouble(values); - break; - case "String": - if (values.GetType().IsArray) - nparray = np.array((string[])values, np_dt); - else - nparray = NDArray.FromString(Converts.ToString(values)); - break; - case "Boolean": - if (values.GetType().IsArray) - nparray = np.array((bool[])values, np_dt); - else - nparray = Converts.ToBoolean(values); - break; - default: - throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); - } - } + nparray = convert_to_numpy_ndarray(values); + + if (np_dt != null && np_dt != typeof(string)) + nparray = nparray.astype(np_dt); } var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); @@ -316,23 +216,59 @@ namespace Tensorflow case NDArray val: nd = val; break; - case int val: - nd = np.asarray(val); + case bool boolVal: + nd = boolVal; + break; + case int intVal: + nd = intVal; + break; + case int[] intVals: + nd = np.array(intVals); + break; + case int[,] intVals: + nd = np.array(intVals); + break; + case long intVal: + nd = intVal; + break; + case long[] intVals: + nd = np.array(intVals); + break; + case long[,] intVals: + nd = np.array(intVals); + break; + case float floatVal: + nd = floatVal; + break; + case float[] floatVals: + nd = floatVals; + break; + case float[,] floatVals: + nd = np.array(floatVals); + break; + case double doubleVal: + nd = doubleVal; + break; + case double[] doubleVals: + nd = np.array(doubleVals); + break; + case double[,] doubleVals: + nd = np.array(doubleVals); break; - case int[] val: - nd = np.array(val); + case string strVal: + nd = NDArray.FromString(strVal); break; - case float val: - nd = np.asarray(val); + case string[] strVals: + nd = strVals; break; - case double val: - nd = np.asarray(val); + case byte[] byteValues: + nd = byteValues; break; - case string val: - nd = np.asarray(val); + case byte[,] byteValues: + nd = np.array(byteValues); break; default: - throw new Exception("Not Implemented"); + throw new NotImplementedException($"convert_to_numpy_ndarray: Support for type {values.GetType()} Not Implemented"); } return nd; diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index d182a5cd..faf6fec2 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -41,6 +41,14 @@ namespace Tensorflow.Train _epsilon = epsilon; } + public AdamOptimizer(Tensor learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam") + : base(learning_rate, use_locking, name) + { + _beta1 = beta1; + _beta2 = beta2; + _epsilon = epsilon; + } + public override Operation _apply_sparse(IndexedSlices grad, RefVariable var) { return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) => @@ -146,10 +154,10 @@ namespace Tensorflow.Train var beta2 = _call_if_callable(_beta2); var epsilon = _call_if_callable(_epsilon); - _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); - _beta1_t = ops.convert_to_tensor(beta1, name: "beta1"); - _beta2_t = ops.convert_to_tensor(beta2, name: "beta2"); - _epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon"); + _lr_t = _lr_t ?? ops.convert_to_tensor(lr, name: "learning_rate"); + _beta1_t = _beta1_t ?? ops.convert_to_tensor(beta1, name: "beta1"); + _beta2_t = _beta2_t ?? ops.convert_to_tensor(beta2, name: "beta2"); + _epsilon_t = _epsilon_t ?? ops.convert_to_tensor(epsilon, name: "epsilon"); } } } diff --git a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs index e129edce..2d4effca 100644 --- a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs +++ b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Train bool _zero_debias; string _name; public string name => _name; - List _averages; + Dictionary _averages; public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false, string name = "ExponentialMovingAverage") @@ -22,7 +22,7 @@ namespace Tensorflow.Train _num_updates = num_updates; _zero_debias = zero_debias; _name = name; - _averages = new List(); + _averages = new Dictionary(); } /// @@ -37,16 +37,43 @@ namespace Tensorflow.Train foreach(var var in var_list) { - if (!_averages.Contains(var)) + if (!_averages.ContainsKey(var)) { ops.init_scope(); - var slot = new SlotCreator(); - var.initialized_value(); - // var avg = slot.create_zeros_slot + var slot_creator = new SlotCreator(); + var value = var.initialized_value(); + var avg = slot_creator.create_slot(var, + value, + name, + colocate_with_primary: true); + ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var); + _averages[var] = avg; + } + else + { + // avg = slot_creator.create_zeros_slot( + throw new NotImplementedException(""); } } - throw new NotImplementedException(""); + return tf_with(ops.name_scope(name), scope => + { + var decay = ops.convert_to_tensor(_decay, name: "decay"); + if (_num_updates.HasValue) + { + throw new NotImplementedException("ExponentialMovingAverage.apply"); + } + + var updates = new List(); + foreach (var var in var_list) + { + var zero_debias = false;// _averages[var] in zero_debias_true + var ama = moving_averages.assign_moving_average(_averages[var], var, decay, zero_debias: zero_debias); + updates.Add(ama); + } + + return control_flow_ops.group(updates.ToArray(), name: scope); + }); } } } diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index bb8fcd7a..e0040ecf 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -62,6 +62,20 @@ namespace Tensorflow _deferred_slot_restorations = new Dictionary(); } + public Optimizer(Tensor learning_rate, bool use_locking, string name = null) + { + if (String.IsNullOrEmpty(name)) + throw new NotImplementedException("Must specify the optimizer name"); + + _name = name; + _use_locking = use_locking; + _lr_t = learning_rate; + // Dictionary of slots. + _slots = new Dictionary>(); + _non_slot_dict = new Dictionary(); + _deferred_slot_restorations = new Dictionary(); + } + /// /// Add operations to minimize `loss` by updating `var_list` /// @@ -198,7 +212,7 @@ namespace Tensorflow if (!tf.context.executing_eagerly()) { - var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP) as List; + var train_op = ops.get_collection_ref(tf.GraphKeys.TRAIN_OP); if (train_op != null && train_op.Contains(apply_updates)) train_op.Add(apply_updates); } @@ -235,7 +249,9 @@ namespace Tensorflow { _maybe_initialize_trackable(); v = variable_scope.default_variable_creator( - initial_value, name: name, trainable: false, + initial_value, + name: name, + trainable: false, use_resource: resource_variable_ops.is_resource_variable( colocate_with)); @@ -357,17 +373,19 @@ namespace Tensorflow loss = _scale_loss(loss); int num_towers = 1; - - var tmp = variables.trainable_variables(); - var vars = ops.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); - switch (tmp) + if(var_list == null) { - case List values: - var_list = values.Concat(vars).ToList(); - break; - case List values: - var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); - break; + var vars = ops.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); + var tmp = variables.trainable_variables(); + switch (tmp) + { + case List values: + var_list = values.Concat(vars).ToList(); + break; + case List values: + var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); + break; + } } var_list = var_list.Concat(ops.get_collection(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); diff --git a/src/TensorFlowNET.Core/Train/QueueRunner.cs b/src/TensorFlowNET.Core/Train/QueueRunner.cs new file mode 100644 index 00000000..0a0d9c2e --- /dev/null +++ b/src/TensorFlowNET.Core/Train/QueueRunner.cs @@ -0,0 +1,36 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Queues; + +namespace Tensorflow.Train +{ + /// + /// Holds a list of enqueue operations for a queue, each to be run in a thread. + /// + public class QueueRunner + { + public QueueRunner(QueueBase queue, Operation[] enqueue_ops) + { + + } + + + } +} diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs index c9b60d32..7fe1a891 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -133,7 +133,7 @@ namespace Tensorflow var check_collection_list = graph.get_all_collection_keys(); foreach (var collection_type in check_collection_list) { - var cols = graph.get_collection(collection_type); + /*var cols = graph.get_collection(collection_type); switch (cols) { case List values: @@ -165,7 +165,7 @@ namespace Tensorflow break; default: throw new NotImplementedException("_build_internal.check_collection_list"); - } + }*/ } diff --git a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs index 4ff538db..47f64b91 100644 --- a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs +++ b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using static Tensorflow.SaverDef.Types; +using static Tensorflow.Binding; namespace Tensorflow { @@ -144,5 +145,70 @@ namespace Tensorflow return prefix + ".index"; return prefix; } + + /// + /// Finds the filename of latest saved checkpoint file. + /// + /// + /// + /// + public static string latest_checkpoint(string checkpoint_dir, string latest_filename = null) + { + // Pick the latest checkpoint based on checkpoint state. + var ckpt = get_checkpoint_state(checkpoint_dir, latest_filename); + if(ckpt != null && !string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) + { + // Look for either a V2 path or a V1 path, with priority for V2. + var v2_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V2); + var v1_path = _prefix_to_checkpoint_path(ckpt.ModelCheckpointPath, CheckpointFormatVersion.V1); + if (File.Exists(v2_path) || File.Exists(v1_path)) + return ckpt.ModelCheckpointPath; + else + throw new ValueError($"Couldn't match files for checkpoint {ckpt.ModelCheckpointPath}"); + } + return null; + } + + public static CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) + { + var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename); + if (File.Exists(coord_checkpoint_filename)) + { + var file_content = File.ReadAllLines(coord_checkpoint_filename); + // https://github.com/protocolbuffers/protobuf/issues/6654 + // var ckpt = CheckpointState.Parser.ParseFrom(file_content); + var ckpt = new CheckpointState(); + var field = CheckpointState.Descriptor.FindFieldByName("model_checkpoint_path"); + ckpt.ModelCheckpointPath = file_content.FirstOrDefault(x => x.StartsWith(field.Name + ":")).Substring(field.Name.Length + 2); + // remove first and last quote. + ckpt.ModelCheckpointPath = ckpt.ModelCheckpointPath.Substring(1, ckpt.ModelCheckpointPath.Length - 2); + + field = CheckpointState.Descriptor.FindFieldByName("all_model_checkpoint_paths"); + file_content.Where(x => x.StartsWith(field.Name + ":")) + .ToList() + .ForEach(x => + { + string value = x.Substring(field.Name.Length + 2); + ckpt.AllModelCheckpointPaths.Add(value.Substring(1, value.Length - 2)); + }); + + if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) + throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}"); + // For relative model_checkpoint_path and all_model_checkpoint_paths, + // prepend checkpoint_dir. + if (!Path.IsPathRooted(ckpt.ModelCheckpointPath)) + ckpt.ModelCheckpointPath = Path.Combine(checkpoint_dir, ckpt.ModelCheckpointPath); + foreach(var i in range(len(ckpt.AllModelCheckpointPaths))) + { + var p = ckpt.AllModelCheckpointPaths[i]; + if (!Path.IsPathRooted(p)) + ckpt.AllModelCheckpointPaths[i] = Path.Combine(checkpoint_dir, p); + } + + return ckpt; + } + + return null; + } } } diff --git a/src/TensorFlowNET.Core/Train/SessionRunArgs.cs b/src/TensorFlowNET.Core/Train/SessionRunArgs.cs new file mode 100644 index 00000000..00d473e1 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/SessionRunArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Train +{ + public class SessionRunArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Train/SessionRunContext.cs b/src/TensorFlowNET.Core/Train/SessionRunContext.cs new file mode 100644 index 00000000..cf8bdc05 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/SessionRunContext.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Train +{ + public class SessionRunContext + { + SessionRunArgs _original_args; + public SessionRunArgs original_args => _original_args; + + Session _session; + public Session session => _session; + + bool _stop_requested; + public bool stop_requested => _stop_requested; + + public SessionRunContext(SessionRunArgs original_args, Session session) + { + _original_args = original_args; + _session = session; + _stop_requested = false; + } + + public void request_stop() + { + _stop_requested = true; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/SessionRunValues.cs b/src/TensorFlowNET.Core/Train/SessionRunValues.cs new file mode 100644 index 00000000..655c3310 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/SessionRunValues.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Train +{ + public class SessionRunValues + { + } +} diff --git a/src/TensorFlowNET.Core/Train/SlotCreator.cs b/src/TensorFlowNET.Core/Train/SlotCreator.cs index 29e073c7..1334b4bd 100644 --- a/src/TensorFlowNET.Core/Train/SlotCreator.cs +++ b/src/TensorFlowNET.Core/Train/SlotCreator.cs @@ -22,6 +22,24 @@ namespace Tensorflow.Train { public class SlotCreator { + /// + /// Create a slot initialized to the given value. + /// + /// + /// + /// + /// + /// + public RefVariable create_slot(RefVariable primary, Tensor val, string name, bool colocate_with_primary = true) + { + var validate_shape = val.TensorShape.is_fully_defined(); + var prefix = primary.op.name; + return tf_with(tf.variable_scope(name: null, prefix + "/" + name), delegate + { + return _create_slot_var(primary, val, "", validate_shape, null, TF_DataType.DtInvalid); + }); + } + /// /// Create a slot initialized to 0 with same shape as the primary object. /// @@ -73,7 +91,7 @@ namespace Tensorflow.Train /// /// /// - private RefVariable _create_slot_var(VariableV1 primary, IInitializer val, string scope, bool validate_shape, + private RefVariable _create_slot_var(VariableV1 primary, object val, string scope, bool validate_shape, TensorShape shape, TF_DataType dtype) { bool use_resource = primary is ResourceVariable; diff --git a/src/TensorFlowNET.Core/Train/TrainingUtil.cs b/src/TensorFlowNET.Core/Train/TrainingUtil.cs new file mode 100644 index 00000000..63227733 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/TrainingUtil.cs @@ -0,0 +1,89 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class TrainingUtil + { + public static RefVariable create_global_step(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + if (get_global_step(graph) != null) + throw new ValueError("global_step already exists."); + + // Create in proper graph and base name_scope. + var g = graph.as_default(); + g.name_scope(null); + var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64, + initializer: tf.zeros_initializer, + trainable: false, + aggregation: VariableAggregation.OnlyFirstReplica, + collections: new List { tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP }); + return v; + } + + public static RefVariable get_global_step(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + RefVariable global_step_tensor = null; + var global_step_tensors = graph.get_collection(tf.GraphKeys.GLOBAL_STEP); + if (global_step_tensors.Count == 1) + { + global_step_tensor = global_step_tensors[0]; + } + else + { + try + { + global_step_tensor = graph.get_tensor_by_name("global_step:0"); + } + catch (KeyError) + { + return null; + } + } + + return global_step_tensor; + } + + public static Tensor _get_or_create_global_step_read(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + var global_step_read_tensor = _get_global_step_read(graph); + if (global_step_read_tensor != null) + return global_step_read_tensor; + + var global_step_tensor = get_global_step(graph); + + if (global_step_tensor == null) + return null; + + var g = graph.as_default(); + g.name_scope(null); + g.name_scope(global_step_tensor.op.name + "/"); + // using initialized_value to ensure that global_step is initialized before + // this run. This is needed for example Estimator makes all model_fn build + // under global_step_read_tensor dependency. + var global_step_value = global_step_tensor.initialized_value(); + ops.add_to_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY, global_step_value + 0); + + return _get_global_step_read(graph); + } + + private static Tensor _get_global_step_read(Graph graph = null) + { + graph = graph ?? ops.get_default_graph(); + var global_step_read_tensors = graph.get_collection(tf.GraphKeys.GLOBAL_STEP_READ_KEY); + if (global_step_read_tensors.Count > 1) + throw new RuntimeError($"There are multiple items in collection {tf.GraphKeys.GLOBAL_STEP_READ_KEY}. " + + "There should be only one."); + + if (global_step_read_tensors.Count == 1) + return global_step_read_tensors[0]; + + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/_MonitoredSession.cs b/src/TensorFlowNET.Core/Train/_MonitoredSession.cs new file mode 100644 index 00000000..e89b1b89 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/_MonitoredSession.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Train +{ + internal class _MonitoredSession + { + } +} diff --git a/src/TensorFlowNET.Core/Train/moving_averages.cs b/src/TensorFlowNET.Core/Train/moving_averages.cs new file mode 100644 index 00000000..de4e7f2e --- /dev/null +++ b/src/TensorFlowNET.Core/Train/moving_averages.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Train +{ + public class moving_averages + { + /// + /// Compute the moving average of a variable. + /// + /// + /// + /// + /// + /// + /// + public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay, + bool zero_debias = true, string name = null) + { + return tf_with(ops.name_scope(name, "AssignMovingAvg", new { variable, value, decay }), scope => + { + decay = ops.convert_to_tensor(1.0f - decay, name: "decay"); + if (decay.dtype != variable.dtype.as_base_dtype()) + decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); + + return state_ops.assign_sub(variable, (variable - value) * decay, name: scope); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Util/Locks.cs b/src/TensorFlowNET.Core/Util/Locks.cs new file mode 100644 index 00000000..3b54ee2c --- /dev/null +++ b/src/TensorFlowNET.Core/Util/Locks.cs @@ -0,0 +1,21 @@ +using System.Threading; + +namespace Tensorflow.Util +{ + /// + /// Provides a set of locks on different shared levels. + /// + public static class Locks + { + private static readonly ThreadLocal _lockpool = new ThreadLocal(() => new object()); + + /// + /// A seperate lock for every requesting thread. + /// + /// This property is thread-safe. + public static object ThreadWide => _lockpool.Value; + + + public static readonly object ProcessWide = new object(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs index b4c77226..79d7dd5f 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Operators.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using static Tensorflow.Binding; namespace Tensorflow @@ -29,12 +30,29 @@ namespace Tensorflow public static Tensor operator -(RefVariable x, double y) => op_helper("sub", x, y); public static Tensor operator -(RefVariable x, Tensor y) => op_helper("sub", x, y); + public static Tensor operator <(RefVariable x, Tensor y) => gen_math_ops.less(x.value(), y); + + public static Tensor operator >(RefVariable x, Tensor y) => gen_math_ops.greater(x.value(), y); + private static Tensor op_helper(string default_name, RefVariable x, T y) { - var tensor1 = x.value(); - return tf_with(ops.name_scope(null, default_name, new { tensor1, y }), scope => { - var tensor2 = ops.convert_to_tensor(y, tensor1.dtype.as_base_dtype(), "y"); - return gen_math_ops.add(tensor1, tensor2, scope); + var xVal = x.value(); + return tf_with(ops.name_scope(null, default_name, new { xVal, y }), scope => { + string name = scope; + var yTensor = ops.convert_to_tensor(y, xVal.dtype.as_base_dtype(), "y"); + Tensor result = null; + switch (default_name) + { + case "add": + result = gen_math_ops.add(xVal, yTensor, name); + break; + case "sub": + result = gen_math_ops.sub(xVal, yTensor, name); + break; + default: + throw new NotImplementedException(""); + } + return result; }); } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index e0e3e0f7..97e1d0f4 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -17,6 +17,7 @@ using Google.Protobuf; using System; using System.Collections.Generic; +using System.Linq; using static Tensorflow.Binding; namespace Tensorflow @@ -176,7 +177,7 @@ namespace Tensorflow // If 'initial_value' makes use of other variables, make sure we don't // have an issue if these other variables aren't initialized first by // using their initialized_value() method. - var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value); + var _initial_value2 = _try_guard_against_uninitialized_dependencies(name, _initial_value); _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op; @@ -215,9 +216,9 @@ namespace Tensorflow /// Attempt to guard against dependencies on uninitialized variables. /// /// - private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value) + private Tensor _try_guard_against_uninitialized_dependencies(string name, Tensor initial_value) { - return _safe_initial_value_from_tensor(initial_value, new Dictionary()); + return _safe_initial_value_from_tensor(name, initial_value, op_cache: new Dictionary()); } /// @@ -226,19 +227,19 @@ namespace Tensorflow /// A `Tensor`. The tensor to replace. /// A dict mapping operation names to `Operation`s. /// A `Tensor` compatible with `tensor`. - private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary op_cache) + private Tensor _safe_initial_value_from_tensor(string name, Tensor tensor, Dictionary op_cache) { var op = tensor.op; var new_op = op_cache.ContainsKey(op.name) ? op_cache[op.name] : null; if(new_op == null) { - new_op = _safe_initial_value_from_op(op, op_cache); + new_op = _safe_initial_value_from_op(name, op, op_cache); op_cache[op.name] = new_op; } return new_op.outputs[tensor.value_index]; } - private Operation _safe_initial_value_from_op(Operation op, Dictionary op_cache) + private Operation _safe_initial_value_from_op(string name, Operation op, Dictionary op_cache) { var op_type = op.node_def.Op; switch (op_type) @@ -250,13 +251,54 @@ namespace Tensorflow case "Variable": case "VariableV2": case "VarHandleOp": - break; + var initialized_value = _find_initialized_value_for_variable(op); + return initialized_value == null ? op : initialized_value.op; } // Recursively build initializer expressions for inputs. + var modified = false; + var new_op_inputs = new List(); + foreach (var op_input in op.inputs) + { + var new_op_input = _safe_initial_value_from_tensor(name, op_input as Tensor, op_cache); + new_op_inputs.Add(new_op_input); + modified = modified || new_op_input != op_input; + } + + // If at least one input was modified, replace the op. + if (modified) + { + var new_op_type = op_type; + if (new_op_type == "RefSwitch") + new_op_type = "Switch"; + var new_op_name = op.node_def.Name + "_" + name; + new_op_name = new_op_name.Replace(":", "_"); + + // Convert attr values to AttrValue protos. + var attr_protos = new Dictionary(); + foreach (var attr_def in op.node_def.Attr) + attr_protos[attr_def.Key] = attr_def.Value; + + return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types, + name: new_op_name, attrs: attr_protos); + } return op; } + private Operation _find_initialized_value_for_variable(Operation variable_op) + { + var var_names = new[] { variable_op.node_def.Name, variable_op.node_def.Name + ":0" }; + foreach(var collection_name in new[]{tf.GraphKeys.GLOBAL_VARIABLES, + tf.GraphKeys.LOCAL_VARIABLES }) + { + foreach (var var in variable_op.graph.get_collection(collection_name)) + if (var_names.Contains(var.name)) + return var.initialized_value(); + } + + return null; + } + /// /// Assigns a new value to the variable. /// @@ -318,6 +360,15 @@ namespace Tensorflow return array_ops.identity(_variable, name: "read"); } + /// + /// Returns the Tensor used as the initial value for the variable. + /// + /// + private ITensorOrOperation initial_value() + { + return _initial_value; + } + public Tensor is_variable_initialized(RefVariable variable) { return state_ops.is_variable_initialized(variable); @@ -326,10 +377,9 @@ namespace Tensorflow public Tensor initialized_value() { ops.init_scope(); - throw new NotImplementedException(""); - /*return control_flow_ops.cond(is_variable_initialized(this), + return control_flow_ops.cond(is_variable_initialized(this), read_value, - () => initial_value);*/ + initial_value); } } } diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index afc221c8..ad7750a1 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; using static Tensorflow.Binding; namespace Tensorflow @@ -50,6 +51,7 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, object initializer = null, // IInitializer or Tensor bool? trainable = null, + List collections = null, bool? use_resource = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, @@ -67,6 +69,7 @@ namespace Tensorflow initializer: initializer, reuse: resue, trainable: trainable, + collections: collections, synchronization: synchronization, aggregation: aggregation); }); diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index 8957568e..d0fbf161 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -42,6 +42,7 @@ namespace Tensorflow object initializer = null, // IInitializer or Tensor bool? reuse = null, bool? trainable = null, + List collections = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None) @@ -54,6 +55,7 @@ namespace Tensorflow dtype: dtype, initializer: initializer, trainable: trainable, + collections: collections, validate_shape: validate_shape, synchronization: synchronization, aggregation: aggregation); @@ -64,6 +66,7 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.TF_FLOAT, object initializer = null, bool? trainable = null, + List collections = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None) @@ -77,6 +80,7 @@ namespace Tensorflow dtype: dtype, initializer: init, trainable: trainable, + collections: collections, validate_shape: validate_shape, synchronization: synchronization, aggregation: aggregation); @@ -112,6 +116,7 @@ namespace Tensorflow IInitializer initializer = null, bool reuse = false, bool? trainable = null, + List collections = null, bool validate_shape = false, bool? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.Auto, @@ -157,6 +162,7 @@ namespace Tensorflow v = variable_scope.default_variable_creator(init_val, name: name, trainable: trainable, + collections: collections, dtype: variable_dtype, validate_shape: validate_shape, synchronization: synchronization, diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 5c8744b6..9c006170 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -126,7 +126,7 @@ namespace Tensorflow // name: A name for the operation(optional). // Returns: // A mutable `Tensor`. Has the same type as `ref`. - public static Tensor assign_add(RefVariable @ref, Tensor value, bool use_locking = false, string name = null) + public static Tensor assign_add(RefVariable @ref, T value, bool use_locking = false, string name = null) { var _op = _op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); return _op.outputs[0]; @@ -149,7 +149,8 @@ namespace Tensorflow public static Tensor is_variable_initialized(RefVariable @ref, string name = null) { - throw new NotImplementedException(""); + var _op = _op_def_lib._apply_op_helper("IsVariableInitialized", name: name, args: new { @ref }); + return _op.output; } } } diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index 8f478f2d..cd8d4f3f 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -94,10 +94,15 @@ namespace Tensorflow // Returns: // Same as "ref". Returned as a convenience for operations that want // to use the new value after the variable has been updated. - public static Tensor assign_add(RefVariable @ref, - Tensor value, + public static Tensor assign_add(RefVariable @ref, + T value, bool use_locking = false, - string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); + throw new NotImplementedException("assign_add"); + } public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) { diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 6bc83052..4f357b12 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -175,6 +175,7 @@ namespace Tensorflow public static RefVariable default_variable_creator(object initial_value, string name = null, bool? trainable = null, + List collections = null, TF_DataType dtype = TF_DataType.DtInvalid, bool validate_shape = false, bool ? use_resource = null, @@ -199,6 +200,7 @@ namespace Tensorflow return new RefVariable(initial_value, trainable: trainable.Value, validate_shape: validate_shape, + collections: collections, name: name, dtype: dtype); } diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index c5a06433..4e7235bc 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -52,6 +52,8 @@ namespace Tensorflow /// public const string LOSSES_ = "losses"; + public const string MOVING_AVERAGE_VARIABLES = "moving_average_variables"; + /// /// Key to collect Variable objects that are global (shared across machines). /// Default collection for all variables, except local ones. @@ -100,6 +102,12 @@ namespace Tensorflow /// public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; + /// + /// Key to collect local variables that are local to the machine and are not + /// saved/restored. + /// + public string LOCAL_VARIABLES = "local_variables"; + /// /// Key to collect losses /// @@ -114,6 +122,7 @@ namespace Tensorflow public string TRAIN_OP => TRAIN_OP_; public string GLOBAL_STEP => GLOBAL_STEP_; + public string GLOBAL_STEP_READ_KEY = "global_step_read_op_cache"; public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; /// diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 1dc8eb56..846de1ea 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -19,13 +19,19 @@ using System.Collections.Generic; using System.Runtime.InteropServices; using Google.Protobuf; using System.Linq; +using System.Threading; using NumSharp; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow { public partial class ops { + private static readonly ThreadLocal _defaultGraphFactory = new ThreadLocal(() => new DefaultGraphStack()); + + public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; + public static int tensor_id(Tensor tensor) { return tensor.Id; @@ -67,13 +73,11 @@ namespace Tensorflow return get_default_graph().get_collection(key, scope); } - public static object get_collection_ref(string key) + public static List get_collection_ref(string key) { - return get_default_graph().get_collection_ref(key); + return get_default_graph().get_collection_ref(key); } - public static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); - /// /// Returns the default graph for the current thread. /// @@ -93,6 +97,7 @@ namespace Tensorflow //return _default_graph_stack.get_default() return default_graph_stack.get_controller(); } + public static Graph set_default_graph(Graph graph) { //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! @@ -201,49 +206,52 @@ namespace Tensorflow /// /// A list of `Operation`s to set as control dependencies. /// A wrapped TF_Operation*. - public static (IntPtr, IntPtr) _create_c_op(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) + public static IntPtr _create_c_op(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) { - var op_desc = graph.NewOperation(node_def.Op, node_def.Name); - - //TODO: Implement TF_SetDevice - //if node_def.device: - // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) - // Add inputs - foreach (var op_input in inputs) + lock (Locks.ProcessWide) { - if (op_input is Tensor[] op_inputs) - c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); - else if (op_input is Tensor op_input1) + var op_desc = graph.NewOperation(node_def.Op, node_def.Name); + + //TODO: Implement TF_SetDevice + //if node_def.device: + // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) + // Add inputs + foreach (var op_input in inputs) { - c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); + if (op_input is Tensor[] op_inputs) + c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); + else if (op_input is Tensor op_input1) + { + c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); + } else + throw new NotImplementedException("_create_c_op"); } - else - throw new NotImplementedException("_create_c_op"); - } - - var status = new Status(); - // Add control inputs - foreach (var control_input in control_inputs) - c_api.TF_AddControlInput(op_desc, control_input); - - // Add attrs - foreach (var attr in node_def.Attr) - { - var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. - var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak - Marshal.Copy(bytes, 0, proto, bytes.Length); - uint len = (uint)bytes.Length; - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); - - status.Check(true); + using (var status = new Status()) + { + // Add control inputs + foreach (var control_input in control_inputs) + c_api.TF_AddControlInput(op_desc, control_input); + + // Add attrs + foreach (var attr in node_def.Attr) + { + var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. + var protoHandle = Marshal.AllocHGlobal(bytes.Length); + Marshal.Copy(bytes, 0, protoHandle, bytes.Length); + uint len = (uint)bytes.Length; + c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); + status.Check(true); + Marshal.FreeHGlobal(protoHandle); + } + + var c_op = c_api.TF_FinishOperation(op_desc, status); + + status.Check(true); + + return c_op; + } } - - var c_op = c_api.TF_FinishOperation(op_desc, status); - - status.Check(true); - - return (c_op, op_desc); } public static OpDef _get_op_def(Graph graph, string type) @@ -311,7 +319,7 @@ namespace Tensorflow /// public static int uid() { - return uid_number++; + return Interlocked.Increment(ref uid_number); } public static void colocate_with(bool ignore_existing = false) @@ -386,8 +394,6 @@ namespace Tensorflow /// The default `Session` being used in the current thread. public static Session get_default_session() { - if (tf.defaultSession == null) - tf.defaultSession = tf.Session(); return tf.defaultSession; } @@ -500,6 +506,8 @@ namespace Tensorflow return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); case ResourceVariable varVal: return null; + case TensorShape ts: + return constant_op.constant(ts.dims, dtype: dtype, name: name); case object[] objects: return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name); default: diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index ca903844..cf973864 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -14,12 +14,15 @@ limitations under the License. ******************************************************************************/ +using System.Threading; using Tensorflow.Eager; namespace Tensorflow { public partial class tensorflow : IObjectLife { + protected internal readonly ThreadLocal _defaultSessionFactory; + public TF_DataType @byte = TF_DataType.TF_UINT8; public TF_DataType @sbyte = TF_DataType.TF_INT8; public TF_DataType int16 = TF_DataType.TF_INT16; @@ -34,7 +37,13 @@ namespace Tensorflow public Context context = new Context(new ContextOptions(), new Status()); - public Session defaultSession; + + public tensorflow() + { + _defaultSessionFactory = new ThreadLocal(() => new Session()); + } + + public Session defaultSession => _defaultSessionFactory.Value; public RefVariable Variable(T data, bool trainable = true, @@ -64,17 +73,17 @@ namespace Tensorflow public Session Session() { - return new Session(); + return new Session().as_default(); } public Session Session(Graph graph, SessionOptions opts = null) { - return new Session(graph, opts: opts); + return new Session(graph, opts: opts).as_default(); } public Session Session(SessionOptions opts) { - return new Session(null, opts); + return new Session(null, opts).as_default(); } public void __init__() diff --git a/src/TensorFlowDatasets/DatasetBuilder.cs b/src/TensorFlowNET.Datasets/DatasetBuilder.cs similarity index 100% rename from src/TensorFlowDatasets/DatasetBuilder.cs rename to src/TensorFlowNET.Datasets/DatasetBuilder.cs diff --git a/src/TensorFlowDatasets/DownloadConfig.cs b/src/TensorFlowNET.Datasets/DownloadConfig.cs similarity index 100% rename from src/TensorFlowDatasets/DownloadConfig.cs rename to src/TensorFlowNET.Datasets/DownloadConfig.cs diff --git a/src/TensorFlowDatasets/TensorFlowDatasets.csproj b/src/TensorFlowNET.Datasets/TensorFlowNET.Datasets.csproj similarity index 88% rename from src/TensorFlowDatasets/TensorFlowDatasets.csproj rename to src/TensorFlowNET.Datasets/TensorFlowNET.Datasets.csproj index 1b839c1f..198c3e12 100644 --- a/src/TensorFlowDatasets/TensorFlowDatasets.csproj +++ b/src/TensorFlowNET.Datasets/TensorFlowNET.Datasets.csproj @@ -14,6 +14,8 @@ git SciSharp, Dataset, TensorFlow Apache 2.0 + TensorFlow.Datasets + TensorFlow.Datasets diff --git a/src/TensorFlowHub/DataSetBase.cs b/src/TensorFlowNET.Hub/DataSetBase.cs similarity index 100% rename from src/TensorFlowHub/DataSetBase.cs rename to src/TensorFlowNET.Hub/DataSetBase.cs diff --git a/src/TensorFlowHub/Datasets.cs b/src/TensorFlowNET.Hub/Datasets.cs similarity index 100% rename from src/TensorFlowHub/Datasets.cs rename to src/TensorFlowNET.Hub/Datasets.cs diff --git a/src/TensorFlowHub/IDataSet.cs b/src/TensorFlowNET.Hub/IDataSet.cs similarity index 100% rename from src/TensorFlowHub/IDataSet.cs rename to src/TensorFlowNET.Hub/IDataSet.cs diff --git a/src/TensorFlowHub/IModelLoader.cs b/src/TensorFlowNET.Hub/IModelLoader.cs similarity index 100% rename from src/TensorFlowHub/IModelLoader.cs rename to src/TensorFlowNET.Hub/IModelLoader.cs diff --git a/src/TensorFlowHub/MnistDataSet.cs b/src/TensorFlowNET.Hub/MnistDataSet.cs similarity index 100% rename from src/TensorFlowHub/MnistDataSet.cs rename to src/TensorFlowNET.Hub/MnistDataSet.cs diff --git a/src/TensorFlowHub/MnistModelLoader.cs b/src/TensorFlowNET.Hub/MnistModelLoader.cs similarity index 100% rename from src/TensorFlowHub/MnistModelLoader.cs rename to src/TensorFlowNET.Hub/MnistModelLoader.cs diff --git a/src/TensorFlowHub/ModelLoadSetting.cs b/src/TensorFlowNET.Hub/ModelLoadSetting.cs similarity index 100% rename from src/TensorFlowHub/ModelLoadSetting.cs rename to src/TensorFlowNET.Hub/ModelLoadSetting.cs diff --git a/src/TensorFlowHub/README.md b/src/TensorFlowNET.Hub/README.md similarity index 100% rename from src/TensorFlowHub/README.md rename to src/TensorFlowNET.Hub/README.md diff --git a/src/TensorFlowHub/TensorFlowHub.csproj b/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj similarity index 86% rename from src/TensorFlowHub/TensorFlowHub.csproj rename to src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj index 16e22183..27b5128b 100644 --- a/src/TensorFlowHub/TensorFlowHub.csproj +++ b/src/TensorFlowNET.Hub/TensorFlowNET.Hub.csproj @@ -2,7 +2,7 @@ Tensorflow.Hub netstandard2.0 - 0.0.2 + 0.0.3 Kerry Jiang SciSharp STACK Apache 2.0 @@ -15,8 +15,9 @@ true https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + TensorFlow.Hub - + \ No newline at end of file diff --git a/src/TensorFlowHub/Utils.cs b/src/TensorFlowNET.Hub/Utils.cs similarity index 100% rename from src/TensorFlowHub/Utils.cs rename to src/TensorFlowNET.Hub/Utils.cs diff --git a/src/TensorFlowNET.Models/ObjectDetection/AnchorGenerators/GridAnchorGenerator.cs b/src/TensorFlowNET.Models/ObjectDetection/AnchorGenerators/GridAnchorGenerator.cs new file mode 100644 index 00000000..c5093dfa --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/AnchorGenerators/GridAnchorGenerator.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection +{ + public class GridAnchorGenerator : Core.AnchorGenerator + { + public GridAnchorGenerator(float[] scales = null) + { + if (scales == null) + scales = new[] { 0.5f, 1.0f, 2.0f }; + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/AnchorGeneratorBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/AnchorGeneratorBuilder.cs new file mode 100644 index 00000000..f220bccd --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/AnchorGeneratorBuilder.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Models.ObjectDetection.Protos; +using static Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator; + +namespace Tensorflow.Models.ObjectDetection +{ + public class AnchorGeneratorBuilder + { + public AnchorGeneratorBuilder() + { + + } + + public GridAnchorGenerator build(AnchorGenerator anchor_generator_config) + { + if(anchor_generator_config.AnchorGeneratorOneofCase == AnchorGeneratorOneofOneofCase.GridAnchorGenerator) + { + var grid_anchor_generator_config = anchor_generator_config.GridAnchorGenerator; + return new GridAnchorGenerator(scales: grid_anchor_generator_config.Scales.Select(x => float.Parse(x.ToString())).ToArray()); + } + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/BoxPredictorBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/BoxPredictorBuilder.cs new file mode 100644 index 00000000..7ff2be25 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/BoxPredictorBuilder.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection +{ + public class BoxPredictorBuilder + { + ConvolutionalBoxPredictor _first_stage_box_predictor; + public ConvolutionalBoxPredictor build_convolutional_box_predictor() + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs new file mode 100644 index 00000000..3c47bf51 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/DatasetBuilder.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Data; +using Tensorflow.Models.ObjectDetection.Protos; + +namespace Tensorflow.Models.ObjectDetection +{ + public class DatasetBuilder + { + public DatasetV1Adapter build(InputReader input_reader_config, + int batch_size = 0, + Action transform_input_data_fn = null) + { + Func, (Dictionary, Dictionary)> transform_and_pad_input_data_fn = (tensor_dict) => + { + return (null, null); + }; + + var config = input_reader_config.TfRecordInputReader; + + throw new NotImplementedException(""); + } + + public Dictionary process_fn(Tensor value) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs new file mode 100644 index 00000000..5c52bd87 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ImageResizerBuilder.cs @@ -0,0 +1,65 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using Tensorflow.Models.ObjectDetection.Core; +using Tensorflow.Models.ObjectDetection.Protos; +using static Tensorflow.Models.ObjectDetection.Protos.ImageResizer; + +namespace Tensorflow.Models.ObjectDetection +{ + public class ImageResizerBuilder + { + public ImageResizerBuilder() + { + + } + + /// + /// Builds callable for image resizing operations. + /// + /// + /// + public Func build(ImageResizer image_resizer_config) + { + var image_resizer_oneof = image_resizer_config.ImageResizerOneofCase; + if (image_resizer_oneof == ImageResizerOneofOneofCase.KeepAspectRatioResizer) + { + var keep_aspect_ratio_config = image_resizer_config.KeepAspectRatioResizer; + if (keep_aspect_ratio_config.MinDimension > keep_aspect_ratio_config.MaxDimension) + throw new ValueError("min_dimension > max_dimension"); + var method = _tf_resize_method(keep_aspect_ratio_config.ResizeMethod); + var per_channel_pad_value = new[] { 0, 0, 0 }; + if (keep_aspect_ratio_config.PerChannelPadValue.Count > 0) + throw new NotImplementedException(""); + // per_channel_pad_value = new[] { keep_aspect_ratio_config.PerChannelPadValue. }; + + var args = new ResizeToRangeArgs + { + min_dimension = keep_aspect_ratio_config.MinDimension, + max_dimension = keep_aspect_ratio_config.MaxDimension, + method = method, + pad_to_max_dimension = keep_aspect_ratio_config.PadToMaxDimension, + per_channel_pad_value = per_channel_pad_value + }; + + Func func = (input) => + { + args.image = input.image; + return Preprocessor.resize_to_range(args); + }; + + return func; + } + else + { + throw new NotImplementedException(""); + } + } + + private ResizeMethod _tf_resize_method(ResizeType resize_method) + { + return (ResizeMethod)(int)resize_method; + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs new file mode 100644 index 00000000..596a7532 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Builders/ModelBuilder.cs @@ -0,0 +1,89 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Models.ObjectDetection.Protos; +using static Tensorflow.Models.ObjectDetection.Protos.DetectionModel; + +namespace Tensorflow.Models.ObjectDetection +{ + public class ModelBuilder + { + ImageResizerBuilder _image_resizer_builder; + FasterRCNNFeatureExtractor _feature_extractor; + AnchorGeneratorBuilder _anchor_generator_builder; + + public ModelBuilder() + { + _image_resizer_builder = new ImageResizerBuilder(); + _anchor_generator_builder = new AnchorGeneratorBuilder(); + } + + /// + /// Builds a DetectionModel based on the model config. + /// + /// A model.proto object containing the config for the desired DetectionModel. + /// True if this model is being built for training purposes. + /// Whether to add tensorflow summaries in the model graph. + /// DetectionModel based on the config. + public FasterRCNNMetaArch build(DetectionModel model_config, bool is_training, bool add_summaries = true) + { + var meta_architecture = model_config.ModelCase; + if (meta_architecture == ModelOneofCase.Ssd) + throw new NotImplementedException(""); + else if (meta_architecture == ModelOneofCase.FasterRcnn) + return _build_faster_rcnn_model(model_config.FasterRcnn, is_training, add_summaries); + + throw new ValueError($"Unknown meta architecture: {meta_architecture}"); + } + + /// + /// Builds a Faster R-CNN or R-FCN detection model based on the model config. + /// + /// + /// + /// + /// FasterRCNNMetaArch based on the config. + private FasterRCNNMetaArch _build_faster_rcnn_model(FasterRcnn frcnn_config, bool is_training, bool add_summaries) + { + var num_classes = frcnn_config.NumClasses; + var image_resizer_fn = _image_resizer_builder.build(frcnn_config.ImageResizer); + + var feature_extractor = _build_faster_rcnn_feature_extractor(frcnn_config.FeatureExtractor, is_training, + inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); + + var number_of_stages = frcnn_config.NumberOfStages; + var first_stage_anchor_generator = _anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); + var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; + + return new FasterRCNNMetaArch(new FasterRCNNInitArgs + { + is_training = is_training, + num_classes = num_classes, + image_resizer_fn = image_resizer_fn, + feature_extractor = _feature_extractor, + number_of_stage = number_of_stages, + first_stage_anchor_generator = null, + first_stage_atrous_rate = first_stage_atrous_rate + }); + } + + public Action preprocess() + { + throw new NotImplementedException(""); + } + + private FasterRCNNFeatureExtractor _build_faster_rcnn_feature_extractor(FasterRcnnFeatureExtractor feature_extractor_config, + bool is_training, bool reuse_weights = false, bool inplace_batchnorm_update = false) + { + if (inplace_batchnorm_update) + throw new ValueError("inplace batchnorm updates not supported."); + var feature_type = feature_extractor_config.Type; + var first_stage_features_stride = feature_extractor_config.FirstStageFeaturesStride; + var batch_norm_trainable = feature_extractor_config.BatchNormTrainable; + + return new FasterRCNNResnet101FeatureExtractor(is_training, first_stage_features_stride, + batch_norm_trainable: batch_norm_trainable, + reuse_weights: reuse_weights); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Core/AnchorGenerator.cs b/src/TensorFlowNET.Models/ObjectDetection/Core/AnchorGenerator.cs new file mode 100644 index 00000000..af44ee3f --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Core/AnchorGenerator.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection.Core +{ + public class AnchorGenerator + { + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Core/DetectionModel.cs b/src/TensorFlowNET.Models/ObjectDetection/Core/DetectionModel.cs new file mode 100644 index 00000000..24578a5b --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Core/DetectionModel.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection.Core +{ + public abstract class DetectionModel + { + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Core/Preprocessor.cs b/src/TensorFlowNET.Models/ObjectDetection/Core/Preprocessor.cs new file mode 100644 index 00000000..ac3cc805 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Core/Preprocessor.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Models.ObjectDetection.Core +{ + public class Preprocessor + { + public static Tensor[] resize_to_range(ResizeToRangeArgs args) + { + var image = args.image; + var min_dimension = args.min_dimension; + var max_dimension = args.max_dimension; + var method = args.method; + var align_corners = args.align_corners; + + if (image.NDims != 3) + throw new ValueError("Image should be 3D tensor"); + + + Func _resize_landscape_image = (image1) => + { + return tf.image.resize_images(image1, + tf.stack(new[] { min_dimension, max_dimension }), + method: method, + align_corners: align_corners, + preserve_aspect_ratio: true); + }; + Func _resize_portrait_image = (image1) => + { + return tf.image.resize_images(image1, + tf.stack(new[] { min_dimension, max_dimension }), + method: method, + align_corners: align_corners, + preserve_aspect_ratio: true); + }; + + return tf_with(tf.name_scope("ResizeToRange", values: new { image, min_dimension }), delegate + { + Tensor new_image, new_size; + + if (image.TensorShape.is_fully_defined()) + throw new NotImplementedException(""); + else + { + new_image = tf.cond( + tf.less(tf.shape(image)[0], tf.shape(image)[1]), + () => _resize_landscape_image(image), + () => _resize_portrait_image(image)); + new_size = tf.shape(new_image); + } + + if (args.pad_to_max_dimension) + { + throw new NotImplementedException(""); + } + + var result = new List { new_image }; + if (args.masks != null) + throw new NotImplementedException(""); + + result.Add(new_size); + + return result.ToArray(); + }); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Core/ResizeToRangeArgs.cs b/src/TensorFlowNET.Models/ObjectDetection/Core/ResizeToRangeArgs.cs new file mode 100644 index 00000000..1a3c8eb5 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Core/ResizeToRangeArgs.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection.Core +{ + public class ResizeToRangeArgs + { + public Tensor image { get; set; } + public int[] masks { get; set; } + public int min_dimension { get; set; } + public int max_dimension { get; set; } + public ResizeMethod method {get;set;} + public bool align_corners { get; set; } + public bool pad_to_max_dimension { get; set; } + public int[] per_channel_pad_value { get; set; } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs b/src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs new file mode 100644 index 00000000..aa6bb502 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Entities/TrainAndEvalDict.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Data; +using Tensorflow.Estimators; + +namespace Tensorflow.Models.ObjectDetection +{ + public class TrainAndEvalDict + { + public Estimator estimator { get; set; } + public Func train_input_fn { get; set; } + public Action[] eval_input_fns { get; set; } + public string[] eval_input_names { get; set; } + public Action eval_on_train_input_fn { get; set; } + public Action predict_input_fn { get; set; } + public int train_steps { get; set; } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs b/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs new file mode 100644 index 00000000..34845786 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Inputs.cs @@ -0,0 +1,60 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Data; +using Tensorflow.Models.ObjectDetection.Protos; + +namespace Tensorflow.Models.ObjectDetection +{ + public class Inputs + { + ModelBuilder modelBuilder; + DatasetBuilder datasetBuilder; + + public Inputs() + { + modelBuilder = new ModelBuilder(); + datasetBuilder = new DatasetBuilder(); + } + + public Func create_train_input_fn(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) + { + Func _train_input_fn = () => + train_input(train_config, train_input_config, model_config); + + return _train_input_fn; + } + + /// + /// Returns `features` and `labels` tensor dictionaries for training. + /// + /// + /// + /// + /// + public DatasetV1Adapter train_input(TrainConfig train_config, InputReader train_input_config, DetectionModel model_config) + { + var arch = modelBuilder.build(model_config, true, true); + Func model_preprocess_fn = arch.preprocess; + + Func, (Dictionary, Dictionary) > transform_and_pad_input_data_fn = (tensor_dict) => + { + return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict)); + }; + + var dataset = datasetBuilder.build(train_input_config); + + return dataset; + } + + private Dictionary _get_features_dict(Dictionary input_dict) + { + throw new NotImplementedException("_get_features_dict"); + } + + private Dictionary _get_labels_dict(Dictionary input_dict) + { + throw new NotImplementedException("_get_labels_dict"); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs new file mode 100644 index 00000000..bdcfae76 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNFeatureExtractor.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection +{ + /// + /// Faster R-CNN Feature Extractor definition. + /// + public class FasterRCNNFeatureExtractor + { + bool _is_training; + int _first_stage_features_stride; + bool _reuse_weights = false; + float _weight_decay = 0.0f; + bool _train_batch_norm; + + public FasterRCNNFeatureExtractor(bool is_training, + int first_stage_features_stride, + bool batch_norm_trainable = false, + bool reuse_weights = false, + float weight_decay = 0.0f) + { + _is_training = is_training; + _first_stage_features_stride = first_stage_features_stride; + _train_batch_norm = (batch_norm_trainable && is_training); + _reuse_weights = reuse_weights; + _weight_decay = weight_decay; + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs new file mode 100644 index 00000000..e5e92161 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNInitArgs.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Models.ObjectDetection.Core; + +namespace Tensorflow.Models.ObjectDetection +{ + public class FasterRCNNInitArgs + { + public bool is_training { get; set; } + public int num_classes { get; set; } + public Func image_resizer_fn { get; set; } + public FasterRCNNFeatureExtractor feature_extractor { get; set; } + public int number_of_stage { get; set; } + public object first_stage_anchor_generator { get; set; } + public object first_stage_target_assigner { get; set; } + public int first_stage_atrous_rate { get; set; } + public int parallel_iterations { get; set; } = 16; + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs new file mode 100644 index 00000000..956960b0 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/MetaArchitectures/FasterRCNNMetaArch.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Models.ObjectDetection +{ + public class FasterRCNNMetaArch : Core.DetectionModel + { + FasterRCNNInitArgs _args; + + public FasterRCNNMetaArch(FasterRCNNInitArgs args) + { + _args = args; + } + + /// + /// Feature-extractor specific preprocessing. + /// + /// + /// + public (Tensor, Tensor) preprocess(Tensor inputs) + { + tf_with(tf.name_scope("Preprocessor"), delegate + { + var outputs = shape_utils.static_or_dynamic_map_fn( + (inputs1) => + { + return _args.image_resizer_fn(new Core.ResizeToRangeArgs + { + image = inputs1 + })[0]; + }, + elems: inputs, + dtypes: new[] { tf.float32, tf.int32 }, + parallel_iterations: _args.parallel_iterations); + }); + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs new file mode 100644 index 00000000..5611356b --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs @@ -0,0 +1,77 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using Tensorflow.Estimators; +using System.Linq; +using Tensorflow.Contrib.Train; +using Tensorflow.Models.ObjectDetection.Utils; +using Tensorflow.Data; + +namespace Tensorflow.Models.ObjectDetection +{ + public class ModelLib + { + Inputs inputs = new Inputs(); + + public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config, + HParams hparams = null, + string pipeline_config_path = null, + int train_steps = 0, + int sample_1_of_n_eval_examples = 0, + int sample_1_of_n_eval_on_train_examples = 1) + { + var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path); + + // Create the input functions for TRAIN/EVAL/PREDICT. + Func train_input_fn = inputs.create_train_input_fn(config.TrainConfig, config.TrainInputReader, config.Model); + + var eval_input_configs = config.EvalInputReader; + + var eval_input_fns = new Action[eval_input_configs.Count]; + var eval_input_names = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray(); + Action eval_on_train_input_fn = () => { }; + Action predict_input_fn = () => { }; + Action model_fn = () => { }; + var estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config); + + return new TrainAndEvalDict + { + estimator = estimator, + train_input_fn = train_input_fn, + eval_input_fns = eval_input_fns, + eval_input_names = eval_input_names, + eval_on_train_input_fn = eval_on_train_input_fn, + predict_input_fn = predict_input_fn, + train_steps = train_steps + }; + } + + public (TrainSpec, EvalSpec[]) create_train_and_eval_specs(Func train_input_fn, Action[] eval_input_fns, Action eval_on_train_input_fn, + Action predict_input_fn, int train_steps, bool eval_on_train_data = false, + string final_exporter_name = "Servo", string[] eval_spec_names = null) + { + var train_spec = tf.estimator.TrainSpec(input_fn: train_input_fn, max_steps: train_steps); + + if (eval_spec_names == null) + eval_spec_names = range(len(eval_input_fns)) + .Select(x => x.ToString()) + .ToArray(); + + var eval_specs = new List(); + foreach (var (index, (eval_spec_name, eval_input_fn)) in enumerate(zip(eval_spec_names, eval_input_fns).ToList())) + { + var exporter_name = index == 0 ? final_exporter_name : $"{final_exporter_name}_{eval_spec_name}"; + var exporter = tf.estimator.FinalExporter(name: exporter_name, serving_input_receiver_fn: predict_input_fn); + eval_specs.Add(tf.estimator.EvalSpec(name: eval_spec_name, + input_fn: eval_input_fn, + exporters: exporter)); + } + + if (eval_on_train_data) + throw new NotImplementedException(""); + + return (train_spec, eval_specs.ToArray()); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnet101FeatureExtractor.cs b/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnet101FeatureExtractor.cs new file mode 100644 index 00000000..75e21ade --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnet101FeatureExtractor.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using Tensorflow.Operations.Activation; +using Tensorflow.Models.Slim.Nets; + +namespace Tensorflow.Models.ObjectDetection +{ + /// + /// Faster R-CNN Resnet 101 feature extractor implementation. + /// + public class FasterRCNNResnet101FeatureExtractor : FasterRCNNResnetV1FeatureExtractor + { + public FasterRCNNResnet101FeatureExtractor(bool is_training, + int first_stage_features_stride, + bool batch_norm_trainable = false, + bool reuse_weights = false, + float weight_decay = 0.0f, + IActivation activation_fn = null) : base("resnet_v1_101", + ResNetV1.resnet_v1_101, + is_training, + first_stage_features_stride, + batch_norm_trainable: batch_norm_trainable, + reuse_weights: reuse_weights, + weight_decay: weight_decay, + activation_fn: activation_fn) + { + + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnetV1FeatureExtractor.cs b/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnetV1FeatureExtractor.cs new file mode 100644 index 00000000..e4a8351b --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Models/FasterRCNNResnetV1FeatureExtractor.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using Tensorflow.Operations.Activation; + +namespace Tensorflow.Models.ObjectDetection +{ + public class FasterRCNNResnetV1FeatureExtractor : FasterRCNNFeatureExtractor + { + public FasterRCNNResnetV1FeatureExtractor(string architecture, + Action resnet_model, + bool is_training, + int first_stage_features_stride, + bool batch_norm_trainable = false, + bool reuse_weights = false, + float weight_decay = 0.0f, + IActivation activation_fn = null) : base(is_training, + first_stage_features_stride, + batch_norm_trainable: batch_norm_trainable, + reuse_weights: reuse_weights, + weight_decay: weight_decay) + { + if (activation_fn == null) + activation_fn = tf.nn.relu(); + } + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config b/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config new file mode 100644 index 00000000..7458f4a5 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Models/faster_rcnn_resnet101_voc07.config @@ -0,0 +1,133 @@ +# Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset. +# Users should configure the fine_tune_checkpoint field in the train config as +# well as the label_map_path and input_path fields in the train_input_reader and +# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that +# should be configured. + +model { + faster_rcnn { + num_classes: 20 + image_resizer { + keep_aspect_ratio_resizer { + min_dimension: 600 + max_dimension: 1024 + } + } + feature_extractor { + type: 'faster_rcnn_resnet101' + first_stage_features_stride: 16 + } + first_stage_anchor_generator { + grid_anchor_generator { + scales: [0.25, 0.5, 1.0, 2.0] + aspect_ratios: [0.5, 1.0, 2.0] + height_stride: 16 + width_stride: 16 + } + } + first_stage_box_predictor_conv_hyperparams { + op: CONV + regularizer { + l2_regularizer { + weight: 0.0 + } + } + initializer { + truncated_normal_initializer { + stddev: 0.01 + } + } + } + first_stage_nms_score_threshold: 0.0 + first_stage_nms_iou_threshold: 0.7 + first_stage_max_proposals: 300 + first_stage_localization_loss_weight: 2.0 + first_stage_objectness_loss_weight: 1.0 + initial_crop_size: 14 + maxpool_kernel_size: 2 + maxpool_stride: 2 + second_stage_box_predictor { + mask_rcnn_box_predictor { + use_dropout: false + dropout_keep_probability: 1.0 + fc_hyperparams { + op: FC + regularizer { + l2_regularizer { + weight: 0.0 + } + } + initializer { + variance_scaling_initializer { + factor: 1.0 + uniform: true + mode: FAN_AVG + } + } + } + } + } + second_stage_post_processing { + batch_non_max_suppression { + score_threshold: 0.0 + iou_threshold: 0.6 + max_detections_per_class: 100 + max_total_detections: 300 + } + score_converter: SOFTMAX + } + second_stage_localization_loss_weight: 2.0 + second_stage_classification_loss_weight: 1.0 + } +} + +train_config: { + batch_size: 1 + optimizer { + momentum_optimizer: { + learning_rate: { + manual_step_learning_rate { + initial_learning_rate: 0.0001 + schedule { + step: 500000 + learning_rate: .00001 + } + schedule { + step: 700000 + learning_rate: .000001 + } + } + } + momentum_optimizer_value: 0.9 + } + use_moving_average: false + } + gradient_clipping_by_norm: 10.0 + fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt" + from_detection_checkpoint: true + num_steps: 800000 + data_augmentation_options { + random_horizontal_flip { + } + } +} + +train_input_reader: { + tf_record_input_reader { + input_path: "PATH_TO_BE_CONFIGURED/pascal_train.record" + } + label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" +} + +eval_config: { + num_examples: 4952 +} + +eval_input_reader: { + tf_record_input_reader { + input_path: "PATH_TO_BE_CONFIGURED/pascal_val.record" + } + label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" + shuffle: false + num_readers: 1 +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Predictors/ConvolutionalBoxPredictor.cs b/src/TensorFlowNET.Models/ObjectDetection/Predictors/ConvolutionalBoxPredictor.cs new file mode 100644 index 00000000..bd2f4114 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Predictors/ConvolutionalBoxPredictor.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.ObjectDetection +{ + public class ConvolutionalBoxPredictor + { + } +} diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/AnchorGenerator.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/AnchorGenerator.cs new file mode 100644 index 00000000..8a0e255c --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/AnchorGenerator.cs @@ -0,0 +1,343 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/anchor_generator.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/anchor_generator.proto + public static partial class AnchorGeneratorReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/anchor_generator.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static AnchorGeneratorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci5vYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9hbmNob3JfZ2VuZXJhdG9yLnBy", + "b3RvEhdvYmplY3RfZGV0ZWN0aW9uLnByb3Rvcxo8b2JqZWN0X2RldGVjdGlv", + "bi9wcm90b3MvZmxleGlibGVfZ3JpZF9hbmNob3JfZ2VuZXJhdG9yLnByb3Rv", + "GjNvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9ncmlkX2FuY2hvcl9nZW5lcmF0", + "b3IucHJvdG8aOW9iamVjdF9kZXRlY3Rpb24vcHJvdG9zL211bHRpc2NhbGVf", + "YW5jaG9yX2dlbmVyYXRvci5wcm90bxoyb2JqZWN0X2RldGVjdGlvbi9wcm90", + "b3Mvc3NkX2FuY2hvcl9nZW5lcmF0b3IucHJvdG8iggMKD0FuY2hvckdlbmVy", + "YXRvchJNChVncmlkX2FuY2hvcl9nZW5lcmF0b3IYASABKAsyLC5vYmplY3Rf", + "ZGV0ZWN0aW9uLnByb3Rvcy5HcmlkQW5jaG9yR2VuZXJhdG9ySAASSwoUc3Nk", + "X2FuY2hvcl9nZW5lcmF0b3IYAiABKAsyKy5vYmplY3RfZGV0ZWN0aW9uLnBy", + "b3Rvcy5Tc2RBbmNob3JHZW5lcmF0b3JIABJZChttdWx0aXNjYWxlX2FuY2hv", + "cl9nZW5lcmF0b3IYAyABKAsyMi5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5N", + "dWx0aXNjYWxlQW5jaG9yR2VuZXJhdG9ySAASXgoeZmxleGlibGVfZ3JpZF9h", + "bmNob3JfZ2VuZXJhdG9yGAQgASgLMjQub2JqZWN0X2RldGVjdGlvbi5wcm90", + "b3MuRmxleGlibGVHcmlkQW5jaG9yR2VuZXJhdG9ySABCGAoWYW5jaG9yX2dl", + "bmVyYXRvcl9vbmVvZmIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGeneratorReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.GridAnchorGeneratorReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.MultiscaleAnchorGeneratorReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.SsdAnchorGeneratorReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator), global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator.Parser, new[]{ "GridAnchorGenerator", "SsdAnchorGenerator", "MultiscaleAnchorGenerator", "FlexibleGridAnchorGenerator" }, new[]{ "AnchorGeneratorOneof" }, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for the anchor generator to use in the object detection + /// pipeline. See core/anchor_generator.py for details. + /// + public sealed partial class AnchorGenerator : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AnchorGenerator()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.AnchorGeneratorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnchorGenerator() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnchorGenerator(AnchorGenerator other) : this() { + switch (other.AnchorGeneratorOneofCase) { + case AnchorGeneratorOneofOneofCase.GridAnchorGenerator: + GridAnchorGenerator = other.GridAnchorGenerator.Clone(); + break; + case AnchorGeneratorOneofOneofCase.SsdAnchorGenerator: + SsdAnchorGenerator = other.SsdAnchorGenerator.Clone(); + break; + case AnchorGeneratorOneofOneofCase.MultiscaleAnchorGenerator: + MultiscaleAnchorGenerator = other.MultiscaleAnchorGenerator.Clone(); + break; + case AnchorGeneratorOneofOneofCase.FlexibleGridAnchorGenerator: + FlexibleGridAnchorGenerator = other.FlexibleGridAnchorGenerator.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnchorGenerator Clone() { + return new AnchorGenerator(this); + } + + /// Field number for the "grid_anchor_generator" field. + public const int GridAnchorGeneratorFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.GridAnchorGenerator GridAnchorGenerator { + get { return anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.GridAnchorGenerator ? (global::Tensorflow.Models.ObjectDetection.Protos.GridAnchorGenerator) anchorGeneratorOneof_ : null; } + set { + anchorGeneratorOneof_ = value; + anchorGeneratorOneofCase_ = value == null ? AnchorGeneratorOneofOneofCase.None : AnchorGeneratorOneofOneofCase.GridAnchorGenerator; + } + } + + /// Field number for the "ssd_anchor_generator" field. + public const int SsdAnchorGeneratorFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SsdAnchorGenerator SsdAnchorGenerator { + get { return anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.SsdAnchorGenerator ? (global::Tensorflow.Models.ObjectDetection.Protos.SsdAnchorGenerator) anchorGeneratorOneof_ : null; } + set { + anchorGeneratorOneof_ = value; + anchorGeneratorOneofCase_ = value == null ? AnchorGeneratorOneofOneofCase.None : AnchorGeneratorOneofOneofCase.SsdAnchorGenerator; + } + } + + /// Field number for the "multiscale_anchor_generator" field. + public const int MultiscaleAnchorGeneratorFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.MultiscaleAnchorGenerator MultiscaleAnchorGenerator { + get { return anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.MultiscaleAnchorGenerator ? (global::Tensorflow.Models.ObjectDetection.Protos.MultiscaleAnchorGenerator) anchorGeneratorOneof_ : null; } + set { + anchorGeneratorOneof_ = value; + anchorGeneratorOneofCase_ = value == null ? AnchorGeneratorOneofOneofCase.None : AnchorGeneratorOneofOneofCase.MultiscaleAnchorGenerator; + } + } + + /// Field number for the "flexible_grid_anchor_generator" field. + public const int FlexibleGridAnchorGeneratorFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGenerator FlexibleGridAnchorGenerator { + get { return anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.FlexibleGridAnchorGenerator ? (global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGenerator) anchorGeneratorOneof_ : null; } + set { + anchorGeneratorOneof_ = value; + anchorGeneratorOneofCase_ = value == null ? AnchorGeneratorOneofOneofCase.None : AnchorGeneratorOneofOneofCase.FlexibleGridAnchorGenerator; + } + } + + private object anchorGeneratorOneof_; + /// Enum of possible cases for the "anchor_generator_oneof" oneof. + public enum AnchorGeneratorOneofOneofCase { + None = 0, + GridAnchorGenerator = 1, + SsdAnchorGenerator = 2, + MultiscaleAnchorGenerator = 3, + FlexibleGridAnchorGenerator = 4, + } + private AnchorGeneratorOneofOneofCase anchorGeneratorOneofCase_ = AnchorGeneratorOneofOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnchorGeneratorOneofOneofCase AnchorGeneratorOneofCase { + get { return anchorGeneratorOneofCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearAnchorGeneratorOneof() { + anchorGeneratorOneofCase_ = AnchorGeneratorOneofOneofCase.None; + anchorGeneratorOneof_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AnchorGenerator); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AnchorGenerator other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(GridAnchorGenerator, other.GridAnchorGenerator)) return false; + if (!object.Equals(SsdAnchorGenerator, other.SsdAnchorGenerator)) return false; + if (!object.Equals(MultiscaleAnchorGenerator, other.MultiscaleAnchorGenerator)) return false; + if (!object.Equals(FlexibleGridAnchorGenerator, other.FlexibleGridAnchorGenerator)) return false; + if (AnchorGeneratorOneofCase != other.AnchorGeneratorOneofCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.GridAnchorGenerator) hash ^= GridAnchorGenerator.GetHashCode(); + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.SsdAnchorGenerator) hash ^= SsdAnchorGenerator.GetHashCode(); + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.MultiscaleAnchorGenerator) hash ^= MultiscaleAnchorGenerator.GetHashCode(); + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.FlexibleGridAnchorGenerator) hash ^= FlexibleGridAnchorGenerator.GetHashCode(); + hash ^= (int) anchorGeneratorOneofCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.GridAnchorGenerator) { + output.WriteRawTag(10); + output.WriteMessage(GridAnchorGenerator); + } + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.SsdAnchorGenerator) { + output.WriteRawTag(18); + output.WriteMessage(SsdAnchorGenerator); + } + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.MultiscaleAnchorGenerator) { + output.WriteRawTag(26); + output.WriteMessage(MultiscaleAnchorGenerator); + } + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.FlexibleGridAnchorGenerator) { + output.WriteRawTag(34); + output.WriteMessage(FlexibleGridAnchorGenerator); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.GridAnchorGenerator) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(GridAnchorGenerator); + } + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.SsdAnchorGenerator) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SsdAnchorGenerator); + } + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.MultiscaleAnchorGenerator) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MultiscaleAnchorGenerator); + } + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.FlexibleGridAnchorGenerator) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FlexibleGridAnchorGenerator); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AnchorGenerator other) { + if (other == null) { + return; + } + switch (other.AnchorGeneratorOneofCase) { + case AnchorGeneratorOneofOneofCase.GridAnchorGenerator: + if (GridAnchorGenerator == null) { + GridAnchorGenerator = new global::Tensorflow.Models.ObjectDetection.Protos.GridAnchorGenerator(); + } + GridAnchorGenerator.MergeFrom(other.GridAnchorGenerator); + break; + case AnchorGeneratorOneofOneofCase.SsdAnchorGenerator: + if (SsdAnchorGenerator == null) { + SsdAnchorGenerator = new global::Tensorflow.Models.ObjectDetection.Protos.SsdAnchorGenerator(); + } + SsdAnchorGenerator.MergeFrom(other.SsdAnchorGenerator); + break; + case AnchorGeneratorOneofOneofCase.MultiscaleAnchorGenerator: + if (MultiscaleAnchorGenerator == null) { + MultiscaleAnchorGenerator = new global::Tensorflow.Models.ObjectDetection.Protos.MultiscaleAnchorGenerator(); + } + MultiscaleAnchorGenerator.MergeFrom(other.MultiscaleAnchorGenerator); + break; + case AnchorGeneratorOneofOneofCase.FlexibleGridAnchorGenerator: + if (FlexibleGridAnchorGenerator == null) { + FlexibleGridAnchorGenerator = new global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGenerator(); + } + FlexibleGridAnchorGenerator.MergeFrom(other.FlexibleGridAnchorGenerator); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.GridAnchorGenerator subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.GridAnchorGenerator(); + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.GridAnchorGenerator) { + subBuilder.MergeFrom(GridAnchorGenerator); + } + input.ReadMessage(subBuilder); + GridAnchorGenerator = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.SsdAnchorGenerator subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.SsdAnchorGenerator(); + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.SsdAnchorGenerator) { + subBuilder.MergeFrom(SsdAnchorGenerator); + } + input.ReadMessage(subBuilder); + SsdAnchorGenerator = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.MultiscaleAnchorGenerator subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.MultiscaleAnchorGenerator(); + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.MultiscaleAnchorGenerator) { + subBuilder.MergeFrom(MultiscaleAnchorGenerator); + } + input.ReadMessage(subBuilder); + MultiscaleAnchorGenerator = subBuilder; + break; + } + case 34: { + global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGenerator subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGenerator(); + if (anchorGeneratorOneofCase_ == AnchorGeneratorOneofOneofCase.FlexibleGridAnchorGenerator) { + subBuilder.MergeFrom(FlexibleGridAnchorGenerator); + } + input.ReadMessage(subBuilder); + FlexibleGridAnchorGenerator = subBuilder; + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/ArgmaxMatcher.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/ArgmaxMatcher.cs new file mode 100644 index 00000000..c967c034 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/ArgmaxMatcher.cs @@ -0,0 +1,343 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/argmax_matcher.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/argmax_matcher.proto + public static partial class ArgmaxMatcherReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/argmax_matcher.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ArgmaxMatcherReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CixvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9hcmdtYXhfbWF0Y2hlci5wcm90", + "bxIXb2JqZWN0X2RldGVjdGlvbi5wcm90b3MixwEKDUFyZ01heE1hdGNoZXIS", + "GQoRbWF0Y2hlZF90aHJlc2hvbGQYASABKAISGwoTdW5tYXRjaGVkX3RocmVz", + "aG9sZBgCIAEoAhIZChFpZ25vcmVfdGhyZXNob2xkcxgDIAEoCBImCh5uZWdh", + "dGl2ZXNfbG93ZXJfdGhhbl91bm1hdGNoZWQYBCABKAgSIAoYZm9yY2VfbWF0", + "Y2hfZm9yX2VhY2hfcm93GAUgASgIEhkKEXVzZV9tYXRtdWxfZ2F0aGVyGAYg", + "ASgIYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ArgMaxMatcher), global::Tensorflow.Models.ObjectDetection.Protos.ArgMaxMatcher.Parser, new[]{ "MatchedThreshold", "UnmatchedThreshold", "IgnoreThresholds", "NegativesLowerThanUnmatched", "ForceMatchForEachRow", "UseMatmulGather" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for ArgMaxMatcher. See + /// matchers/argmax_matcher.py for details. + /// + public sealed partial class ArgMaxMatcher : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ArgMaxMatcher()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.ArgmaxMatcherReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ArgMaxMatcher() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ArgMaxMatcher(ArgMaxMatcher other) : this() { + matchedThreshold_ = other.matchedThreshold_; + unmatchedThreshold_ = other.unmatchedThreshold_; + ignoreThresholds_ = other.ignoreThresholds_; + negativesLowerThanUnmatched_ = other.negativesLowerThanUnmatched_; + forceMatchForEachRow_ = other.forceMatchForEachRow_; + useMatmulGather_ = other.useMatmulGather_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ArgMaxMatcher Clone() { + return new ArgMaxMatcher(this); + } + + /// Field number for the "matched_threshold" field. + public const int MatchedThresholdFieldNumber = 1; + private float matchedThreshold_; + /// + /// Threshold for positive matches. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MatchedThreshold { + get { return matchedThreshold_; } + set { + matchedThreshold_ = value; + } + } + + /// Field number for the "unmatched_threshold" field. + public const int UnmatchedThresholdFieldNumber = 2; + private float unmatchedThreshold_; + /// + /// Threshold for negative matches. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float UnmatchedThreshold { + get { return unmatchedThreshold_; } + set { + unmatchedThreshold_ = value; + } + } + + /// Field number for the "ignore_thresholds" field. + public const int IgnoreThresholdsFieldNumber = 3; + private bool ignoreThresholds_; + /// + /// Whether to construct ArgMaxMatcher without thresholds. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IgnoreThresholds { + get { return ignoreThresholds_; } + set { + ignoreThresholds_ = value; + } + } + + /// Field number for the "negatives_lower_than_unmatched" field. + public const int NegativesLowerThanUnmatchedFieldNumber = 4; + private bool negativesLowerThanUnmatched_; + /// + /// If True then negative matches are the ones below the unmatched_threshold, + /// whereas ignored matches are in between the matched and umatched + /// threshold. If False, then negative matches are in between the matched + /// and unmatched threshold, and everything lower than unmatched is ignored. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool NegativesLowerThanUnmatched { + get { return negativesLowerThanUnmatched_; } + set { + negativesLowerThanUnmatched_ = value; + } + } + + /// Field number for the "force_match_for_each_row" field. + public const int ForceMatchForEachRowFieldNumber = 5; + private bool forceMatchForEachRow_; + /// + /// Whether to ensure each row is matched to at least one column. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ForceMatchForEachRow { + get { return forceMatchForEachRow_; } + set { + forceMatchForEachRow_ = value; + } + } + + /// Field number for the "use_matmul_gather" field. + public const int UseMatmulGatherFieldNumber = 6; + private bool useMatmulGather_; + /// + /// Force constructed match objects to use matrix multiplication based gather + /// instead of standard tf.gather + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseMatmulGather { + get { return useMatmulGather_; } + set { + useMatmulGather_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ArgMaxMatcher); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ArgMaxMatcher other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MatchedThreshold, other.MatchedThreshold)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(UnmatchedThreshold, other.UnmatchedThreshold)) return false; + if (IgnoreThresholds != other.IgnoreThresholds) return false; + if (NegativesLowerThanUnmatched != other.NegativesLowerThanUnmatched) return false; + if (ForceMatchForEachRow != other.ForceMatchForEachRow) return false; + if (UseMatmulGather != other.UseMatmulGather) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MatchedThreshold != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MatchedThreshold); + if (UnmatchedThreshold != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(UnmatchedThreshold); + if (IgnoreThresholds != false) hash ^= IgnoreThresholds.GetHashCode(); + if (NegativesLowerThanUnmatched != false) hash ^= NegativesLowerThanUnmatched.GetHashCode(); + if (ForceMatchForEachRow != false) hash ^= ForceMatchForEachRow.GetHashCode(); + if (UseMatmulGather != false) hash ^= UseMatmulGather.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MatchedThreshold != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MatchedThreshold); + } + if (UnmatchedThreshold != 0F) { + output.WriteRawTag(21); + output.WriteFloat(UnmatchedThreshold); + } + if (IgnoreThresholds != false) { + output.WriteRawTag(24); + output.WriteBool(IgnoreThresholds); + } + if (NegativesLowerThanUnmatched != false) { + output.WriteRawTag(32); + output.WriteBool(NegativesLowerThanUnmatched); + } + if (ForceMatchForEachRow != false) { + output.WriteRawTag(40); + output.WriteBool(ForceMatchForEachRow); + } + if (UseMatmulGather != false) { + output.WriteRawTag(48); + output.WriteBool(UseMatmulGather); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MatchedThreshold != 0F) { + size += 1 + 4; + } + if (UnmatchedThreshold != 0F) { + size += 1 + 4; + } + if (IgnoreThresholds != false) { + size += 1 + 1; + } + if (NegativesLowerThanUnmatched != false) { + size += 1 + 1; + } + if (ForceMatchForEachRow != false) { + size += 1 + 1; + } + if (UseMatmulGather != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ArgMaxMatcher other) { + if (other == null) { + return; + } + if (other.MatchedThreshold != 0F) { + MatchedThreshold = other.MatchedThreshold; + } + if (other.UnmatchedThreshold != 0F) { + UnmatchedThreshold = other.UnmatchedThreshold; + } + if (other.IgnoreThresholds != false) { + IgnoreThresholds = other.IgnoreThresholds; + } + if (other.NegativesLowerThanUnmatched != false) { + NegativesLowerThanUnmatched = other.NegativesLowerThanUnmatched; + } + if (other.ForceMatchForEachRow != false) { + ForceMatchForEachRow = other.ForceMatchForEachRow; + } + if (other.UseMatmulGather != false) { + UseMatmulGather = other.UseMatmulGather; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MatchedThreshold = input.ReadFloat(); + break; + } + case 21: { + UnmatchedThreshold = input.ReadFloat(); + break; + } + case 24: { + IgnoreThresholds = input.ReadBool(); + break; + } + case 32: { + NegativesLowerThanUnmatched = input.ReadBool(); + break; + } + case 40: { + ForceMatchForEachRow = input.ReadBool(); + break; + } + case 48: { + UseMatmulGather = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/BipartiteMatcher.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/BipartiteMatcher.cs new file mode 100644 index 00000000..5ead27f5 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/BipartiteMatcher.cs @@ -0,0 +1,181 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/bipartite_matcher.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/bipartite_matcher.proto + public static partial class BipartiteMatcherReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/bipartite_matcher.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static BipartiteMatcherReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci9vYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9iaXBhcnRpdGVfbWF0Y2hlci5w", + "cm90bxIXb2JqZWN0X2RldGVjdGlvbi5wcm90b3MiLQoQQmlwYXJ0aXRlTWF0", + "Y2hlchIZChF1c2VfbWF0bXVsX2dhdGhlchgGIAEoCGIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.BipartiteMatcher), global::Tensorflow.Models.ObjectDetection.Protos.BipartiteMatcher.Parser, new[]{ "UseMatmulGather" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for bipartite matcher. See + /// matchers/bipartite_matcher.py for details. + /// + public sealed partial class BipartiteMatcher : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BipartiteMatcher()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.BipartiteMatcherReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BipartiteMatcher() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BipartiteMatcher(BipartiteMatcher other) : this() { + useMatmulGather_ = other.useMatmulGather_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BipartiteMatcher Clone() { + return new BipartiteMatcher(this); + } + + /// Field number for the "use_matmul_gather" field. + public const int UseMatmulGatherFieldNumber = 6; + private bool useMatmulGather_; + /// + /// Force constructed match objects to use matrix multiplication based gather + /// instead of standard tf.gather + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseMatmulGather { + get { return useMatmulGather_; } + set { + useMatmulGather_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BipartiteMatcher); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BipartiteMatcher other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (UseMatmulGather != other.UseMatmulGather) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (UseMatmulGather != false) hash ^= UseMatmulGather.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (UseMatmulGather != false) { + output.WriteRawTag(48); + output.WriteBool(UseMatmulGather); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (UseMatmulGather != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BipartiteMatcher other) { + if (other == null) { + return; + } + if (other.UseMatmulGather != false) { + UseMatmulGather = other.UseMatmulGather; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 48: { + UseMatmulGather = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/BoxCoder.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/BoxCoder.cs new file mode 100644 index 00000000..d13d3326 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/BoxCoder.cs @@ -0,0 +1,341 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/box_coder.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/box_coder.proto + public static partial class BoxCoderReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/box_coder.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static BoxCoderReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CidvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9ib3hfY29kZXIucHJvdG8SF29i", + "amVjdF9kZXRlY3Rpb24ucHJvdG9zGjNvYmplY3RfZGV0ZWN0aW9uL3Byb3Rv", + "cy9mYXN0ZXJfcmNubl9ib3hfY29kZXIucHJvdG8aMG9iamVjdF9kZXRlY3Rp", + "b24vcHJvdG9zL2tleXBvaW50X2JveF9jb2Rlci5wcm90bxozb2JqZWN0X2Rl", + "dGVjdGlvbi9wcm90b3MvbWVhbl9zdGRkZXZfYm94X2NvZGVyLnByb3RvGi5v", + "YmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9zcXVhcmVfYm94X2NvZGVyLnByb3Rv", + "IscCCghCb3hDb2RlchJMChVmYXN0ZXJfcmNubl9ib3hfY29kZXIYASABKAsy", + "Ky5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5GYXN0ZXJSY25uQm94Q29kZXJI", + "ABJMChVtZWFuX3N0ZGRldl9ib3hfY29kZXIYAiABKAsyKy5vYmplY3RfZGV0", + "ZWN0aW9uLnByb3Rvcy5NZWFuU3RkZGV2Qm94Q29kZXJIABJDChBzcXVhcmVf", + "Ym94X2NvZGVyGAMgASgLMicub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuU3F1", + "YXJlQm94Q29kZXJIABJHChJrZXlwb2ludF9ib3hfY29kZXIYBCABKAsyKS5v", + "YmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5LZXlwb2ludEJveENvZGVySABCEQoP", + "Ym94X2NvZGVyX29uZW9mYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnBoxCoderReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.KeypointBoxCoderReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.MeanStddevBoxCoderReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.SquareBoxCoderReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.BoxCoder), global::Tensorflow.Models.ObjectDetection.Protos.BoxCoder.Parser, new[]{ "FasterRcnnBoxCoder", "MeanStddevBoxCoder", "SquareBoxCoder", "KeypointBoxCoder" }, new[]{ "BoxCoderOneof" }, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for the box coder to be used in the object detection + /// pipeline. See core/box_coder.py for details. + /// + public sealed partial class BoxCoder : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BoxCoder()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.BoxCoderReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxCoder() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxCoder(BoxCoder other) : this() { + switch (other.BoxCoderOneofCase) { + case BoxCoderOneofOneofCase.FasterRcnnBoxCoder: + FasterRcnnBoxCoder = other.FasterRcnnBoxCoder.Clone(); + break; + case BoxCoderOneofOneofCase.MeanStddevBoxCoder: + MeanStddevBoxCoder = other.MeanStddevBoxCoder.Clone(); + break; + case BoxCoderOneofOneofCase.SquareBoxCoder: + SquareBoxCoder = other.SquareBoxCoder.Clone(); + break; + case BoxCoderOneofOneofCase.KeypointBoxCoder: + KeypointBoxCoder = other.KeypointBoxCoder.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxCoder Clone() { + return new BoxCoder(this); + } + + /// Field number for the "faster_rcnn_box_coder" field. + public const int FasterRcnnBoxCoderFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnBoxCoder FasterRcnnBoxCoder { + get { return boxCoderOneofCase_ == BoxCoderOneofOneofCase.FasterRcnnBoxCoder ? (global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnBoxCoder) boxCoderOneof_ : null; } + set { + boxCoderOneof_ = value; + boxCoderOneofCase_ = value == null ? BoxCoderOneofOneofCase.None : BoxCoderOneofOneofCase.FasterRcnnBoxCoder; + } + } + + /// Field number for the "mean_stddev_box_coder" field. + public const int MeanStddevBoxCoderFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.MeanStddevBoxCoder MeanStddevBoxCoder { + get { return boxCoderOneofCase_ == BoxCoderOneofOneofCase.MeanStddevBoxCoder ? (global::Tensorflow.Models.ObjectDetection.Protos.MeanStddevBoxCoder) boxCoderOneof_ : null; } + set { + boxCoderOneof_ = value; + boxCoderOneofCase_ = value == null ? BoxCoderOneofOneofCase.None : BoxCoderOneofOneofCase.MeanStddevBoxCoder; + } + } + + /// Field number for the "square_box_coder" field. + public const int SquareBoxCoderFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SquareBoxCoder SquareBoxCoder { + get { return boxCoderOneofCase_ == BoxCoderOneofOneofCase.SquareBoxCoder ? (global::Tensorflow.Models.ObjectDetection.Protos.SquareBoxCoder) boxCoderOneof_ : null; } + set { + boxCoderOneof_ = value; + boxCoderOneofCase_ = value == null ? BoxCoderOneofOneofCase.None : BoxCoderOneofOneofCase.SquareBoxCoder; + } + } + + /// Field number for the "keypoint_box_coder" field. + public const int KeypointBoxCoderFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.KeypointBoxCoder KeypointBoxCoder { + get { return boxCoderOneofCase_ == BoxCoderOneofOneofCase.KeypointBoxCoder ? (global::Tensorflow.Models.ObjectDetection.Protos.KeypointBoxCoder) boxCoderOneof_ : null; } + set { + boxCoderOneof_ = value; + boxCoderOneofCase_ = value == null ? BoxCoderOneofOneofCase.None : BoxCoderOneofOneofCase.KeypointBoxCoder; + } + } + + private object boxCoderOneof_; + /// Enum of possible cases for the "box_coder_oneof" oneof. + public enum BoxCoderOneofOneofCase { + None = 0, + FasterRcnnBoxCoder = 1, + MeanStddevBoxCoder = 2, + SquareBoxCoder = 3, + KeypointBoxCoder = 4, + } + private BoxCoderOneofOneofCase boxCoderOneofCase_ = BoxCoderOneofOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxCoderOneofOneofCase BoxCoderOneofCase { + get { return boxCoderOneofCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearBoxCoderOneof() { + boxCoderOneofCase_ = BoxCoderOneofOneofCase.None; + boxCoderOneof_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BoxCoder); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BoxCoder other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(FasterRcnnBoxCoder, other.FasterRcnnBoxCoder)) return false; + if (!object.Equals(MeanStddevBoxCoder, other.MeanStddevBoxCoder)) return false; + if (!object.Equals(SquareBoxCoder, other.SquareBoxCoder)) return false; + if (!object.Equals(KeypointBoxCoder, other.KeypointBoxCoder)) return false; + if (BoxCoderOneofCase != other.BoxCoderOneofCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.FasterRcnnBoxCoder) hash ^= FasterRcnnBoxCoder.GetHashCode(); + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.MeanStddevBoxCoder) hash ^= MeanStddevBoxCoder.GetHashCode(); + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.SquareBoxCoder) hash ^= SquareBoxCoder.GetHashCode(); + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.KeypointBoxCoder) hash ^= KeypointBoxCoder.GetHashCode(); + hash ^= (int) boxCoderOneofCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.FasterRcnnBoxCoder) { + output.WriteRawTag(10); + output.WriteMessage(FasterRcnnBoxCoder); + } + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.MeanStddevBoxCoder) { + output.WriteRawTag(18); + output.WriteMessage(MeanStddevBoxCoder); + } + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.SquareBoxCoder) { + output.WriteRawTag(26); + output.WriteMessage(SquareBoxCoder); + } + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.KeypointBoxCoder) { + output.WriteRawTag(34); + output.WriteMessage(KeypointBoxCoder); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.FasterRcnnBoxCoder) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FasterRcnnBoxCoder); + } + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.MeanStddevBoxCoder) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MeanStddevBoxCoder); + } + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.SquareBoxCoder) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SquareBoxCoder); + } + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.KeypointBoxCoder) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(KeypointBoxCoder); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BoxCoder other) { + if (other == null) { + return; + } + switch (other.BoxCoderOneofCase) { + case BoxCoderOneofOneofCase.FasterRcnnBoxCoder: + if (FasterRcnnBoxCoder == null) { + FasterRcnnBoxCoder = new global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnBoxCoder(); + } + FasterRcnnBoxCoder.MergeFrom(other.FasterRcnnBoxCoder); + break; + case BoxCoderOneofOneofCase.MeanStddevBoxCoder: + if (MeanStddevBoxCoder == null) { + MeanStddevBoxCoder = new global::Tensorflow.Models.ObjectDetection.Protos.MeanStddevBoxCoder(); + } + MeanStddevBoxCoder.MergeFrom(other.MeanStddevBoxCoder); + break; + case BoxCoderOneofOneofCase.SquareBoxCoder: + if (SquareBoxCoder == null) { + SquareBoxCoder = new global::Tensorflow.Models.ObjectDetection.Protos.SquareBoxCoder(); + } + SquareBoxCoder.MergeFrom(other.SquareBoxCoder); + break; + case BoxCoderOneofOneofCase.KeypointBoxCoder: + if (KeypointBoxCoder == null) { + KeypointBoxCoder = new global::Tensorflow.Models.ObjectDetection.Protos.KeypointBoxCoder(); + } + KeypointBoxCoder.MergeFrom(other.KeypointBoxCoder); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnBoxCoder subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnBoxCoder(); + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.FasterRcnnBoxCoder) { + subBuilder.MergeFrom(FasterRcnnBoxCoder); + } + input.ReadMessage(subBuilder); + FasterRcnnBoxCoder = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.MeanStddevBoxCoder subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.MeanStddevBoxCoder(); + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.MeanStddevBoxCoder) { + subBuilder.MergeFrom(MeanStddevBoxCoder); + } + input.ReadMessage(subBuilder); + MeanStddevBoxCoder = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.SquareBoxCoder subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.SquareBoxCoder(); + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.SquareBoxCoder) { + subBuilder.MergeFrom(SquareBoxCoder); + } + input.ReadMessage(subBuilder); + SquareBoxCoder = subBuilder; + break; + } + case 34: { + global::Tensorflow.Models.ObjectDetection.Protos.KeypointBoxCoder subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.KeypointBoxCoder(); + if (boxCoderOneofCase_ == BoxCoderOneofOneofCase.KeypointBoxCoder) { + subBuilder.MergeFrom(KeypointBoxCoder); + } + input.ReadMessage(subBuilder); + KeypointBoxCoder = subBuilder; + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/BoxPredictor.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/BoxPredictor.cs new file mode 100644 index 00000000..382c0d21 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/BoxPredictor.cs @@ -0,0 +1,2587 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/box_predictor.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/box_predictor.proto + public static partial class BoxPredictorReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/box_predictor.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static BoxPredictorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CitvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9ib3hfcHJlZGljdG9yLnByb3Rv", + "EhdvYmplY3RfZGV0ZWN0aW9uLnByb3Rvcxopb2JqZWN0X2RldGVjdGlvbi9w", + "cm90b3MvaHlwZXJwYXJhbXMucHJvdG8ikAMKDEJveFByZWRpY3RvchJZChtj", + "b252b2x1dGlvbmFsX2JveF9wcmVkaWN0b3IYASABKAsyMi5vYmplY3RfZGV0", + "ZWN0aW9uLnByb3Rvcy5Db252b2x1dGlvbmFsQm94UHJlZGljdG9ySAASUAoX", + "bWFza19yY25uX2JveF9wcmVkaWN0b3IYAiABKAsyLS5vYmplY3RfZGV0ZWN0", + "aW9uLnByb3Rvcy5NYXNrUkNOTkJveFByZWRpY3RvckgAEkcKEnJmY25fYm94", + "X3ByZWRpY3RvchgDIAEoCzIpLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlJm", + "Y25Cb3hQcmVkaWN0b3JIABJzCil3ZWlnaHRfc2hhcmVkX2NvbnZvbHV0aW9u", + "YWxfYm94X3ByZWRpY3RvchgEIAEoCzI+Lm9iamVjdF9kZXRlY3Rpb24ucHJv", + "dG9zLldlaWdodFNoYXJlZENvbnZvbHV0aW9uYWxCb3hQcmVkaWN0b3JIAEIV", + "ChNib3hfcHJlZGljdG9yX29uZW9mIoQEChlDb252b2x1dGlvbmFsQm94UHJl", + "ZGljdG9yEj4KEGNvbnZfaHlwZXJwYXJhbXMYASABKAsyJC5vYmplY3RfZGV0", + "ZWN0aW9uLnByb3Rvcy5IeXBlcnBhcmFtcxIRCgltaW5fZGVwdGgYAiABKAUS", + "EQoJbWF4X2RlcHRoGAMgASgFEiMKG251bV9sYXllcnNfYmVmb3JlX3ByZWRp", + "Y3RvchgEIAEoBRITCgt1c2VfZHJvcG91dBgFIAEoCBIgChhkcm9wb3V0X2tl", + "ZXBfcHJvYmFiaWxpdHkYBiABKAISEwoLa2VybmVsX3NpemUYByABKAUSFQoN", + "Ym94X2NvZGVfc2l6ZRgIIAEoBRIfChdhcHBseV9zaWdtb2lkX3RvX3Njb3Jl", + "cxgJIAEoCBIiChpjbGFzc19wcmVkaWN0aW9uX2JpYXNfaW5pdBgKIAEoAhIV", + "Cg11c2VfZGVwdGh3aXNlGAsgASgIEmoKGGJveF9lbmNvZGluZ3NfY2xpcF9y", + "YW5nZRgMIAEoCzJILm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLkNvbnZvbHV0", + "aW9uYWxCb3hQcmVkaWN0b3IuQm94RW5jb2RpbmdzQ2xpcFJhbmdlGjEKFUJv", + "eEVuY29kaW5nc0NsaXBSYW5nZRILCgNtaW4YASABKAISCwoDbWF4GAIgASgC", + "IpkFCiVXZWlnaHRTaGFyZWRDb252b2x1dGlvbmFsQm94UHJlZGljdG9yEj4K", + "EGNvbnZfaHlwZXJwYXJhbXMYASABKAsyJC5vYmplY3RfZGV0ZWN0aW9uLnBy", + "b3Rvcy5IeXBlcnBhcmFtcxIjChtudW1fbGF5ZXJzX2JlZm9yZV9wcmVkaWN0", + "b3IYBCABKAUSDQoFZGVwdGgYAiABKAUSEwoLa2VybmVsX3NpemUYByABKAUS", + "FQoNYm94X2NvZGVfc2l6ZRgIIAEoBRIiChpjbGFzc19wcmVkaWN0aW9uX2Jp", + "YXNfaW5pdBgKIAEoAhITCgt1c2VfZHJvcG91dBgLIAEoCBIgChhkcm9wb3V0", + "X2tlZXBfcHJvYmFiaWxpdHkYDCABKAISHgoWc2hhcmVfcHJlZGljdGlvbl90", + "b3dlchgNIAEoCBIVCg11c2VfZGVwdGh3aXNlGA4gASgIEmYKD3Njb3JlX2Nv", + "bnZlcnRlchgQIAEoDjJNLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLldlaWdo", + "dFNoYXJlZENvbnZvbHV0aW9uYWxCb3hQcmVkaWN0b3IuU2NvcmVDb252ZXJ0", + "ZXISdgoYYm94X2VuY29kaW5nc19jbGlwX3JhbmdlGBEgASgLMlQub2JqZWN0", + "X2RldGVjdGlvbi5wcm90b3MuV2VpZ2h0U2hhcmVkQ29udm9sdXRpb25hbEJv", + "eFByZWRpY3Rvci5Cb3hFbmNvZGluZ3NDbGlwUmFuZ2UaMQoVQm94RW5jb2Rp", + "bmdzQ2xpcFJhbmdlEgsKA21pbhgBIAEoAhILCgNtYXgYAiABKAIiKwoOU2Nv", + "cmVDb252ZXJ0ZXISDAoISURFTlRJVFkQABILCgdTSUdNT0lEEAEi/QMKFE1h", + "c2tSQ05OQm94UHJlZGljdG9yEjwKDmZjX2h5cGVycGFyYW1zGAEgASgLMiQu", + "b2JqZWN0X2RldGVjdGlvbi5wcm90b3MuSHlwZXJwYXJhbXMSEwoLdXNlX2Ry", + "b3BvdXQYAiABKAgSIAoYZHJvcG91dF9rZWVwX3Byb2JhYmlsaXR5GAMgASgC", + "EhUKDWJveF9jb2RlX3NpemUYBCABKAUSPgoQY29udl9oeXBlcnBhcmFtcxgF", + "IAEoCzIkLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLkh5cGVycGFyYW1zEh4K", + "FnByZWRpY3RfaW5zdGFuY2VfbWFza3MYBiABKAgSIgoabWFza19wcmVkaWN0", + "aW9uX2NvbnZfZGVwdGgYByABKAUSGQoRcHJlZGljdF9rZXlwb2ludHMYCCAB", + "KAgSEwoLbWFza19oZWlnaHQYCSABKAUSEgoKbWFza193aWR0aBgKIAEoBRIn", + "Ch9tYXNrX3ByZWRpY3Rpb25fbnVtX2NvbnZfbGF5ZXJzGAsgASgFEiAKGG1h", + "c2tzX2FyZV9jbGFzc19hZ25vc3RpYxgMIAEoCBIgChhzaGFyZV9ib3hfYWNy", + "b3NzX2NsYXNzZXMYDSABKAgSJAocY29udm9sdmVfdGhlbl91cHNhbXBsZV9t", + "YXNrcxgOIAEoCCLiAQoQUmZjbkJveFByZWRpY3RvchI+ChBjb252X2h5cGVy", + "cGFyYW1zGAEgASgLMiQub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuSHlwZXJw", + "YXJhbXMSHwoXbnVtX3NwYXRpYWxfYmluc19oZWlnaHQYAiABKAUSHgoWbnVt", + "X3NwYXRpYWxfYmluc193aWR0aBgDIAEoBRINCgVkZXB0aBgEIAEoBRIVCg1i", + "b3hfY29kZV9zaXplGAUgASgFEhMKC2Nyb3BfaGVpZ2h0GAYgASgFEhIKCmNy", + "b3Bfd2lkdGgYByABKAViBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor), global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor.Parser, new[]{ "ConvolutionalBoxPredictor", "MaskRcnnBoxPredictor", "RfcnBoxPredictor", "WeightSharedConvolutionalBoxPredictor" }, new[]{ "BoxPredictorOneof" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor), global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor.Parser, new[]{ "ConvHyperparams", "MinDepth", "MaxDepth", "NumLayersBeforePredictor", "UseDropout", "DropoutKeepProbability", "KernelSize", "BoxCodeSize", "ApplySigmoidToScores", "ClassPredictionBiasInit", "UseDepthwise", "BoxEncodingsClipRange" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor.Types.BoxEncodingsClipRange), global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor.Types.BoxEncodingsClipRange.Parser, new[]{ "Min", "Max" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor), global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Parser, new[]{ "ConvHyperparams", "NumLayersBeforePredictor", "Depth", "KernelSize", "BoxCodeSize", "ClassPredictionBiasInit", "UseDropout", "DropoutKeepProbability", "SharePredictionTower", "UseDepthwise", "ScoreConverter", "BoxEncodingsClipRange" }, null, new[]{ typeof(global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.ScoreConverter) }, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.BoxEncodingsClipRange), global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.BoxEncodingsClipRange.Parser, new[]{ "Min", "Max" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.MaskRCNNBoxPredictor), global::Tensorflow.Models.ObjectDetection.Protos.MaskRCNNBoxPredictor.Parser, new[]{ "FcHyperparams", "UseDropout", "DropoutKeepProbability", "BoxCodeSize", "ConvHyperparams", "PredictInstanceMasks", "MaskPredictionConvDepth", "PredictKeypoints", "MaskHeight", "MaskWidth", "MaskPredictionNumConvLayers", "MasksAreClassAgnostic", "ShareBoxAcrossClasses", "ConvolveThenUpsampleMasks" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RfcnBoxPredictor), global::Tensorflow.Models.ObjectDetection.Protos.RfcnBoxPredictor.Parser, new[]{ "ConvHyperparams", "NumSpatialBinsHeight", "NumSpatialBinsWidth", "Depth", "BoxCodeSize", "CropHeight", "CropWidth" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for box predictor. See core/box_predictor.py for details. + /// + public sealed partial class BoxPredictor : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BoxPredictor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxPredictor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxPredictor(BoxPredictor other) : this() { + switch (other.BoxPredictorOneofCase) { + case BoxPredictorOneofOneofCase.ConvolutionalBoxPredictor: + ConvolutionalBoxPredictor = other.ConvolutionalBoxPredictor.Clone(); + break; + case BoxPredictorOneofOneofCase.MaskRcnnBoxPredictor: + MaskRcnnBoxPredictor = other.MaskRcnnBoxPredictor.Clone(); + break; + case BoxPredictorOneofOneofCase.RfcnBoxPredictor: + RfcnBoxPredictor = other.RfcnBoxPredictor.Clone(); + break; + case BoxPredictorOneofOneofCase.WeightSharedConvolutionalBoxPredictor: + WeightSharedConvolutionalBoxPredictor = other.WeightSharedConvolutionalBoxPredictor.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxPredictor Clone() { + return new BoxPredictor(this); + } + + /// Field number for the "convolutional_box_predictor" field. + public const int ConvolutionalBoxPredictorFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor ConvolutionalBoxPredictor { + get { return boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.ConvolutionalBoxPredictor ? (global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor) boxPredictorOneof_ : null; } + set { + boxPredictorOneof_ = value; + boxPredictorOneofCase_ = value == null ? BoxPredictorOneofOneofCase.None : BoxPredictorOneofOneofCase.ConvolutionalBoxPredictor; + } + } + + /// Field number for the "mask_rcnn_box_predictor" field. + public const int MaskRcnnBoxPredictorFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.MaskRCNNBoxPredictor MaskRcnnBoxPredictor { + get { return boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.MaskRcnnBoxPredictor ? (global::Tensorflow.Models.ObjectDetection.Protos.MaskRCNNBoxPredictor) boxPredictorOneof_ : null; } + set { + boxPredictorOneof_ = value; + boxPredictorOneofCase_ = value == null ? BoxPredictorOneofOneofCase.None : BoxPredictorOneofOneofCase.MaskRcnnBoxPredictor; + } + } + + /// Field number for the "rfcn_box_predictor" field. + public const int RfcnBoxPredictorFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RfcnBoxPredictor RfcnBoxPredictor { + get { return boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.RfcnBoxPredictor ? (global::Tensorflow.Models.ObjectDetection.Protos.RfcnBoxPredictor) boxPredictorOneof_ : null; } + set { + boxPredictorOneof_ = value; + boxPredictorOneofCase_ = value == null ? BoxPredictorOneofOneofCase.None : BoxPredictorOneofOneofCase.RfcnBoxPredictor; + } + } + + /// Field number for the "weight_shared_convolutional_box_predictor" field. + public const int WeightSharedConvolutionalBoxPredictorFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor WeightSharedConvolutionalBoxPredictor { + get { return boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.WeightSharedConvolutionalBoxPredictor ? (global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor) boxPredictorOneof_ : null; } + set { + boxPredictorOneof_ = value; + boxPredictorOneofCase_ = value == null ? BoxPredictorOneofOneofCase.None : BoxPredictorOneofOneofCase.WeightSharedConvolutionalBoxPredictor; + } + } + + private object boxPredictorOneof_; + /// Enum of possible cases for the "box_predictor_oneof" oneof. + public enum BoxPredictorOneofOneofCase { + None = 0, + ConvolutionalBoxPredictor = 1, + MaskRcnnBoxPredictor = 2, + RfcnBoxPredictor = 3, + WeightSharedConvolutionalBoxPredictor = 4, + } + private BoxPredictorOneofOneofCase boxPredictorOneofCase_ = BoxPredictorOneofOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxPredictorOneofOneofCase BoxPredictorOneofCase { + get { return boxPredictorOneofCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearBoxPredictorOneof() { + boxPredictorOneofCase_ = BoxPredictorOneofOneofCase.None; + boxPredictorOneof_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BoxPredictor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BoxPredictor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(ConvolutionalBoxPredictor, other.ConvolutionalBoxPredictor)) return false; + if (!object.Equals(MaskRcnnBoxPredictor, other.MaskRcnnBoxPredictor)) return false; + if (!object.Equals(RfcnBoxPredictor, other.RfcnBoxPredictor)) return false; + if (!object.Equals(WeightSharedConvolutionalBoxPredictor, other.WeightSharedConvolutionalBoxPredictor)) return false; + if (BoxPredictorOneofCase != other.BoxPredictorOneofCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.ConvolutionalBoxPredictor) hash ^= ConvolutionalBoxPredictor.GetHashCode(); + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.MaskRcnnBoxPredictor) hash ^= MaskRcnnBoxPredictor.GetHashCode(); + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.RfcnBoxPredictor) hash ^= RfcnBoxPredictor.GetHashCode(); + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.WeightSharedConvolutionalBoxPredictor) hash ^= WeightSharedConvolutionalBoxPredictor.GetHashCode(); + hash ^= (int) boxPredictorOneofCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.ConvolutionalBoxPredictor) { + output.WriteRawTag(10); + output.WriteMessage(ConvolutionalBoxPredictor); + } + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.MaskRcnnBoxPredictor) { + output.WriteRawTag(18); + output.WriteMessage(MaskRcnnBoxPredictor); + } + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.RfcnBoxPredictor) { + output.WriteRawTag(26); + output.WriteMessage(RfcnBoxPredictor); + } + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.WeightSharedConvolutionalBoxPredictor) { + output.WriteRawTag(34); + output.WriteMessage(WeightSharedConvolutionalBoxPredictor); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.ConvolutionalBoxPredictor) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ConvolutionalBoxPredictor); + } + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.MaskRcnnBoxPredictor) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MaskRcnnBoxPredictor); + } + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.RfcnBoxPredictor) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RfcnBoxPredictor); + } + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.WeightSharedConvolutionalBoxPredictor) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WeightSharedConvolutionalBoxPredictor); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BoxPredictor other) { + if (other == null) { + return; + } + switch (other.BoxPredictorOneofCase) { + case BoxPredictorOneofOneofCase.ConvolutionalBoxPredictor: + if (ConvolutionalBoxPredictor == null) { + ConvolutionalBoxPredictor = new global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor(); + } + ConvolutionalBoxPredictor.MergeFrom(other.ConvolutionalBoxPredictor); + break; + case BoxPredictorOneofOneofCase.MaskRcnnBoxPredictor: + if (MaskRcnnBoxPredictor == null) { + MaskRcnnBoxPredictor = new global::Tensorflow.Models.ObjectDetection.Protos.MaskRCNNBoxPredictor(); + } + MaskRcnnBoxPredictor.MergeFrom(other.MaskRcnnBoxPredictor); + break; + case BoxPredictorOneofOneofCase.RfcnBoxPredictor: + if (RfcnBoxPredictor == null) { + RfcnBoxPredictor = new global::Tensorflow.Models.ObjectDetection.Protos.RfcnBoxPredictor(); + } + RfcnBoxPredictor.MergeFrom(other.RfcnBoxPredictor); + break; + case BoxPredictorOneofOneofCase.WeightSharedConvolutionalBoxPredictor: + if (WeightSharedConvolutionalBoxPredictor == null) { + WeightSharedConvolutionalBoxPredictor = new global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor(); + } + WeightSharedConvolutionalBoxPredictor.MergeFrom(other.WeightSharedConvolutionalBoxPredictor); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor(); + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.ConvolutionalBoxPredictor) { + subBuilder.MergeFrom(ConvolutionalBoxPredictor); + } + input.ReadMessage(subBuilder); + ConvolutionalBoxPredictor = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.MaskRCNNBoxPredictor subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.MaskRCNNBoxPredictor(); + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.MaskRcnnBoxPredictor) { + subBuilder.MergeFrom(MaskRcnnBoxPredictor); + } + input.ReadMessage(subBuilder); + MaskRcnnBoxPredictor = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.RfcnBoxPredictor subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RfcnBoxPredictor(); + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.RfcnBoxPredictor) { + subBuilder.MergeFrom(RfcnBoxPredictor); + } + input.ReadMessage(subBuilder); + RfcnBoxPredictor = subBuilder; + break; + } + case 34: { + global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor(); + if (boxPredictorOneofCase_ == BoxPredictorOneofOneofCase.WeightSharedConvolutionalBoxPredictor) { + subBuilder.MergeFrom(WeightSharedConvolutionalBoxPredictor); + } + input.ReadMessage(subBuilder); + WeightSharedConvolutionalBoxPredictor = subBuilder; + break; + } + } + } + } + + } + + /// + /// Configuration proto for Convolutional box predictor. + /// Next id: 13 + /// + public sealed partial class ConvolutionalBoxPredictor : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ConvolutionalBoxPredictor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictorReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConvolutionalBoxPredictor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConvolutionalBoxPredictor(ConvolutionalBoxPredictor other) : this() { + convHyperparams_ = other.convHyperparams_ != null ? other.convHyperparams_.Clone() : null; + minDepth_ = other.minDepth_; + maxDepth_ = other.maxDepth_; + numLayersBeforePredictor_ = other.numLayersBeforePredictor_; + useDropout_ = other.useDropout_; + dropoutKeepProbability_ = other.dropoutKeepProbability_; + kernelSize_ = other.kernelSize_; + boxCodeSize_ = other.boxCodeSize_; + applySigmoidToScores_ = other.applySigmoidToScores_; + classPredictionBiasInit_ = other.classPredictionBiasInit_; + useDepthwise_ = other.useDepthwise_; + boxEncodingsClipRange_ = other.boxEncodingsClipRange_ != null ? other.boxEncodingsClipRange_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConvolutionalBoxPredictor Clone() { + return new ConvolutionalBoxPredictor(this); + } + + /// Field number for the "conv_hyperparams" field. + public const int ConvHyperparamsFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams convHyperparams_; + /// + /// Hyperparameters for convolution ops used in the box predictor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams ConvHyperparams { + get { return convHyperparams_; } + set { + convHyperparams_ = value; + } + } + + /// Field number for the "min_depth" field. + public const int MinDepthFieldNumber = 2; + private int minDepth_; + /// + /// Minimum feature depth prior to predicting box encodings and class + /// predictions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinDepth { + get { return minDepth_; } + set { + minDepth_ = value; + } + } + + /// Field number for the "max_depth" field. + public const int MaxDepthFieldNumber = 3; + private int maxDepth_; + /// + /// Maximum feature depth prior to predicting box encodings and class + /// predictions. If max_depth is set to 0, no additional feature map will be + /// inserted before location and class predictions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxDepth { + get { return maxDepth_; } + set { + maxDepth_ = value; + } + } + + /// Field number for the "num_layers_before_predictor" field. + public const int NumLayersBeforePredictorFieldNumber = 4; + private int numLayersBeforePredictor_; + /// + /// Number of the additional conv layers before the predictor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumLayersBeforePredictor { + get { return numLayersBeforePredictor_; } + set { + numLayersBeforePredictor_ = value; + } + } + + /// Field number for the "use_dropout" field. + public const int UseDropoutFieldNumber = 5; + private bool useDropout_; + /// + /// Whether to use dropout for class prediction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseDropout { + get { return useDropout_; } + set { + useDropout_ = value; + } + } + + /// Field number for the "dropout_keep_probability" field. + public const int DropoutKeepProbabilityFieldNumber = 6; + private float dropoutKeepProbability_; + /// + /// Keep probability for dropout + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float DropoutKeepProbability { + get { return dropoutKeepProbability_; } + set { + dropoutKeepProbability_ = value; + } + } + + /// Field number for the "kernel_size" field. + public const int KernelSizeFieldNumber = 7; + private int kernelSize_; + /// + /// Size of final convolution kernel. If the spatial resolution of the feature + /// map is smaller than the kernel size, then the kernel size is set to + /// min(feature_width, feature_height). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int KernelSize { + get { return kernelSize_; } + set { + kernelSize_ = value; + } + } + + /// Field number for the "box_code_size" field. + public const int BoxCodeSizeFieldNumber = 8; + private int boxCodeSize_; + /// + /// Size of the encoding for boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int BoxCodeSize { + get { return boxCodeSize_; } + set { + boxCodeSize_ = value; + } + } + + /// Field number for the "apply_sigmoid_to_scores" field. + public const int ApplySigmoidToScoresFieldNumber = 9; + private bool applySigmoidToScores_; + /// + /// Whether to apply sigmoid to the output of class predictions. + /// TODO(jonathanhuang): Do we need this since we have a post processing + /// module.? + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ApplySigmoidToScores { + get { return applySigmoidToScores_; } + set { + applySigmoidToScores_ = value; + } + } + + /// Field number for the "class_prediction_bias_init" field. + public const int ClassPredictionBiasInitFieldNumber = 10; + private float classPredictionBiasInit_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float ClassPredictionBiasInit { + get { return classPredictionBiasInit_; } + set { + classPredictionBiasInit_ = value; + } + } + + /// Field number for the "use_depthwise" field. + public const int UseDepthwiseFieldNumber = 11; + private bool useDepthwise_; + /// + /// Whether to use depthwise separable convolution for box predictor layers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseDepthwise { + get { return useDepthwise_; } + set { + useDepthwise_ = value; + } + } + + /// Field number for the "box_encodings_clip_range" field. + public const int BoxEncodingsClipRangeFieldNumber = 12; + private global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor.Types.BoxEncodingsClipRange boxEncodingsClipRange_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor.Types.BoxEncodingsClipRange BoxEncodingsClipRange { + get { return boxEncodingsClipRange_; } + set { + boxEncodingsClipRange_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ConvolutionalBoxPredictor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ConvolutionalBoxPredictor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(ConvHyperparams, other.ConvHyperparams)) return false; + if (MinDepth != other.MinDepth) return false; + if (MaxDepth != other.MaxDepth) return false; + if (NumLayersBeforePredictor != other.NumLayersBeforePredictor) return false; + if (UseDropout != other.UseDropout) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(DropoutKeepProbability, other.DropoutKeepProbability)) return false; + if (KernelSize != other.KernelSize) return false; + if (BoxCodeSize != other.BoxCodeSize) return false; + if (ApplySigmoidToScores != other.ApplySigmoidToScores) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(ClassPredictionBiasInit, other.ClassPredictionBiasInit)) return false; + if (UseDepthwise != other.UseDepthwise) return false; + if (!object.Equals(BoxEncodingsClipRange, other.BoxEncodingsClipRange)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (convHyperparams_ != null) hash ^= ConvHyperparams.GetHashCode(); + if (MinDepth != 0) hash ^= MinDepth.GetHashCode(); + if (MaxDepth != 0) hash ^= MaxDepth.GetHashCode(); + if (NumLayersBeforePredictor != 0) hash ^= NumLayersBeforePredictor.GetHashCode(); + if (UseDropout != false) hash ^= UseDropout.GetHashCode(); + if (DropoutKeepProbability != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(DropoutKeepProbability); + if (KernelSize != 0) hash ^= KernelSize.GetHashCode(); + if (BoxCodeSize != 0) hash ^= BoxCodeSize.GetHashCode(); + if (ApplySigmoidToScores != false) hash ^= ApplySigmoidToScores.GetHashCode(); + if (ClassPredictionBiasInit != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(ClassPredictionBiasInit); + if (UseDepthwise != false) hash ^= UseDepthwise.GetHashCode(); + if (boxEncodingsClipRange_ != null) hash ^= BoxEncodingsClipRange.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (convHyperparams_ != null) { + output.WriteRawTag(10); + output.WriteMessage(ConvHyperparams); + } + if (MinDepth != 0) { + output.WriteRawTag(16); + output.WriteInt32(MinDepth); + } + if (MaxDepth != 0) { + output.WriteRawTag(24); + output.WriteInt32(MaxDepth); + } + if (NumLayersBeforePredictor != 0) { + output.WriteRawTag(32); + output.WriteInt32(NumLayersBeforePredictor); + } + if (UseDropout != false) { + output.WriteRawTag(40); + output.WriteBool(UseDropout); + } + if (DropoutKeepProbability != 0F) { + output.WriteRawTag(53); + output.WriteFloat(DropoutKeepProbability); + } + if (KernelSize != 0) { + output.WriteRawTag(56); + output.WriteInt32(KernelSize); + } + if (BoxCodeSize != 0) { + output.WriteRawTag(64); + output.WriteInt32(BoxCodeSize); + } + if (ApplySigmoidToScores != false) { + output.WriteRawTag(72); + output.WriteBool(ApplySigmoidToScores); + } + if (ClassPredictionBiasInit != 0F) { + output.WriteRawTag(85); + output.WriteFloat(ClassPredictionBiasInit); + } + if (UseDepthwise != false) { + output.WriteRawTag(88); + output.WriteBool(UseDepthwise); + } + if (boxEncodingsClipRange_ != null) { + output.WriteRawTag(98); + output.WriteMessage(BoxEncodingsClipRange); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (convHyperparams_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ConvHyperparams); + } + if (MinDepth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinDepth); + } + if (MaxDepth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxDepth); + } + if (NumLayersBeforePredictor != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumLayersBeforePredictor); + } + if (UseDropout != false) { + size += 1 + 1; + } + if (DropoutKeepProbability != 0F) { + size += 1 + 4; + } + if (KernelSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelSize); + } + if (BoxCodeSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(BoxCodeSize); + } + if (ApplySigmoidToScores != false) { + size += 1 + 1; + } + if (ClassPredictionBiasInit != 0F) { + size += 1 + 4; + } + if (UseDepthwise != false) { + size += 1 + 1; + } + if (boxEncodingsClipRange_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BoxEncodingsClipRange); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ConvolutionalBoxPredictor other) { + if (other == null) { + return; + } + if (other.convHyperparams_ != null) { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + ConvHyperparams.MergeFrom(other.ConvHyperparams); + } + if (other.MinDepth != 0) { + MinDepth = other.MinDepth; + } + if (other.MaxDepth != 0) { + MaxDepth = other.MaxDepth; + } + if (other.NumLayersBeforePredictor != 0) { + NumLayersBeforePredictor = other.NumLayersBeforePredictor; + } + if (other.UseDropout != false) { + UseDropout = other.UseDropout; + } + if (other.DropoutKeepProbability != 0F) { + DropoutKeepProbability = other.DropoutKeepProbability; + } + if (other.KernelSize != 0) { + KernelSize = other.KernelSize; + } + if (other.BoxCodeSize != 0) { + BoxCodeSize = other.BoxCodeSize; + } + if (other.ApplySigmoidToScores != false) { + ApplySigmoidToScores = other.ApplySigmoidToScores; + } + if (other.ClassPredictionBiasInit != 0F) { + ClassPredictionBiasInit = other.ClassPredictionBiasInit; + } + if (other.UseDepthwise != false) { + UseDepthwise = other.UseDepthwise; + } + if (other.boxEncodingsClipRange_ != null) { + if (boxEncodingsClipRange_ == null) { + boxEncodingsClipRange_ = new global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor.Types.BoxEncodingsClipRange(); + } + BoxEncodingsClipRange.MergeFrom(other.BoxEncodingsClipRange); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + input.ReadMessage(convHyperparams_); + break; + } + case 16: { + MinDepth = input.ReadInt32(); + break; + } + case 24: { + MaxDepth = input.ReadInt32(); + break; + } + case 32: { + NumLayersBeforePredictor = input.ReadInt32(); + break; + } + case 40: { + UseDropout = input.ReadBool(); + break; + } + case 53: { + DropoutKeepProbability = input.ReadFloat(); + break; + } + case 56: { + KernelSize = input.ReadInt32(); + break; + } + case 64: { + BoxCodeSize = input.ReadInt32(); + break; + } + case 72: { + ApplySigmoidToScores = input.ReadBool(); + break; + } + case 85: { + ClassPredictionBiasInit = input.ReadFloat(); + break; + } + case 88: { + UseDepthwise = input.ReadBool(); + break; + } + case 98: { + if (boxEncodingsClipRange_ == null) { + boxEncodingsClipRange_ = new global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor.Types.BoxEncodingsClipRange(); + } + input.ReadMessage(boxEncodingsClipRange_); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the ConvolutionalBoxPredictor message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// If specified, apply clipping to box encodings. + /// + public sealed partial class BoxEncodingsClipRange : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BoxEncodingsClipRange()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.ConvolutionalBoxPredictor.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxEncodingsClipRange() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxEncodingsClipRange(BoxEncodingsClipRange other) : this() { + min_ = other.min_; + max_ = other.max_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxEncodingsClipRange Clone() { + return new BoxEncodingsClipRange(this); + } + + /// Field number for the "min" field. + public const int MinFieldNumber = 1; + private float min_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Min { + get { return min_; } + set { + min_ = value; + } + } + + /// Field number for the "max" field. + public const int MaxFieldNumber = 2; + private float max_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Max { + get { return max_; } + set { + max_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BoxEncodingsClipRange); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BoxEncodingsClipRange other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Min, other.Min)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Max, other.Max)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Min != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Min); + if (Max != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Max); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Min != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Min); + } + if (Max != 0F) { + output.WriteRawTag(21); + output.WriteFloat(Max); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Min != 0F) { + size += 1 + 4; + } + if (Max != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BoxEncodingsClipRange other) { + if (other == null) { + return; + } + if (other.Min != 0F) { + Min = other.Min; + } + if (other.Max != 0F) { + Max = other.Max; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Min = input.ReadFloat(); + break; + } + case 21: { + Max = input.ReadFloat(); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// Configuration proto for weight shared convolutional box predictor. + /// Next id: 18 + /// + public sealed partial class WeightSharedConvolutionalBoxPredictor : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WeightSharedConvolutionalBoxPredictor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictorReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightSharedConvolutionalBoxPredictor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightSharedConvolutionalBoxPredictor(WeightSharedConvolutionalBoxPredictor other) : this() { + convHyperparams_ = other.convHyperparams_ != null ? other.convHyperparams_.Clone() : null; + numLayersBeforePredictor_ = other.numLayersBeforePredictor_; + depth_ = other.depth_; + kernelSize_ = other.kernelSize_; + boxCodeSize_ = other.boxCodeSize_; + classPredictionBiasInit_ = other.classPredictionBiasInit_; + useDropout_ = other.useDropout_; + dropoutKeepProbability_ = other.dropoutKeepProbability_; + sharePredictionTower_ = other.sharePredictionTower_; + useDepthwise_ = other.useDepthwise_; + scoreConverter_ = other.scoreConverter_; + boxEncodingsClipRange_ = other.boxEncodingsClipRange_ != null ? other.boxEncodingsClipRange_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightSharedConvolutionalBoxPredictor Clone() { + return new WeightSharedConvolutionalBoxPredictor(this); + } + + /// Field number for the "conv_hyperparams" field. + public const int ConvHyperparamsFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams convHyperparams_; + /// + /// Hyperparameters for convolution ops used in the box predictor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams ConvHyperparams { + get { return convHyperparams_; } + set { + convHyperparams_ = value; + } + } + + /// Field number for the "num_layers_before_predictor" field. + public const int NumLayersBeforePredictorFieldNumber = 4; + private int numLayersBeforePredictor_; + /// + /// Number of the additional conv layers before the predictor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumLayersBeforePredictor { + get { return numLayersBeforePredictor_; } + set { + numLayersBeforePredictor_ = value; + } + } + + /// Field number for the "depth" field. + public const int DepthFieldNumber = 2; + private int depth_; + /// + /// Output depth for the convolution ops prior to predicting box encodings + /// and class predictions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Depth { + get { return depth_; } + set { + depth_ = value; + } + } + + /// Field number for the "kernel_size" field. + public const int KernelSizeFieldNumber = 7; + private int kernelSize_; + /// + /// Size of final convolution kernel. If the spatial resolution of the feature + /// map is smaller than the kernel size, then the kernel size is set to + /// min(feature_width, feature_height). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int KernelSize { + get { return kernelSize_; } + set { + kernelSize_ = value; + } + } + + /// Field number for the "box_code_size" field. + public const int BoxCodeSizeFieldNumber = 8; + private int boxCodeSize_; + /// + /// Size of the encoding for boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int BoxCodeSize { + get { return boxCodeSize_; } + set { + boxCodeSize_ = value; + } + } + + /// Field number for the "class_prediction_bias_init" field. + public const int ClassPredictionBiasInitFieldNumber = 10; + private float classPredictionBiasInit_; + /// + /// Bias initialization for class prediction. It has been show to stabilize + /// training where there are large number of negative boxes. See + /// https://arxiv.org/abs/1708.02002 for details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float ClassPredictionBiasInit { + get { return classPredictionBiasInit_; } + set { + classPredictionBiasInit_ = value; + } + } + + /// Field number for the "use_dropout" field. + public const int UseDropoutFieldNumber = 11; + private bool useDropout_; + /// + /// Whether to use dropout for class prediction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseDropout { + get { return useDropout_; } + set { + useDropout_ = value; + } + } + + /// Field number for the "dropout_keep_probability" field. + public const int DropoutKeepProbabilityFieldNumber = 12; + private float dropoutKeepProbability_; + /// + /// Keep probability for dropout. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float DropoutKeepProbability { + get { return dropoutKeepProbability_; } + set { + dropoutKeepProbability_ = value; + } + } + + /// Field number for the "share_prediction_tower" field. + public const int SharePredictionTowerFieldNumber = 13; + private bool sharePredictionTower_; + /// + /// Whether to share the multi-layer tower between box prediction and class + /// prediction heads. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SharePredictionTower { + get { return sharePredictionTower_; } + set { + sharePredictionTower_ = value; + } + } + + /// Field number for the "use_depthwise" field. + public const int UseDepthwiseFieldNumber = 14; + private bool useDepthwise_; + /// + /// Whether to use depthwise separable convolution for box predictor layers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseDepthwise { + get { return useDepthwise_; } + set { + useDepthwise_ = value; + } + } + + /// Field number for the "score_converter" field. + public const int ScoreConverterFieldNumber = 16; + private global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.ScoreConverter scoreConverter_ = 0; + /// + /// Callable elementwise score converter at inference time. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.ScoreConverter ScoreConverter { + get { return scoreConverter_; } + set { + scoreConverter_ = value; + } + } + + /// Field number for the "box_encodings_clip_range" field. + public const int BoxEncodingsClipRangeFieldNumber = 17; + private global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.BoxEncodingsClipRange boxEncodingsClipRange_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.BoxEncodingsClipRange BoxEncodingsClipRange { + get { return boxEncodingsClipRange_; } + set { + boxEncodingsClipRange_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as WeightSharedConvolutionalBoxPredictor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(WeightSharedConvolutionalBoxPredictor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(ConvHyperparams, other.ConvHyperparams)) return false; + if (NumLayersBeforePredictor != other.NumLayersBeforePredictor) return false; + if (Depth != other.Depth) return false; + if (KernelSize != other.KernelSize) return false; + if (BoxCodeSize != other.BoxCodeSize) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(ClassPredictionBiasInit, other.ClassPredictionBiasInit)) return false; + if (UseDropout != other.UseDropout) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(DropoutKeepProbability, other.DropoutKeepProbability)) return false; + if (SharePredictionTower != other.SharePredictionTower) return false; + if (UseDepthwise != other.UseDepthwise) return false; + if (ScoreConverter != other.ScoreConverter) return false; + if (!object.Equals(BoxEncodingsClipRange, other.BoxEncodingsClipRange)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (convHyperparams_ != null) hash ^= ConvHyperparams.GetHashCode(); + if (NumLayersBeforePredictor != 0) hash ^= NumLayersBeforePredictor.GetHashCode(); + if (Depth != 0) hash ^= Depth.GetHashCode(); + if (KernelSize != 0) hash ^= KernelSize.GetHashCode(); + if (BoxCodeSize != 0) hash ^= BoxCodeSize.GetHashCode(); + if (ClassPredictionBiasInit != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(ClassPredictionBiasInit); + if (UseDropout != false) hash ^= UseDropout.GetHashCode(); + if (DropoutKeepProbability != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(DropoutKeepProbability); + if (SharePredictionTower != false) hash ^= SharePredictionTower.GetHashCode(); + if (UseDepthwise != false) hash ^= UseDepthwise.GetHashCode(); + if (ScoreConverter != 0) hash ^= ScoreConverter.GetHashCode(); + if (boxEncodingsClipRange_ != null) hash ^= BoxEncodingsClipRange.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (convHyperparams_ != null) { + output.WriteRawTag(10); + output.WriteMessage(ConvHyperparams); + } + if (Depth != 0) { + output.WriteRawTag(16); + output.WriteInt32(Depth); + } + if (NumLayersBeforePredictor != 0) { + output.WriteRawTag(32); + output.WriteInt32(NumLayersBeforePredictor); + } + if (KernelSize != 0) { + output.WriteRawTag(56); + output.WriteInt32(KernelSize); + } + if (BoxCodeSize != 0) { + output.WriteRawTag(64); + output.WriteInt32(BoxCodeSize); + } + if (ClassPredictionBiasInit != 0F) { + output.WriteRawTag(85); + output.WriteFloat(ClassPredictionBiasInit); + } + if (UseDropout != false) { + output.WriteRawTag(88); + output.WriteBool(UseDropout); + } + if (DropoutKeepProbability != 0F) { + output.WriteRawTag(101); + output.WriteFloat(DropoutKeepProbability); + } + if (SharePredictionTower != false) { + output.WriteRawTag(104); + output.WriteBool(SharePredictionTower); + } + if (UseDepthwise != false) { + output.WriteRawTag(112); + output.WriteBool(UseDepthwise); + } + if (ScoreConverter != 0) { + output.WriteRawTag(128, 1); + output.WriteEnum((int) ScoreConverter); + } + if (boxEncodingsClipRange_ != null) { + output.WriteRawTag(138, 1); + output.WriteMessage(BoxEncodingsClipRange); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (convHyperparams_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ConvHyperparams); + } + if (NumLayersBeforePredictor != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumLayersBeforePredictor); + } + if (Depth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Depth); + } + if (KernelSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelSize); + } + if (BoxCodeSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(BoxCodeSize); + } + if (ClassPredictionBiasInit != 0F) { + size += 1 + 4; + } + if (UseDropout != false) { + size += 1 + 1; + } + if (DropoutKeepProbability != 0F) { + size += 1 + 4; + } + if (SharePredictionTower != false) { + size += 1 + 1; + } + if (UseDepthwise != false) { + size += 1 + 1; + } + if (ScoreConverter != 0) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) ScoreConverter); + } + if (boxEncodingsClipRange_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(BoxEncodingsClipRange); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(WeightSharedConvolutionalBoxPredictor other) { + if (other == null) { + return; + } + if (other.convHyperparams_ != null) { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + ConvHyperparams.MergeFrom(other.ConvHyperparams); + } + if (other.NumLayersBeforePredictor != 0) { + NumLayersBeforePredictor = other.NumLayersBeforePredictor; + } + if (other.Depth != 0) { + Depth = other.Depth; + } + if (other.KernelSize != 0) { + KernelSize = other.KernelSize; + } + if (other.BoxCodeSize != 0) { + BoxCodeSize = other.BoxCodeSize; + } + if (other.ClassPredictionBiasInit != 0F) { + ClassPredictionBiasInit = other.ClassPredictionBiasInit; + } + if (other.UseDropout != false) { + UseDropout = other.UseDropout; + } + if (other.DropoutKeepProbability != 0F) { + DropoutKeepProbability = other.DropoutKeepProbability; + } + if (other.SharePredictionTower != false) { + SharePredictionTower = other.SharePredictionTower; + } + if (other.UseDepthwise != false) { + UseDepthwise = other.UseDepthwise; + } + if (other.ScoreConverter != 0) { + ScoreConverter = other.ScoreConverter; + } + if (other.boxEncodingsClipRange_ != null) { + if (boxEncodingsClipRange_ == null) { + boxEncodingsClipRange_ = new global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.BoxEncodingsClipRange(); + } + BoxEncodingsClipRange.MergeFrom(other.BoxEncodingsClipRange); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + input.ReadMessage(convHyperparams_); + break; + } + case 16: { + Depth = input.ReadInt32(); + break; + } + case 32: { + NumLayersBeforePredictor = input.ReadInt32(); + break; + } + case 56: { + KernelSize = input.ReadInt32(); + break; + } + case 64: { + BoxCodeSize = input.ReadInt32(); + break; + } + case 85: { + ClassPredictionBiasInit = input.ReadFloat(); + break; + } + case 88: { + UseDropout = input.ReadBool(); + break; + } + case 101: { + DropoutKeepProbability = input.ReadFloat(); + break; + } + case 104: { + SharePredictionTower = input.ReadBool(); + break; + } + case 112: { + UseDepthwise = input.ReadBool(); + break; + } + case 128: { + scoreConverter_ = (global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.ScoreConverter) input.ReadEnum(); + break; + } + case 138: { + if (boxEncodingsClipRange_ == null) { + boxEncodingsClipRange_ = new global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Types.BoxEncodingsClipRange(); + } + input.ReadMessage(boxEncodingsClipRange_); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the WeightSharedConvolutionalBoxPredictor message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// Enum to specify how to convert the detection scores at inference time. + /// + public enum ScoreConverter { + /// + /// Input scores equals output scores. + /// + [pbr::OriginalName("IDENTITY")] Identity = 0, + /// + /// Applies a sigmoid on input scores. + /// + [pbr::OriginalName("SIGMOID")] Sigmoid = 1, + } + + /// + /// If specified, apply clipping to box encodings. + /// + public sealed partial class BoxEncodingsClipRange : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BoxEncodingsClipRange()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.WeightSharedConvolutionalBoxPredictor.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxEncodingsClipRange() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxEncodingsClipRange(BoxEncodingsClipRange other) : this() { + min_ = other.min_; + max_ = other.max_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BoxEncodingsClipRange Clone() { + return new BoxEncodingsClipRange(this); + } + + /// Field number for the "min" field. + public const int MinFieldNumber = 1; + private float min_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Min { + get { return min_; } + set { + min_ = value; + } + } + + /// Field number for the "max" field. + public const int MaxFieldNumber = 2; + private float max_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Max { + get { return max_; } + set { + max_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BoxEncodingsClipRange); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BoxEncodingsClipRange other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Min, other.Min)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Max, other.Max)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Min != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Min); + if (Max != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Max); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Min != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Min); + } + if (Max != 0F) { + output.WriteRawTag(21); + output.WriteFloat(Max); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Min != 0F) { + size += 1 + 4; + } + if (Max != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BoxEncodingsClipRange other) { + if (other == null) { + return; + } + if (other.Min != 0F) { + Min = other.Min; + } + if (other.Max != 0F) { + Max = other.Max; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Min = input.ReadFloat(); + break; + } + case 21: { + Max = input.ReadFloat(); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// TODO(alirezafathi): Refactor the proto file to be able to configure mask rcnn + /// head easily. + /// Next id: 15 + /// + public sealed partial class MaskRCNNBoxPredictor : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MaskRCNNBoxPredictor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictorReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MaskRCNNBoxPredictor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MaskRCNNBoxPredictor(MaskRCNNBoxPredictor other) : this() { + fcHyperparams_ = other.fcHyperparams_ != null ? other.fcHyperparams_.Clone() : null; + useDropout_ = other.useDropout_; + dropoutKeepProbability_ = other.dropoutKeepProbability_; + boxCodeSize_ = other.boxCodeSize_; + convHyperparams_ = other.convHyperparams_ != null ? other.convHyperparams_.Clone() : null; + predictInstanceMasks_ = other.predictInstanceMasks_; + maskPredictionConvDepth_ = other.maskPredictionConvDepth_; + predictKeypoints_ = other.predictKeypoints_; + maskHeight_ = other.maskHeight_; + maskWidth_ = other.maskWidth_; + maskPredictionNumConvLayers_ = other.maskPredictionNumConvLayers_; + masksAreClassAgnostic_ = other.masksAreClassAgnostic_; + shareBoxAcrossClasses_ = other.shareBoxAcrossClasses_; + convolveThenUpsampleMasks_ = other.convolveThenUpsampleMasks_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MaskRCNNBoxPredictor Clone() { + return new MaskRCNNBoxPredictor(this); + } + + /// Field number for the "fc_hyperparams" field. + public const int FcHyperparamsFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams fcHyperparams_; + /// + /// Hyperparameters for fully connected ops used in the box predictor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams FcHyperparams { + get { return fcHyperparams_; } + set { + fcHyperparams_ = value; + } + } + + /// Field number for the "use_dropout" field. + public const int UseDropoutFieldNumber = 2; + private bool useDropout_; + /// + /// Whether to use dropout op prior to the both box and class predictions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseDropout { + get { return useDropout_; } + set { + useDropout_ = value; + } + } + + /// Field number for the "dropout_keep_probability" field. + public const int DropoutKeepProbabilityFieldNumber = 3; + private float dropoutKeepProbability_; + /// + /// Keep probability for dropout. This is only used if use_dropout is true. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float DropoutKeepProbability { + get { return dropoutKeepProbability_; } + set { + dropoutKeepProbability_ = value; + } + } + + /// Field number for the "box_code_size" field. + public const int BoxCodeSizeFieldNumber = 4; + private int boxCodeSize_; + /// + /// Size of the encoding for the boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int BoxCodeSize { + get { return boxCodeSize_; } + set { + boxCodeSize_ = value; + } + } + + /// Field number for the "conv_hyperparams" field. + public const int ConvHyperparamsFieldNumber = 5; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams convHyperparams_; + /// + /// Hyperparameters for convolution ops used in the box predictor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams ConvHyperparams { + get { return convHyperparams_; } + set { + convHyperparams_ = value; + } + } + + /// Field number for the "predict_instance_masks" field. + public const int PredictInstanceMasksFieldNumber = 6; + private bool predictInstanceMasks_; + /// + /// Whether to predict instance masks inside detection boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool PredictInstanceMasks { + get { return predictInstanceMasks_; } + set { + predictInstanceMasks_ = value; + } + } + + /// Field number for the "mask_prediction_conv_depth" field. + public const int MaskPredictionConvDepthFieldNumber = 7; + private int maskPredictionConvDepth_; + /// + /// The depth for the first conv2d_transpose op applied to the + /// image_features in the mask prediction branch. If set to 0, the value + /// will be set automatically based on the number of channels in the image + /// features and the number of classes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaskPredictionConvDepth { + get { return maskPredictionConvDepth_; } + set { + maskPredictionConvDepth_ = value; + } + } + + /// Field number for the "predict_keypoints" field. + public const int PredictKeypointsFieldNumber = 8; + private bool predictKeypoints_; + /// + /// Whether to predict keypoints inside detection boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool PredictKeypoints { + get { return predictKeypoints_; } + set { + predictKeypoints_ = value; + } + } + + /// Field number for the "mask_height" field. + public const int MaskHeightFieldNumber = 9; + private int maskHeight_; + /// + /// The height and the width of the predicted mask. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaskHeight { + get { return maskHeight_; } + set { + maskHeight_ = value; + } + } + + /// Field number for the "mask_width" field. + public const int MaskWidthFieldNumber = 10; + private int maskWidth_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaskWidth { + get { return maskWidth_; } + set { + maskWidth_ = value; + } + } + + /// Field number for the "mask_prediction_num_conv_layers" field. + public const int MaskPredictionNumConvLayersFieldNumber = 11; + private int maskPredictionNumConvLayers_; + /// + /// The number of convolutions applied to image_features in the mask prediction + /// branch. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaskPredictionNumConvLayers { + get { return maskPredictionNumConvLayers_; } + set { + maskPredictionNumConvLayers_ = value; + } + } + + /// Field number for the "masks_are_class_agnostic" field. + public const int MasksAreClassAgnosticFieldNumber = 12; + private bool masksAreClassAgnostic_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool MasksAreClassAgnostic { + get { return masksAreClassAgnostic_; } + set { + masksAreClassAgnostic_ = value; + } + } + + /// Field number for the "share_box_across_classes" field. + public const int ShareBoxAcrossClassesFieldNumber = 13; + private bool shareBoxAcrossClasses_; + /// + /// Whether to use one box for all classes rather than a different box for each + /// class. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ShareBoxAcrossClasses { + get { return shareBoxAcrossClasses_; } + set { + shareBoxAcrossClasses_ = value; + } + } + + /// Field number for the "convolve_then_upsample_masks" field. + public const int ConvolveThenUpsampleMasksFieldNumber = 14; + private bool convolveThenUpsampleMasks_; + /// + /// Whether to apply convolutions on mask features before upsampling using + /// nearest neighbor resizing. + /// By default, mask features are resized to [`mask_height`, `mask_width`] + /// before applying convolutions and predicting masks. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ConvolveThenUpsampleMasks { + get { return convolveThenUpsampleMasks_; } + set { + convolveThenUpsampleMasks_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as MaskRCNNBoxPredictor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(MaskRCNNBoxPredictor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(FcHyperparams, other.FcHyperparams)) return false; + if (UseDropout != other.UseDropout) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(DropoutKeepProbability, other.DropoutKeepProbability)) return false; + if (BoxCodeSize != other.BoxCodeSize) return false; + if (!object.Equals(ConvHyperparams, other.ConvHyperparams)) return false; + if (PredictInstanceMasks != other.PredictInstanceMasks) return false; + if (MaskPredictionConvDepth != other.MaskPredictionConvDepth) return false; + if (PredictKeypoints != other.PredictKeypoints) return false; + if (MaskHeight != other.MaskHeight) return false; + if (MaskWidth != other.MaskWidth) return false; + if (MaskPredictionNumConvLayers != other.MaskPredictionNumConvLayers) return false; + if (MasksAreClassAgnostic != other.MasksAreClassAgnostic) return false; + if (ShareBoxAcrossClasses != other.ShareBoxAcrossClasses) return false; + if (ConvolveThenUpsampleMasks != other.ConvolveThenUpsampleMasks) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (fcHyperparams_ != null) hash ^= FcHyperparams.GetHashCode(); + if (UseDropout != false) hash ^= UseDropout.GetHashCode(); + if (DropoutKeepProbability != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(DropoutKeepProbability); + if (BoxCodeSize != 0) hash ^= BoxCodeSize.GetHashCode(); + if (convHyperparams_ != null) hash ^= ConvHyperparams.GetHashCode(); + if (PredictInstanceMasks != false) hash ^= PredictInstanceMasks.GetHashCode(); + if (MaskPredictionConvDepth != 0) hash ^= MaskPredictionConvDepth.GetHashCode(); + if (PredictKeypoints != false) hash ^= PredictKeypoints.GetHashCode(); + if (MaskHeight != 0) hash ^= MaskHeight.GetHashCode(); + if (MaskWidth != 0) hash ^= MaskWidth.GetHashCode(); + if (MaskPredictionNumConvLayers != 0) hash ^= MaskPredictionNumConvLayers.GetHashCode(); + if (MasksAreClassAgnostic != false) hash ^= MasksAreClassAgnostic.GetHashCode(); + if (ShareBoxAcrossClasses != false) hash ^= ShareBoxAcrossClasses.GetHashCode(); + if (ConvolveThenUpsampleMasks != false) hash ^= ConvolveThenUpsampleMasks.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (fcHyperparams_ != null) { + output.WriteRawTag(10); + output.WriteMessage(FcHyperparams); + } + if (UseDropout != false) { + output.WriteRawTag(16); + output.WriteBool(UseDropout); + } + if (DropoutKeepProbability != 0F) { + output.WriteRawTag(29); + output.WriteFloat(DropoutKeepProbability); + } + if (BoxCodeSize != 0) { + output.WriteRawTag(32); + output.WriteInt32(BoxCodeSize); + } + if (convHyperparams_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ConvHyperparams); + } + if (PredictInstanceMasks != false) { + output.WriteRawTag(48); + output.WriteBool(PredictInstanceMasks); + } + if (MaskPredictionConvDepth != 0) { + output.WriteRawTag(56); + output.WriteInt32(MaskPredictionConvDepth); + } + if (PredictKeypoints != false) { + output.WriteRawTag(64); + output.WriteBool(PredictKeypoints); + } + if (MaskHeight != 0) { + output.WriteRawTag(72); + output.WriteInt32(MaskHeight); + } + if (MaskWidth != 0) { + output.WriteRawTag(80); + output.WriteInt32(MaskWidth); + } + if (MaskPredictionNumConvLayers != 0) { + output.WriteRawTag(88); + output.WriteInt32(MaskPredictionNumConvLayers); + } + if (MasksAreClassAgnostic != false) { + output.WriteRawTag(96); + output.WriteBool(MasksAreClassAgnostic); + } + if (ShareBoxAcrossClasses != false) { + output.WriteRawTag(104); + output.WriteBool(ShareBoxAcrossClasses); + } + if (ConvolveThenUpsampleMasks != false) { + output.WriteRawTag(112); + output.WriteBool(ConvolveThenUpsampleMasks); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (fcHyperparams_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FcHyperparams); + } + if (UseDropout != false) { + size += 1 + 1; + } + if (DropoutKeepProbability != 0F) { + size += 1 + 4; + } + if (BoxCodeSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(BoxCodeSize); + } + if (convHyperparams_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ConvHyperparams); + } + if (PredictInstanceMasks != false) { + size += 1 + 1; + } + if (MaskPredictionConvDepth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaskPredictionConvDepth); + } + if (PredictKeypoints != false) { + size += 1 + 1; + } + if (MaskHeight != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaskHeight); + } + if (MaskWidth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaskWidth); + } + if (MaskPredictionNumConvLayers != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaskPredictionNumConvLayers); + } + if (MasksAreClassAgnostic != false) { + size += 1 + 1; + } + if (ShareBoxAcrossClasses != false) { + size += 1 + 1; + } + if (ConvolveThenUpsampleMasks != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(MaskRCNNBoxPredictor other) { + if (other == null) { + return; + } + if (other.fcHyperparams_ != null) { + if (fcHyperparams_ == null) { + fcHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + FcHyperparams.MergeFrom(other.FcHyperparams); + } + if (other.UseDropout != false) { + UseDropout = other.UseDropout; + } + if (other.DropoutKeepProbability != 0F) { + DropoutKeepProbability = other.DropoutKeepProbability; + } + if (other.BoxCodeSize != 0) { + BoxCodeSize = other.BoxCodeSize; + } + if (other.convHyperparams_ != null) { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + ConvHyperparams.MergeFrom(other.ConvHyperparams); + } + if (other.PredictInstanceMasks != false) { + PredictInstanceMasks = other.PredictInstanceMasks; + } + if (other.MaskPredictionConvDepth != 0) { + MaskPredictionConvDepth = other.MaskPredictionConvDepth; + } + if (other.PredictKeypoints != false) { + PredictKeypoints = other.PredictKeypoints; + } + if (other.MaskHeight != 0) { + MaskHeight = other.MaskHeight; + } + if (other.MaskWidth != 0) { + MaskWidth = other.MaskWidth; + } + if (other.MaskPredictionNumConvLayers != 0) { + MaskPredictionNumConvLayers = other.MaskPredictionNumConvLayers; + } + if (other.MasksAreClassAgnostic != false) { + MasksAreClassAgnostic = other.MasksAreClassAgnostic; + } + if (other.ShareBoxAcrossClasses != false) { + ShareBoxAcrossClasses = other.ShareBoxAcrossClasses; + } + if (other.ConvolveThenUpsampleMasks != false) { + ConvolveThenUpsampleMasks = other.ConvolveThenUpsampleMasks; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (fcHyperparams_ == null) { + fcHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + input.ReadMessage(fcHyperparams_); + break; + } + case 16: { + UseDropout = input.ReadBool(); + break; + } + case 29: { + DropoutKeepProbability = input.ReadFloat(); + break; + } + case 32: { + BoxCodeSize = input.ReadInt32(); + break; + } + case 42: { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + input.ReadMessage(convHyperparams_); + break; + } + case 48: { + PredictInstanceMasks = input.ReadBool(); + break; + } + case 56: { + MaskPredictionConvDepth = input.ReadInt32(); + break; + } + case 64: { + PredictKeypoints = input.ReadBool(); + break; + } + case 72: { + MaskHeight = input.ReadInt32(); + break; + } + case 80: { + MaskWidth = input.ReadInt32(); + break; + } + case 88: { + MaskPredictionNumConvLayers = input.ReadInt32(); + break; + } + case 96: { + MasksAreClassAgnostic = input.ReadBool(); + break; + } + case 104: { + ShareBoxAcrossClasses = input.ReadBool(); + break; + } + case 112: { + ConvolveThenUpsampleMasks = input.ReadBool(); + break; + } + } + } + } + + } + + public sealed partial class RfcnBoxPredictor : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RfcnBoxPredictor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictorReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RfcnBoxPredictor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RfcnBoxPredictor(RfcnBoxPredictor other) : this() { + convHyperparams_ = other.convHyperparams_ != null ? other.convHyperparams_.Clone() : null; + numSpatialBinsHeight_ = other.numSpatialBinsHeight_; + numSpatialBinsWidth_ = other.numSpatialBinsWidth_; + depth_ = other.depth_; + boxCodeSize_ = other.boxCodeSize_; + cropHeight_ = other.cropHeight_; + cropWidth_ = other.cropWidth_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RfcnBoxPredictor Clone() { + return new RfcnBoxPredictor(this); + } + + /// Field number for the "conv_hyperparams" field. + public const int ConvHyperparamsFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams convHyperparams_; + /// + /// Hyperparameters for convolution ops used in the box predictor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams ConvHyperparams { + get { return convHyperparams_; } + set { + convHyperparams_ = value; + } + } + + /// Field number for the "num_spatial_bins_height" field. + public const int NumSpatialBinsHeightFieldNumber = 2; + private int numSpatialBinsHeight_; + /// + /// Bin sizes for RFCN crops. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumSpatialBinsHeight { + get { return numSpatialBinsHeight_; } + set { + numSpatialBinsHeight_ = value; + } + } + + /// Field number for the "num_spatial_bins_width" field. + public const int NumSpatialBinsWidthFieldNumber = 3; + private int numSpatialBinsWidth_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumSpatialBinsWidth { + get { return numSpatialBinsWidth_; } + set { + numSpatialBinsWidth_ = value; + } + } + + /// Field number for the "depth" field. + public const int DepthFieldNumber = 4; + private int depth_; + /// + /// Target depth to reduce the input image features to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Depth { + get { return depth_; } + set { + depth_ = value; + } + } + + /// Field number for the "box_code_size" field. + public const int BoxCodeSizeFieldNumber = 5; + private int boxCodeSize_; + /// + /// Size of the encoding for the boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int BoxCodeSize { + get { return boxCodeSize_; } + set { + boxCodeSize_ = value; + } + } + + /// Field number for the "crop_height" field. + public const int CropHeightFieldNumber = 6; + private int cropHeight_; + /// + /// Size to resize the rfcn crops to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CropHeight { + get { return cropHeight_; } + set { + cropHeight_ = value; + } + } + + /// Field number for the "crop_width" field. + public const int CropWidthFieldNumber = 7; + private int cropWidth_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CropWidth { + get { return cropWidth_; } + set { + cropWidth_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RfcnBoxPredictor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RfcnBoxPredictor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(ConvHyperparams, other.ConvHyperparams)) return false; + if (NumSpatialBinsHeight != other.NumSpatialBinsHeight) return false; + if (NumSpatialBinsWidth != other.NumSpatialBinsWidth) return false; + if (Depth != other.Depth) return false; + if (BoxCodeSize != other.BoxCodeSize) return false; + if (CropHeight != other.CropHeight) return false; + if (CropWidth != other.CropWidth) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (convHyperparams_ != null) hash ^= ConvHyperparams.GetHashCode(); + if (NumSpatialBinsHeight != 0) hash ^= NumSpatialBinsHeight.GetHashCode(); + if (NumSpatialBinsWidth != 0) hash ^= NumSpatialBinsWidth.GetHashCode(); + if (Depth != 0) hash ^= Depth.GetHashCode(); + if (BoxCodeSize != 0) hash ^= BoxCodeSize.GetHashCode(); + if (CropHeight != 0) hash ^= CropHeight.GetHashCode(); + if (CropWidth != 0) hash ^= CropWidth.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (convHyperparams_ != null) { + output.WriteRawTag(10); + output.WriteMessage(ConvHyperparams); + } + if (NumSpatialBinsHeight != 0) { + output.WriteRawTag(16); + output.WriteInt32(NumSpatialBinsHeight); + } + if (NumSpatialBinsWidth != 0) { + output.WriteRawTag(24); + output.WriteInt32(NumSpatialBinsWidth); + } + if (Depth != 0) { + output.WriteRawTag(32); + output.WriteInt32(Depth); + } + if (BoxCodeSize != 0) { + output.WriteRawTag(40); + output.WriteInt32(BoxCodeSize); + } + if (CropHeight != 0) { + output.WriteRawTag(48); + output.WriteInt32(CropHeight); + } + if (CropWidth != 0) { + output.WriteRawTag(56); + output.WriteInt32(CropWidth); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (convHyperparams_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ConvHyperparams); + } + if (NumSpatialBinsHeight != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumSpatialBinsHeight); + } + if (NumSpatialBinsWidth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumSpatialBinsWidth); + } + if (Depth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Depth); + } + if (BoxCodeSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(BoxCodeSize); + } + if (CropHeight != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(CropHeight); + } + if (CropWidth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(CropWidth); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RfcnBoxPredictor other) { + if (other == null) { + return; + } + if (other.convHyperparams_ != null) { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + ConvHyperparams.MergeFrom(other.ConvHyperparams); + } + if (other.NumSpatialBinsHeight != 0) { + NumSpatialBinsHeight = other.NumSpatialBinsHeight; + } + if (other.NumSpatialBinsWidth != 0) { + NumSpatialBinsWidth = other.NumSpatialBinsWidth; + } + if (other.Depth != 0) { + Depth = other.Depth; + } + if (other.BoxCodeSize != 0) { + BoxCodeSize = other.BoxCodeSize; + } + if (other.CropHeight != 0) { + CropHeight = other.CropHeight; + } + if (other.CropWidth != 0) { + CropWidth = other.CropWidth; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + input.ReadMessage(convHyperparams_); + break; + } + case 16: { + NumSpatialBinsHeight = input.ReadInt32(); + break; + } + case 24: { + NumSpatialBinsWidth = input.ReadInt32(); + break; + } + case 32: { + Depth = input.ReadInt32(); + break; + } + case 40: { + BoxCodeSize = input.ReadInt32(); + break; + } + case 48: { + CropHeight = input.ReadInt32(); + break; + } + case 56: { + CropWidth = input.ReadInt32(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Calibration.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Calibration.cs new file mode 100644 index 00000000..7bd62280 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Calibration.cs @@ -0,0 +1,1413 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/calibration.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/calibration.proto + public static partial class CalibrationReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/calibration.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static CalibrationReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CilvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9jYWxpYnJhdGlvbi5wcm90bxIX", + "b2JqZWN0X2RldGVjdGlvbi5wcm90b3MigQMKEUNhbGlicmF0aW9uQ29uZmln", + "ElAKFmZ1bmN0aW9uX2FwcHJveGltYXRpb24YASABKAsyLi5vYmplY3RfZGV0", + "ZWN0aW9uLnByb3Rvcy5GdW5jdGlvbkFwcHJveGltYXRpb25IABJiCiBjbGFz", + "c19pZF9mdW5jdGlvbl9hcHByb3hpbWF0aW9ucxgCIAEoCzI2Lm9iamVjdF9k", + "ZXRlY3Rpb24ucHJvdG9zLkNsYXNzSWRGdW5jdGlvbkFwcHJveGltYXRpb25z", + "SAASSgoTc2lnbW9pZF9jYWxpYnJhdGlvbhgDIAEoCzIrLm9iamVjdF9kZXRl", + "Y3Rpb24ucHJvdG9zLlNpZ21vaWRDYWxpYnJhdGlvbkgAElwKHWNsYXNzX2lk", + "X3NpZ21vaWRfY2FsaWJyYXRpb25zGAQgASgLMjMub2JqZWN0X2RldGVjdGlv", + "bi5wcm90b3MuQ2xhc3NJZFNpZ21vaWRDYWxpYnJhdGlvbnNIAEIMCgpjYWxp", + "YnJhdG9yIkwKFUZ1bmN0aW9uQXBwcm94aW1hdGlvbhIzCgl4X3lfcGFpcnMY", + "ASABKAsyIC5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5YWVBhaXJzIukBCh1D", + "bGFzc0lkRnVuY3Rpb25BcHByb3hpbWF0aW9ucxJsChVjbGFzc19pZF94eV9w", + "YWlyc19tYXAYASADKAsyTS5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5DbGFz", + "c0lkRnVuY3Rpb25BcHByb3hpbWF0aW9ucy5DbGFzc0lkWHlQYWlyc01hcEVu", + "dHJ5GloKFkNsYXNzSWRYeVBhaXJzTWFwRW50cnkSCwoDa2V5GAEgASgFEi8K", + "BXZhbHVlGAIgASgLMiAub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuWFlQYWly", + "czoCOAEiXAoSU2lnbW9pZENhbGlicmF0aW9uEkYKEnNpZ21vaWRfcGFyYW1l", + "dGVycxgBIAEoCzIqLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlNpZ21vaWRQ", + "YXJhbWV0ZXJzIosCChpDbGFzc0lkU2lnbW9pZENhbGlicmF0aW9ucxJ9Ch9j", + "bGFzc19pZF9zaWdtb2lkX3BhcmFtZXRlcnNfbWFwGAEgAygLMlQub2JqZWN0", + "X2RldGVjdGlvbi5wcm90b3MuQ2xhc3NJZFNpZ21vaWRDYWxpYnJhdGlvbnMu", + "Q2xhc3NJZFNpZ21vaWRQYXJhbWV0ZXJzTWFwRW50cnkabgogQ2xhc3NJZFNp", + "Z21vaWRQYXJhbWV0ZXJzTWFwRW50cnkSCwoDa2V5GAEgASgFEjkKBXZhbHVl", + "GAIgASgLMioub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuU2lnbW9pZFBhcmFt", + "ZXRlcnM6AjgBIqsBCgdYWVBhaXJzEjkKCHhfeV9wYWlyGAEgAygLMicub2Jq", + "ZWN0X2RldGVjdGlvbi5wcm90b3MuWFlQYWlycy5YWVBhaXISRQoSdHJhaW5p", + "bmdfZGF0YV90eXBlGAIgASgOMikub2JqZWN0X2RldGVjdGlvbi5wcm90b3Mu", + "VHJhaW5pbmdEYXRhVHlwZRoeCgZYWVBhaXISCQoBeBgBIAEoAhIJCgF5GAIg", + "ASgCIikKEVNpZ21vaWRQYXJhbWV0ZXJzEgkKAWEYASABKAISCQoBYhgCIAEo", + "AipOChBUcmFpbmluZ0RhdGFUeXBlEhUKEURBVEFfVFlQRV9VTktOT1dOEAAS", + "DwoLQUxMX0NMQVNTRVMQARISCg5DTEFTU19TUEVDSUZJQxACYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.Models.ObjectDetection.Protos.TrainingDataType), }, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.CalibrationConfig), global::Tensorflow.Models.ObjectDetection.Protos.CalibrationConfig.Parser, new[]{ "FunctionApproximation", "ClassIdFunctionApproximations", "SigmoidCalibration", "ClassIdSigmoidCalibrations" }, new[]{ "Calibrator" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.FunctionApproximation), global::Tensorflow.Models.ObjectDetection.Protos.FunctionApproximation.Parser, new[]{ "XYPairs" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ClassIdFunctionApproximations), global::Tensorflow.Models.ObjectDetection.Protos.ClassIdFunctionApproximations.Parser, new[]{ "ClassIdXyPairsMap" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SigmoidCalibration), global::Tensorflow.Models.ObjectDetection.Protos.SigmoidCalibration.Parser, new[]{ "SigmoidParameters" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ClassIdSigmoidCalibrations), global::Tensorflow.Models.ObjectDetection.Protos.ClassIdSigmoidCalibrations.Parser, new[]{ "ClassIdSigmoidParametersMap" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.XYPairs), global::Tensorflow.Models.ObjectDetection.Protos.XYPairs.Parser, new[]{ "XYPair", "TrainingDataType" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.XYPairs.Types.XYPair), global::Tensorflow.Models.ObjectDetection.Protos.XYPairs.Types.XYPair.Parser, new[]{ "X", "Y" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SigmoidParameters), global::Tensorflow.Models.ObjectDetection.Protos.SigmoidParameters.Parser, new[]{ "A", "B" }, null, null, null) + })); + } + #endregion + + } + #region Enums + /// + /// Description of data used to fit the calibration model. CLASS_SPECIFIC + /// indicates that the calibration parameters are derived from detections + /// pertaining to a single class. ALL_CLASSES indicates that parameters were + /// obtained by fitting a model on detections from all classes (including the + /// background class). + /// + public enum TrainingDataType { + [pbr::OriginalName("DATA_TYPE_UNKNOWN")] DataTypeUnknown = 0, + [pbr::OriginalName("ALL_CLASSES")] AllClasses = 1, + [pbr::OriginalName("CLASS_SPECIFIC")] ClassSpecific = 2, + } + + #endregion + + #region Messages + /// + /// Message wrapper for various calibration configurations. + /// + public sealed partial class CalibrationConfig : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CalibrationConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.CalibrationReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CalibrationConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CalibrationConfig(CalibrationConfig other) : this() { + switch (other.CalibratorCase) { + case CalibratorOneofCase.FunctionApproximation: + FunctionApproximation = other.FunctionApproximation.Clone(); + break; + case CalibratorOneofCase.ClassIdFunctionApproximations: + ClassIdFunctionApproximations = other.ClassIdFunctionApproximations.Clone(); + break; + case CalibratorOneofCase.SigmoidCalibration: + SigmoidCalibration = other.SigmoidCalibration.Clone(); + break; + case CalibratorOneofCase.ClassIdSigmoidCalibrations: + ClassIdSigmoidCalibrations = other.ClassIdSigmoidCalibrations.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CalibrationConfig Clone() { + return new CalibrationConfig(this); + } + + /// Field number for the "function_approximation" field. + public const int FunctionApproximationFieldNumber = 1; + /// + /// Class-agnostic calibration via linear interpolation (usually output from + /// isotonic regression). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.FunctionApproximation FunctionApproximation { + get { return calibratorCase_ == CalibratorOneofCase.FunctionApproximation ? (global::Tensorflow.Models.ObjectDetection.Protos.FunctionApproximation) calibrator_ : null; } + set { + calibrator_ = value; + calibratorCase_ = value == null ? CalibratorOneofCase.None : CalibratorOneofCase.FunctionApproximation; + } + } + + /// Field number for the "class_id_function_approximations" field. + public const int ClassIdFunctionApproximationsFieldNumber = 2; + /// + /// Per-class calibration via linear interpolation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ClassIdFunctionApproximations ClassIdFunctionApproximations { + get { return calibratorCase_ == CalibratorOneofCase.ClassIdFunctionApproximations ? (global::Tensorflow.Models.ObjectDetection.Protos.ClassIdFunctionApproximations) calibrator_ : null; } + set { + calibrator_ = value; + calibratorCase_ = value == null ? CalibratorOneofCase.None : CalibratorOneofCase.ClassIdFunctionApproximations; + } + } + + /// Field number for the "sigmoid_calibration" field. + public const int SigmoidCalibrationFieldNumber = 3; + /// + /// Class-agnostic sigmoid calibration. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SigmoidCalibration SigmoidCalibration { + get { return calibratorCase_ == CalibratorOneofCase.SigmoidCalibration ? (global::Tensorflow.Models.ObjectDetection.Protos.SigmoidCalibration) calibrator_ : null; } + set { + calibrator_ = value; + calibratorCase_ = value == null ? CalibratorOneofCase.None : CalibratorOneofCase.SigmoidCalibration; + } + } + + /// Field number for the "class_id_sigmoid_calibrations" field. + public const int ClassIdSigmoidCalibrationsFieldNumber = 4; + /// + /// Per-class sigmoid calibration. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ClassIdSigmoidCalibrations ClassIdSigmoidCalibrations { + get { return calibratorCase_ == CalibratorOneofCase.ClassIdSigmoidCalibrations ? (global::Tensorflow.Models.ObjectDetection.Protos.ClassIdSigmoidCalibrations) calibrator_ : null; } + set { + calibrator_ = value; + calibratorCase_ = value == null ? CalibratorOneofCase.None : CalibratorOneofCase.ClassIdSigmoidCalibrations; + } + } + + private object calibrator_; + /// Enum of possible cases for the "calibrator" oneof. + public enum CalibratorOneofCase { + None = 0, + FunctionApproximation = 1, + ClassIdFunctionApproximations = 2, + SigmoidCalibration = 3, + ClassIdSigmoidCalibrations = 4, + } + private CalibratorOneofCase calibratorCase_ = CalibratorOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CalibratorOneofCase CalibratorCase { + get { return calibratorCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearCalibrator() { + calibratorCase_ = CalibratorOneofCase.None; + calibrator_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as CalibrationConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(CalibrationConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(FunctionApproximation, other.FunctionApproximation)) return false; + if (!object.Equals(ClassIdFunctionApproximations, other.ClassIdFunctionApproximations)) return false; + if (!object.Equals(SigmoidCalibration, other.SigmoidCalibration)) return false; + if (!object.Equals(ClassIdSigmoidCalibrations, other.ClassIdSigmoidCalibrations)) return false; + if (CalibratorCase != other.CalibratorCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (calibratorCase_ == CalibratorOneofCase.FunctionApproximation) hash ^= FunctionApproximation.GetHashCode(); + if (calibratorCase_ == CalibratorOneofCase.ClassIdFunctionApproximations) hash ^= ClassIdFunctionApproximations.GetHashCode(); + if (calibratorCase_ == CalibratorOneofCase.SigmoidCalibration) hash ^= SigmoidCalibration.GetHashCode(); + if (calibratorCase_ == CalibratorOneofCase.ClassIdSigmoidCalibrations) hash ^= ClassIdSigmoidCalibrations.GetHashCode(); + hash ^= (int) calibratorCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (calibratorCase_ == CalibratorOneofCase.FunctionApproximation) { + output.WriteRawTag(10); + output.WriteMessage(FunctionApproximation); + } + if (calibratorCase_ == CalibratorOneofCase.ClassIdFunctionApproximations) { + output.WriteRawTag(18); + output.WriteMessage(ClassIdFunctionApproximations); + } + if (calibratorCase_ == CalibratorOneofCase.SigmoidCalibration) { + output.WriteRawTag(26); + output.WriteMessage(SigmoidCalibration); + } + if (calibratorCase_ == CalibratorOneofCase.ClassIdSigmoidCalibrations) { + output.WriteRawTag(34); + output.WriteMessage(ClassIdSigmoidCalibrations); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (calibratorCase_ == CalibratorOneofCase.FunctionApproximation) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FunctionApproximation); + } + if (calibratorCase_ == CalibratorOneofCase.ClassIdFunctionApproximations) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ClassIdFunctionApproximations); + } + if (calibratorCase_ == CalibratorOneofCase.SigmoidCalibration) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SigmoidCalibration); + } + if (calibratorCase_ == CalibratorOneofCase.ClassIdSigmoidCalibrations) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ClassIdSigmoidCalibrations); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(CalibrationConfig other) { + if (other == null) { + return; + } + switch (other.CalibratorCase) { + case CalibratorOneofCase.FunctionApproximation: + if (FunctionApproximation == null) { + FunctionApproximation = new global::Tensorflow.Models.ObjectDetection.Protos.FunctionApproximation(); + } + FunctionApproximation.MergeFrom(other.FunctionApproximation); + break; + case CalibratorOneofCase.ClassIdFunctionApproximations: + if (ClassIdFunctionApproximations == null) { + ClassIdFunctionApproximations = new global::Tensorflow.Models.ObjectDetection.Protos.ClassIdFunctionApproximations(); + } + ClassIdFunctionApproximations.MergeFrom(other.ClassIdFunctionApproximations); + break; + case CalibratorOneofCase.SigmoidCalibration: + if (SigmoidCalibration == null) { + SigmoidCalibration = new global::Tensorflow.Models.ObjectDetection.Protos.SigmoidCalibration(); + } + SigmoidCalibration.MergeFrom(other.SigmoidCalibration); + break; + case CalibratorOneofCase.ClassIdSigmoidCalibrations: + if (ClassIdSigmoidCalibrations == null) { + ClassIdSigmoidCalibrations = new global::Tensorflow.Models.ObjectDetection.Protos.ClassIdSigmoidCalibrations(); + } + ClassIdSigmoidCalibrations.MergeFrom(other.ClassIdSigmoidCalibrations); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.FunctionApproximation subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.FunctionApproximation(); + if (calibratorCase_ == CalibratorOneofCase.FunctionApproximation) { + subBuilder.MergeFrom(FunctionApproximation); + } + input.ReadMessage(subBuilder); + FunctionApproximation = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.ClassIdFunctionApproximations subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ClassIdFunctionApproximations(); + if (calibratorCase_ == CalibratorOneofCase.ClassIdFunctionApproximations) { + subBuilder.MergeFrom(ClassIdFunctionApproximations); + } + input.ReadMessage(subBuilder); + ClassIdFunctionApproximations = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.SigmoidCalibration subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.SigmoidCalibration(); + if (calibratorCase_ == CalibratorOneofCase.SigmoidCalibration) { + subBuilder.MergeFrom(SigmoidCalibration); + } + input.ReadMessage(subBuilder); + SigmoidCalibration = subBuilder; + break; + } + case 34: { + global::Tensorflow.Models.ObjectDetection.Protos.ClassIdSigmoidCalibrations subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ClassIdSigmoidCalibrations(); + if (calibratorCase_ == CalibratorOneofCase.ClassIdSigmoidCalibrations) { + subBuilder.MergeFrom(ClassIdSigmoidCalibrations); + } + input.ReadMessage(subBuilder); + ClassIdSigmoidCalibrations = subBuilder; + break; + } + } + } + } + + } + + /// + /// Message for class-agnostic domain/range mapping for function + /// approximations. + /// + public sealed partial class FunctionApproximation : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FunctionApproximation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.CalibrationReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionApproximation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionApproximation(FunctionApproximation other) : this() { + xYPairs_ = other.xYPairs_ != null ? other.xYPairs_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FunctionApproximation Clone() { + return new FunctionApproximation(this); + } + + /// Field number for the "x_y_pairs" field. + public const int XYPairsFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.XYPairs xYPairs_; + /// + /// Message mapping class labels to indices + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.XYPairs XYPairs { + get { return xYPairs_; } + set { + xYPairs_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FunctionApproximation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FunctionApproximation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(XYPairs, other.XYPairs)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (xYPairs_ != null) hash ^= XYPairs.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (xYPairs_ != null) { + output.WriteRawTag(10); + output.WriteMessage(XYPairs); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (xYPairs_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(XYPairs); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FunctionApproximation other) { + if (other == null) { + return; + } + if (other.xYPairs_ != null) { + if (xYPairs_ == null) { + xYPairs_ = new global::Tensorflow.Models.ObjectDetection.Protos.XYPairs(); + } + XYPairs.MergeFrom(other.XYPairs); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (xYPairs_ == null) { + xYPairs_ = new global::Tensorflow.Models.ObjectDetection.Protos.XYPairs(); + } + input.ReadMessage(xYPairs_); + break; + } + } + } + } + + } + + /// + /// Message for class-specific domain/range mapping for function + /// approximations. + /// + public sealed partial class ClassIdFunctionApproximations : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ClassIdFunctionApproximations()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.CalibrationReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassIdFunctionApproximations() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassIdFunctionApproximations(ClassIdFunctionApproximations other) : this() { + classIdXyPairsMap_ = other.classIdXyPairsMap_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassIdFunctionApproximations Clone() { + return new ClassIdFunctionApproximations(this); + } + + /// Field number for the "class_id_xy_pairs_map" field. + public const int ClassIdXyPairsMapFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_classIdXyPairsMap_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForInt32(8), pb::FieldCodec.ForMessage(18, global::Tensorflow.Models.ObjectDetection.Protos.XYPairs.Parser), 10); + private readonly pbc::MapField classIdXyPairsMap_ = new pbc::MapField(); + /// + /// Message mapping class ids to indices. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField ClassIdXyPairsMap { + get { return classIdXyPairsMap_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ClassIdFunctionApproximations); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ClassIdFunctionApproximations other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!ClassIdXyPairsMap.Equals(other.ClassIdXyPairsMap)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= ClassIdXyPairsMap.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + classIdXyPairsMap_.WriteTo(output, _map_classIdXyPairsMap_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += classIdXyPairsMap_.CalculateSize(_map_classIdXyPairsMap_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ClassIdFunctionApproximations other) { + if (other == null) { + return; + } + classIdXyPairsMap_.Add(other.classIdXyPairsMap_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + classIdXyPairsMap_.AddEntriesFrom(input, _map_classIdXyPairsMap_codec); + break; + } + } + } + } + + } + + /// + /// Message for class-agnostic Sigmoid Calibration. + /// + public sealed partial class SigmoidCalibration : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SigmoidCalibration()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.CalibrationReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SigmoidCalibration() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SigmoidCalibration(SigmoidCalibration other) : this() { + sigmoidParameters_ = other.sigmoidParameters_ != null ? other.sigmoidParameters_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SigmoidCalibration Clone() { + return new SigmoidCalibration(this); + } + + /// Field number for the "sigmoid_parameters" field. + public const int SigmoidParametersFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.SigmoidParameters sigmoidParameters_; + /// + /// Message mapping class index to Sigmoid Parameters + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SigmoidParameters SigmoidParameters { + get { return sigmoidParameters_; } + set { + sigmoidParameters_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SigmoidCalibration); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SigmoidCalibration other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(SigmoidParameters, other.SigmoidParameters)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (sigmoidParameters_ != null) hash ^= SigmoidParameters.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (sigmoidParameters_ != null) { + output.WriteRawTag(10); + output.WriteMessage(SigmoidParameters); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (sigmoidParameters_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SigmoidParameters); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SigmoidCalibration other) { + if (other == null) { + return; + } + if (other.sigmoidParameters_ != null) { + if (sigmoidParameters_ == null) { + sigmoidParameters_ = new global::Tensorflow.Models.ObjectDetection.Protos.SigmoidParameters(); + } + SigmoidParameters.MergeFrom(other.SigmoidParameters); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (sigmoidParameters_ == null) { + sigmoidParameters_ = new global::Tensorflow.Models.ObjectDetection.Protos.SigmoidParameters(); + } + input.ReadMessage(sigmoidParameters_); + break; + } + } + } + } + + } + + /// + /// Message for class-specific Sigmoid Calibration. + /// + public sealed partial class ClassIdSigmoidCalibrations : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ClassIdSigmoidCalibrations()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.CalibrationReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassIdSigmoidCalibrations() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassIdSigmoidCalibrations(ClassIdSigmoidCalibrations other) : this() { + classIdSigmoidParametersMap_ = other.classIdSigmoidParametersMap_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassIdSigmoidCalibrations Clone() { + return new ClassIdSigmoidCalibrations(this); + } + + /// Field number for the "class_id_sigmoid_parameters_map" field. + public const int ClassIdSigmoidParametersMapFieldNumber = 1; + private static readonly pbc::MapField.Codec _map_classIdSigmoidParametersMap_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForInt32(8), pb::FieldCodec.ForMessage(18, global::Tensorflow.Models.ObjectDetection.Protos.SigmoidParameters.Parser), 10); + private readonly pbc::MapField classIdSigmoidParametersMap_ = new pbc::MapField(); + /// + /// Message mapping class index to Sigmoid Parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField ClassIdSigmoidParametersMap { + get { return classIdSigmoidParametersMap_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ClassIdSigmoidCalibrations); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ClassIdSigmoidCalibrations other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!ClassIdSigmoidParametersMap.Equals(other.ClassIdSigmoidParametersMap)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= ClassIdSigmoidParametersMap.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + classIdSigmoidParametersMap_.WriteTo(output, _map_classIdSigmoidParametersMap_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += classIdSigmoidParametersMap_.CalculateSize(_map_classIdSigmoidParametersMap_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ClassIdSigmoidCalibrations other) { + if (other == null) { + return; + } + classIdSigmoidParametersMap_.Add(other.classIdSigmoidParametersMap_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + classIdSigmoidParametersMap_.AddEntriesFrom(input, _map_classIdSigmoidParametersMap_codec); + break; + } + } + } + } + + } + + /// + /// Message to store a domain/range pair for function to be approximated. + /// + public sealed partial class XYPairs : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new XYPairs()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.CalibrationReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public XYPairs() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public XYPairs(XYPairs other) : this() { + xYPair_ = other.xYPair_.Clone(); + trainingDataType_ = other.trainingDataType_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public XYPairs Clone() { + return new XYPairs(this); + } + + /// Field number for the "x_y_pair" field. + public const int XYPairFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_xYPair_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.Models.ObjectDetection.Protos.XYPairs.Types.XYPair.Parser); + private readonly pbc::RepeatedField xYPair_ = new pbc::RepeatedField(); + /// + /// Sequence of x/y pairs for function approximation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField XYPair { + get { return xYPair_; } + } + + /// Field number for the "training_data_type" field. + public const int TrainingDataTypeFieldNumber = 2; + private global::Tensorflow.Models.ObjectDetection.Protos.TrainingDataType trainingDataType_ = 0; + /// + /// Description of data used to fit the calibration model. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.TrainingDataType TrainingDataType { + get { return trainingDataType_; } + set { + trainingDataType_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as XYPairs); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(XYPairs other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!xYPair_.Equals(other.xYPair_)) return false; + if (TrainingDataType != other.TrainingDataType) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= xYPair_.GetHashCode(); + if (TrainingDataType != 0) hash ^= TrainingDataType.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + xYPair_.WriteTo(output, _repeated_xYPair_codec); + if (TrainingDataType != 0) { + output.WriteRawTag(16); + output.WriteEnum((int) TrainingDataType); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += xYPair_.CalculateSize(_repeated_xYPair_codec); + if (TrainingDataType != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) TrainingDataType); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(XYPairs other) { + if (other == null) { + return; + } + xYPair_.Add(other.xYPair_); + if (other.TrainingDataType != 0) { + TrainingDataType = other.TrainingDataType; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + xYPair_.AddEntriesFrom(input, _repeated_xYPair_codec); + break; + } + case 16: { + trainingDataType_ = (global::Tensorflow.Models.ObjectDetection.Protos.TrainingDataType) input.ReadEnum(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the XYPairs message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + public sealed partial class XYPair : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new XYPair()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.XYPairs.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public XYPair() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public XYPair(XYPair other) : this() { + x_ = other.x_; + y_ = other.y_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public XYPair Clone() { + return new XYPair(this); + } + + /// Field number for the "x" field. + public const int XFieldNumber = 1; + private float x_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float X { + get { return x_; } + set { + x_ = value; + } + } + + /// Field number for the "y" field. + public const int YFieldNumber = 2; + private float y_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Y { + get { return y_; } + set { + y_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as XYPair); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(XYPair other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(X, other.X)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Y, other.Y)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (X != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(X); + if (Y != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Y); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (X != 0F) { + output.WriteRawTag(13); + output.WriteFloat(X); + } + if (Y != 0F) { + output.WriteRawTag(21); + output.WriteFloat(Y); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (X != 0F) { + size += 1 + 4; + } + if (Y != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(XYPair other) { + if (other == null) { + return; + } + if (other.X != 0F) { + X = other.X; + } + if (other.Y != 0F) { + Y = other.Y; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + X = input.ReadFloat(); + break; + } + case 21: { + Y = input.ReadFloat(); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// Message defining parameters for sigmoid calibration. + /// + public sealed partial class SigmoidParameters : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SigmoidParameters()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.CalibrationReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SigmoidParameters() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SigmoidParameters(SigmoidParameters other) : this() { + a_ = other.a_; + b_ = other.b_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SigmoidParameters Clone() { + return new SigmoidParameters(this); + } + + /// Field number for the "a" field. + public const int AFieldNumber = 1; + private float a_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float A { + get { return a_; } + set { + a_ = value; + } + } + + /// Field number for the "b" field. + public const int BFieldNumber = 2; + private float b_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float B { + get { return b_; } + set { + b_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SigmoidParameters); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SigmoidParameters other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(A, other.A)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(B, other.B)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (A != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(A); + if (B != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(B); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (A != 0F) { + output.WriteRawTag(13); + output.WriteFloat(A); + } + if (B != 0F) { + output.WriteRawTag(21); + output.WriteFloat(B); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (A != 0F) { + size += 1 + 4; + } + if (B != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SigmoidParameters other) { + if (other == null) { + return; + } + if (other.A != 0F) { + A = other.A; + } + if (other.B != 0F) { + B = other.B; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + A = input.ReadFloat(); + break; + } + case 21: { + B = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Eval.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Eval.cs new file mode 100644 index 00000000..b33b5869 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Eval.cs @@ -0,0 +1,901 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/eval.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/eval.proto + public static partial class EvalReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/eval.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static EvalReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiJvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9ldmFsLnByb3RvEhdvYmplY3Rf", + "ZGV0ZWN0aW9uLnByb3RvcyK3BQoKRXZhbENvbmZpZxISCgpiYXRjaF9zaXpl", + "GBkgASgNEhoKEm51bV92aXN1YWxpemF0aW9ucxgBIAEoDRIUCgxudW1fZXhh", + "bXBsZXMYAiABKA0SGgoSZXZhbF9pbnRlcnZhbF9zZWNzGAMgASgNEhEKCW1h", + "eF9ldmFscxgEIAEoDRISCgpzYXZlX2dyYXBoGAUgASgIEiAKGHZpc3VhbGl6", + "YXRpb25fZXhwb3J0X2RpchgGIAEoCRITCgtldmFsX21hc3RlchgHIAEoCRIT", + "CgttZXRyaWNzX3NldBgIIAMoCRITCgtleHBvcnRfcGF0aBgJIAEoCRIaChJp", + "Z25vcmVfZ3JvdW5kdHJ1dGgYCiABKAgSGwoTdXNlX21vdmluZ19hdmVyYWdl", + "cxgLIAEoCBIbChNldmFsX2luc3RhbmNlX21hc2tzGAwgASgIEhsKE21pbl9z", + "Y29yZV90aHJlc2hvbGQYDSABKAISIgoabWF4X251bV9ib3hlc190b192aXN1", + "YWxpemUYDiABKAUSEwoLc2tpcF9zY29yZXMYDyABKAgSEwoLc2tpcF9sYWJl", + "bHMYECABKAgSIwobdmlzdWFsaXplX2dyb3VuZHRydXRoX2JveGVzGBEgASgI", + "EisKI2dyb3VuZHRydXRoX2JveF92aXN1YWxpemF0aW9uX2NvbG9yGBIgASgJ", + "Ei4KJmtlZXBfaW1hZ2VfaWRfZm9yX3Zpc3VhbGl6YXRpb25fZXhwb3J0GBMg", + "ASgIEh4KFnJldGFpbl9vcmlnaW5hbF9pbWFnZXMYFyABKAgSJAocaW5jbHVk", + "ZV9tZXRyaWNzX3Blcl9jYXRlZ29yeRgYIAEoCBIaChJyZWNhbGxfbG93ZXJf", + "Ym91bmQYGiABKAISGgoScmVjYWxsX3VwcGVyX2JvdW5kGBsgASgCYgZwcm90", + "bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.EvalConfig), global::Tensorflow.Models.ObjectDetection.Protos.EvalConfig.Parser, new[]{ "BatchSize", "NumVisualizations", "NumExamples", "EvalIntervalSecs", "MaxEvals", "SaveGraph", "VisualizationExportDir", "EvalMaster", "MetricsSet", "ExportPath", "IgnoreGroundtruth", "UseMovingAverages", "EvalInstanceMasks", "MinScoreThreshold", "MaxNumBoxesToVisualize", "SkipScores", "SkipLabels", "VisualizeGroundtruthBoxes", "GroundtruthBoxVisualizationColor", "KeepImageIdForVisualizationExport", "RetainOriginalImages", "IncludeMetricsPerCategory", "RecallLowerBound", "RecallUpperBound" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Message for configuring DetectionModel evaluation jobs (eval.py). + /// + public sealed partial class EvalConfig : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EvalConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.EvalReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EvalConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EvalConfig(EvalConfig other) : this() { + batchSize_ = other.batchSize_; + numVisualizations_ = other.numVisualizations_; + numExamples_ = other.numExamples_; + evalIntervalSecs_ = other.evalIntervalSecs_; + maxEvals_ = other.maxEvals_; + saveGraph_ = other.saveGraph_; + visualizationExportDir_ = other.visualizationExportDir_; + evalMaster_ = other.evalMaster_; + metricsSet_ = other.metricsSet_.Clone(); + exportPath_ = other.exportPath_; + ignoreGroundtruth_ = other.ignoreGroundtruth_; + useMovingAverages_ = other.useMovingAverages_; + evalInstanceMasks_ = other.evalInstanceMasks_; + minScoreThreshold_ = other.minScoreThreshold_; + maxNumBoxesToVisualize_ = other.maxNumBoxesToVisualize_; + skipScores_ = other.skipScores_; + skipLabels_ = other.skipLabels_; + visualizeGroundtruthBoxes_ = other.visualizeGroundtruthBoxes_; + groundtruthBoxVisualizationColor_ = other.groundtruthBoxVisualizationColor_; + keepImageIdForVisualizationExport_ = other.keepImageIdForVisualizationExport_; + retainOriginalImages_ = other.retainOriginalImages_; + includeMetricsPerCategory_ = other.includeMetricsPerCategory_; + recallLowerBound_ = other.recallLowerBound_; + recallUpperBound_ = other.recallUpperBound_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EvalConfig Clone() { + return new EvalConfig(this); + } + + /// Field number for the "batch_size" field. + public const int BatchSizeFieldNumber = 25; + private uint batchSize_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint BatchSize { + get { return batchSize_; } + set { + batchSize_ = value; + } + } + + /// Field number for the "num_visualizations" field. + public const int NumVisualizationsFieldNumber = 1; + private uint numVisualizations_; + /// + /// Number of visualization images to generate. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint NumVisualizations { + get { return numVisualizations_; } + set { + numVisualizations_ = value; + } + } + + /// Field number for the "num_examples" field. + public const int NumExamplesFieldNumber = 2; + private uint numExamples_; + /// + /// Number of examples to process of evaluation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint NumExamples { + get { return numExamples_; } + set { + numExamples_ = value; + } + } + + /// Field number for the "eval_interval_secs" field. + public const int EvalIntervalSecsFieldNumber = 3; + private uint evalIntervalSecs_; + /// + /// How often to run evaluation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint EvalIntervalSecs { + get { return evalIntervalSecs_; } + set { + evalIntervalSecs_ = value; + } + } + + /// Field number for the "max_evals" field. + public const int MaxEvalsFieldNumber = 4; + private uint maxEvals_; + /// + /// Maximum number of times to run evaluation. If set to 0, will run forever. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint MaxEvals { + get { return maxEvals_; } + set { + maxEvals_ = value; + } + } + + /// Field number for the "save_graph" field. + public const int SaveGraphFieldNumber = 5; + private bool saveGraph_; + /// + /// Whether the TensorFlow graph used for evaluation should be saved to disk. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SaveGraph { + get { return saveGraph_; } + set { + saveGraph_ = value; + } + } + + /// Field number for the "visualization_export_dir" field. + public const int VisualizationExportDirFieldNumber = 6; + private string visualizationExportDir_ = ""; + /// + /// Path to directory to store visualizations in. If empty, visualization + /// images are not exported (only shown on Tensorboard). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string VisualizationExportDir { + get { return visualizationExportDir_; } + set { + visualizationExportDir_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "eval_master" field. + public const int EvalMasterFieldNumber = 7; + private string evalMaster_ = ""; + /// + /// BNS name of the TensorFlow master. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string EvalMaster { + get { return evalMaster_; } + set { + evalMaster_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "metrics_set" field. + public const int MetricsSetFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_metricsSet_codec + = pb::FieldCodec.ForString(66); + private readonly pbc::RepeatedField metricsSet_ = new pbc::RepeatedField(); + /// + /// Type of metrics to use for evaluation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField MetricsSet { + get { return metricsSet_; } + } + + /// Field number for the "export_path" field. + public const int ExportPathFieldNumber = 9; + private string exportPath_ = ""; + /// + /// Path to export detections to COCO compatible JSON format. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ExportPath { + get { return exportPath_; } + set { + exportPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "ignore_groundtruth" field. + public const int IgnoreGroundtruthFieldNumber = 10; + private bool ignoreGroundtruth_; + /// + /// Option to not read groundtruth labels and only export detections to + /// COCO-compatible JSON file. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IgnoreGroundtruth { + get { return ignoreGroundtruth_; } + set { + ignoreGroundtruth_ = value; + } + } + + /// Field number for the "use_moving_averages" field. + public const int UseMovingAveragesFieldNumber = 11; + private bool useMovingAverages_; + /// + /// Use exponential moving averages of variables for evaluation. + /// TODO(rathodv): When this is false make sure the model is constructed + /// without moving averages in restore_fn. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseMovingAverages { + get { return useMovingAverages_; } + set { + useMovingAverages_ = value; + } + } + + /// Field number for the "eval_instance_masks" field. + public const int EvalInstanceMasksFieldNumber = 12; + private bool evalInstanceMasks_; + /// + /// Whether to evaluate instance masks. + /// Note that since there is no evaluation code currently for instance + /// segmenation this option is unused. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool EvalInstanceMasks { + get { return evalInstanceMasks_; } + set { + evalInstanceMasks_ = value; + } + } + + /// Field number for the "min_score_threshold" field. + public const int MinScoreThresholdFieldNumber = 13; + private float minScoreThreshold_; + /// + /// Minimum score threshold for a detected object box to be visualized + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinScoreThreshold { + get { return minScoreThreshold_; } + set { + minScoreThreshold_ = value; + } + } + + /// Field number for the "max_num_boxes_to_visualize" field. + public const int MaxNumBoxesToVisualizeFieldNumber = 14; + private int maxNumBoxesToVisualize_; + /// + /// Maximum number of detections to visualize + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxNumBoxesToVisualize { + get { return maxNumBoxesToVisualize_; } + set { + maxNumBoxesToVisualize_ = value; + } + } + + /// Field number for the "skip_scores" field. + public const int SkipScoresFieldNumber = 15; + private bool skipScores_; + /// + /// When drawing a single detection, each label is by default visualized as + /// <label name> : <label score>. One can skip the name or/and score using the + /// following fields: + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SkipScores { + get { return skipScores_; } + set { + skipScores_ = value; + } + } + + /// Field number for the "skip_labels" field. + public const int SkipLabelsFieldNumber = 16; + private bool skipLabels_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SkipLabels { + get { return skipLabels_; } + set { + skipLabels_ = value; + } + } + + /// Field number for the "visualize_groundtruth_boxes" field. + public const int VisualizeGroundtruthBoxesFieldNumber = 17; + private bool visualizeGroundtruthBoxes_; + /// + /// Whether to show groundtruth boxes in addition to detected boxes in + /// visualizations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool VisualizeGroundtruthBoxes { + get { return visualizeGroundtruthBoxes_; } + set { + visualizeGroundtruthBoxes_ = value; + } + } + + /// Field number for the "groundtruth_box_visualization_color" field. + public const int GroundtruthBoxVisualizationColorFieldNumber = 18; + private string groundtruthBoxVisualizationColor_ = ""; + /// + /// Box color for visualizing groundtruth boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string GroundtruthBoxVisualizationColor { + get { return groundtruthBoxVisualizationColor_; } + set { + groundtruthBoxVisualizationColor_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "keep_image_id_for_visualization_export" field. + public const int KeepImageIdForVisualizationExportFieldNumber = 19; + private bool keepImageIdForVisualizationExport_; + /// + /// Whether to keep image identifier in filename when exported to + /// visualization_export_dir. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool KeepImageIdForVisualizationExport { + get { return keepImageIdForVisualizationExport_; } + set { + keepImageIdForVisualizationExport_ = value; + } + } + + /// Field number for the "retain_original_images" field. + public const int RetainOriginalImagesFieldNumber = 23; + private bool retainOriginalImages_; + /// + /// Whether to retain original images (i.e. not pre-processed) in the tensor + /// dictionary, so that they can be displayed in Tensorboard. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool RetainOriginalImages { + get { return retainOriginalImages_; } + set { + retainOriginalImages_ = value; + } + } + + /// Field number for the "include_metrics_per_category" field. + public const int IncludeMetricsPerCategoryFieldNumber = 24; + private bool includeMetricsPerCategory_; + /// + /// If True, additionally include per-category metrics. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IncludeMetricsPerCategory { + get { return includeMetricsPerCategory_; } + set { + includeMetricsPerCategory_ = value; + } + } + + /// Field number for the "recall_lower_bound" field. + public const int RecallLowerBoundFieldNumber = 26; + private float recallLowerBound_; + /// + /// Recall range within which precision should be computed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float RecallLowerBound { + get { return recallLowerBound_; } + set { + recallLowerBound_ = value; + } + } + + /// Field number for the "recall_upper_bound" field. + public const int RecallUpperBoundFieldNumber = 27; + private float recallUpperBound_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float RecallUpperBound { + get { return recallUpperBound_; } + set { + recallUpperBound_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as EvalConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(EvalConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (BatchSize != other.BatchSize) return false; + if (NumVisualizations != other.NumVisualizations) return false; + if (NumExamples != other.NumExamples) return false; + if (EvalIntervalSecs != other.EvalIntervalSecs) return false; + if (MaxEvals != other.MaxEvals) return false; + if (SaveGraph != other.SaveGraph) return false; + if (VisualizationExportDir != other.VisualizationExportDir) return false; + if (EvalMaster != other.EvalMaster) return false; + if(!metricsSet_.Equals(other.metricsSet_)) return false; + if (ExportPath != other.ExportPath) return false; + if (IgnoreGroundtruth != other.IgnoreGroundtruth) return false; + if (UseMovingAverages != other.UseMovingAverages) return false; + if (EvalInstanceMasks != other.EvalInstanceMasks) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinScoreThreshold, other.MinScoreThreshold)) return false; + if (MaxNumBoxesToVisualize != other.MaxNumBoxesToVisualize) return false; + if (SkipScores != other.SkipScores) return false; + if (SkipLabels != other.SkipLabels) return false; + if (VisualizeGroundtruthBoxes != other.VisualizeGroundtruthBoxes) return false; + if (GroundtruthBoxVisualizationColor != other.GroundtruthBoxVisualizationColor) return false; + if (KeepImageIdForVisualizationExport != other.KeepImageIdForVisualizationExport) return false; + if (RetainOriginalImages != other.RetainOriginalImages) return false; + if (IncludeMetricsPerCategory != other.IncludeMetricsPerCategory) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(RecallLowerBound, other.RecallLowerBound)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(RecallUpperBound, other.RecallUpperBound)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (BatchSize != 0) hash ^= BatchSize.GetHashCode(); + if (NumVisualizations != 0) hash ^= NumVisualizations.GetHashCode(); + if (NumExamples != 0) hash ^= NumExamples.GetHashCode(); + if (EvalIntervalSecs != 0) hash ^= EvalIntervalSecs.GetHashCode(); + if (MaxEvals != 0) hash ^= MaxEvals.GetHashCode(); + if (SaveGraph != false) hash ^= SaveGraph.GetHashCode(); + if (VisualizationExportDir.Length != 0) hash ^= VisualizationExportDir.GetHashCode(); + if (EvalMaster.Length != 0) hash ^= EvalMaster.GetHashCode(); + hash ^= metricsSet_.GetHashCode(); + if (ExportPath.Length != 0) hash ^= ExportPath.GetHashCode(); + if (IgnoreGroundtruth != false) hash ^= IgnoreGroundtruth.GetHashCode(); + if (UseMovingAverages != false) hash ^= UseMovingAverages.GetHashCode(); + if (EvalInstanceMasks != false) hash ^= EvalInstanceMasks.GetHashCode(); + if (MinScoreThreshold != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinScoreThreshold); + if (MaxNumBoxesToVisualize != 0) hash ^= MaxNumBoxesToVisualize.GetHashCode(); + if (SkipScores != false) hash ^= SkipScores.GetHashCode(); + if (SkipLabels != false) hash ^= SkipLabels.GetHashCode(); + if (VisualizeGroundtruthBoxes != false) hash ^= VisualizeGroundtruthBoxes.GetHashCode(); + if (GroundtruthBoxVisualizationColor.Length != 0) hash ^= GroundtruthBoxVisualizationColor.GetHashCode(); + if (KeepImageIdForVisualizationExport != false) hash ^= KeepImageIdForVisualizationExport.GetHashCode(); + if (RetainOriginalImages != false) hash ^= RetainOriginalImages.GetHashCode(); + if (IncludeMetricsPerCategory != false) hash ^= IncludeMetricsPerCategory.GetHashCode(); + if (RecallLowerBound != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(RecallLowerBound); + if (RecallUpperBound != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(RecallUpperBound); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (NumVisualizations != 0) { + output.WriteRawTag(8); + output.WriteUInt32(NumVisualizations); + } + if (NumExamples != 0) { + output.WriteRawTag(16); + output.WriteUInt32(NumExamples); + } + if (EvalIntervalSecs != 0) { + output.WriteRawTag(24); + output.WriteUInt32(EvalIntervalSecs); + } + if (MaxEvals != 0) { + output.WriteRawTag(32); + output.WriteUInt32(MaxEvals); + } + if (SaveGraph != false) { + output.WriteRawTag(40); + output.WriteBool(SaveGraph); + } + if (VisualizationExportDir.Length != 0) { + output.WriteRawTag(50); + output.WriteString(VisualizationExportDir); + } + if (EvalMaster.Length != 0) { + output.WriteRawTag(58); + output.WriteString(EvalMaster); + } + metricsSet_.WriteTo(output, _repeated_metricsSet_codec); + if (ExportPath.Length != 0) { + output.WriteRawTag(74); + output.WriteString(ExportPath); + } + if (IgnoreGroundtruth != false) { + output.WriteRawTag(80); + output.WriteBool(IgnoreGroundtruth); + } + if (UseMovingAverages != false) { + output.WriteRawTag(88); + output.WriteBool(UseMovingAverages); + } + if (EvalInstanceMasks != false) { + output.WriteRawTag(96); + output.WriteBool(EvalInstanceMasks); + } + if (MinScoreThreshold != 0F) { + output.WriteRawTag(109); + output.WriteFloat(MinScoreThreshold); + } + if (MaxNumBoxesToVisualize != 0) { + output.WriteRawTag(112); + output.WriteInt32(MaxNumBoxesToVisualize); + } + if (SkipScores != false) { + output.WriteRawTag(120); + output.WriteBool(SkipScores); + } + if (SkipLabels != false) { + output.WriteRawTag(128, 1); + output.WriteBool(SkipLabels); + } + if (VisualizeGroundtruthBoxes != false) { + output.WriteRawTag(136, 1); + output.WriteBool(VisualizeGroundtruthBoxes); + } + if (GroundtruthBoxVisualizationColor.Length != 0) { + output.WriteRawTag(146, 1); + output.WriteString(GroundtruthBoxVisualizationColor); + } + if (KeepImageIdForVisualizationExport != false) { + output.WriteRawTag(152, 1); + output.WriteBool(KeepImageIdForVisualizationExport); + } + if (RetainOriginalImages != false) { + output.WriteRawTag(184, 1); + output.WriteBool(RetainOriginalImages); + } + if (IncludeMetricsPerCategory != false) { + output.WriteRawTag(192, 1); + output.WriteBool(IncludeMetricsPerCategory); + } + if (BatchSize != 0) { + output.WriteRawTag(200, 1); + output.WriteUInt32(BatchSize); + } + if (RecallLowerBound != 0F) { + output.WriteRawTag(213, 1); + output.WriteFloat(RecallLowerBound); + } + if (RecallUpperBound != 0F) { + output.WriteRawTag(221, 1); + output.WriteFloat(RecallUpperBound); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (BatchSize != 0) { + size += 2 + pb::CodedOutputStream.ComputeUInt32Size(BatchSize); + } + if (NumVisualizations != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(NumVisualizations); + } + if (NumExamples != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(NumExamples); + } + if (EvalIntervalSecs != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(EvalIntervalSecs); + } + if (MaxEvals != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(MaxEvals); + } + if (SaveGraph != false) { + size += 1 + 1; + } + if (VisualizationExportDir.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(VisualizationExportDir); + } + if (EvalMaster.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(EvalMaster); + } + size += metricsSet_.CalculateSize(_repeated_metricsSet_codec); + if (ExportPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ExportPath); + } + if (IgnoreGroundtruth != false) { + size += 1 + 1; + } + if (UseMovingAverages != false) { + size += 1 + 1; + } + if (EvalInstanceMasks != false) { + size += 1 + 1; + } + if (MinScoreThreshold != 0F) { + size += 1 + 4; + } + if (MaxNumBoxesToVisualize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxNumBoxesToVisualize); + } + if (SkipScores != false) { + size += 1 + 1; + } + if (SkipLabels != false) { + size += 2 + 1; + } + if (VisualizeGroundtruthBoxes != false) { + size += 2 + 1; + } + if (GroundtruthBoxVisualizationColor.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(GroundtruthBoxVisualizationColor); + } + if (KeepImageIdForVisualizationExport != false) { + size += 2 + 1; + } + if (RetainOriginalImages != false) { + size += 2 + 1; + } + if (IncludeMetricsPerCategory != false) { + size += 2 + 1; + } + if (RecallLowerBound != 0F) { + size += 2 + 4; + } + if (RecallUpperBound != 0F) { + size += 2 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(EvalConfig other) { + if (other == null) { + return; + } + if (other.BatchSize != 0) { + BatchSize = other.BatchSize; + } + if (other.NumVisualizations != 0) { + NumVisualizations = other.NumVisualizations; + } + if (other.NumExamples != 0) { + NumExamples = other.NumExamples; + } + if (other.EvalIntervalSecs != 0) { + EvalIntervalSecs = other.EvalIntervalSecs; + } + if (other.MaxEvals != 0) { + MaxEvals = other.MaxEvals; + } + if (other.SaveGraph != false) { + SaveGraph = other.SaveGraph; + } + if (other.VisualizationExportDir.Length != 0) { + VisualizationExportDir = other.VisualizationExportDir; + } + if (other.EvalMaster.Length != 0) { + EvalMaster = other.EvalMaster; + } + metricsSet_.Add(other.metricsSet_); + if (other.ExportPath.Length != 0) { + ExportPath = other.ExportPath; + } + if (other.IgnoreGroundtruth != false) { + IgnoreGroundtruth = other.IgnoreGroundtruth; + } + if (other.UseMovingAverages != false) { + UseMovingAverages = other.UseMovingAverages; + } + if (other.EvalInstanceMasks != false) { + EvalInstanceMasks = other.EvalInstanceMasks; + } + if (other.MinScoreThreshold != 0F) { + MinScoreThreshold = other.MinScoreThreshold; + } + if (other.MaxNumBoxesToVisualize != 0) { + MaxNumBoxesToVisualize = other.MaxNumBoxesToVisualize; + } + if (other.SkipScores != false) { + SkipScores = other.SkipScores; + } + if (other.SkipLabels != false) { + SkipLabels = other.SkipLabels; + } + if (other.VisualizeGroundtruthBoxes != false) { + VisualizeGroundtruthBoxes = other.VisualizeGroundtruthBoxes; + } + if (other.GroundtruthBoxVisualizationColor.Length != 0) { + GroundtruthBoxVisualizationColor = other.GroundtruthBoxVisualizationColor; + } + if (other.KeepImageIdForVisualizationExport != false) { + KeepImageIdForVisualizationExport = other.KeepImageIdForVisualizationExport; + } + if (other.RetainOriginalImages != false) { + RetainOriginalImages = other.RetainOriginalImages; + } + if (other.IncludeMetricsPerCategory != false) { + IncludeMetricsPerCategory = other.IncludeMetricsPerCategory; + } + if (other.RecallLowerBound != 0F) { + RecallLowerBound = other.RecallLowerBound; + } + if (other.RecallUpperBound != 0F) { + RecallUpperBound = other.RecallUpperBound; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NumVisualizations = input.ReadUInt32(); + break; + } + case 16: { + NumExamples = input.ReadUInt32(); + break; + } + case 24: { + EvalIntervalSecs = input.ReadUInt32(); + break; + } + case 32: { + MaxEvals = input.ReadUInt32(); + break; + } + case 40: { + SaveGraph = input.ReadBool(); + break; + } + case 50: { + VisualizationExportDir = input.ReadString(); + break; + } + case 58: { + EvalMaster = input.ReadString(); + break; + } + case 66: { + metricsSet_.AddEntriesFrom(input, _repeated_metricsSet_codec); + break; + } + case 74: { + ExportPath = input.ReadString(); + break; + } + case 80: { + IgnoreGroundtruth = input.ReadBool(); + break; + } + case 88: { + UseMovingAverages = input.ReadBool(); + break; + } + case 96: { + EvalInstanceMasks = input.ReadBool(); + break; + } + case 109: { + MinScoreThreshold = input.ReadFloat(); + break; + } + case 112: { + MaxNumBoxesToVisualize = input.ReadInt32(); + break; + } + case 120: { + SkipScores = input.ReadBool(); + break; + } + case 128: { + SkipLabels = input.ReadBool(); + break; + } + case 136: { + VisualizeGroundtruthBoxes = input.ReadBool(); + break; + } + case 146: { + GroundtruthBoxVisualizationColor = input.ReadString(); + break; + } + case 152: { + KeepImageIdForVisualizationExport = input.ReadBool(); + break; + } + case 184: { + RetainOriginalImages = input.ReadBool(); + break; + } + case 192: { + IncludeMetricsPerCategory = input.ReadBool(); + break; + } + case 200: { + BatchSize = input.ReadUInt32(); + break; + } + case 213: { + RecallLowerBound = input.ReadFloat(); + break; + } + case 221: { + RecallUpperBound = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/FasterRcnn.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/FasterRcnn.cs new file mode 100644 index 00000000..f5ed1f2c --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/FasterRcnn.cs @@ -0,0 +1,1592 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/faster_rcnn.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/faster_rcnn.proto + public static partial class FasterRcnnReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/faster_rcnn.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static FasterRcnnReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CilvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9mYXN0ZXJfcmNubi5wcm90bxIX", + "b2JqZWN0X2RldGVjdGlvbi5wcm90b3MaLm9iamVjdF9kZXRlY3Rpb24vcHJv", + "dG9zL2FuY2hvcl9nZW5lcmF0b3IucHJvdG8aK29iamVjdF9kZXRlY3Rpb24v", + "cHJvdG9zL2JveF9wcmVkaWN0b3IucHJvdG8aKW9iamVjdF9kZXRlY3Rpb24v", + "cHJvdG9zL2h5cGVycGFyYW1zLnByb3RvGitvYmplY3RfZGV0ZWN0aW9uL3By", + "b3Rvcy9pbWFnZV9yZXNpemVyLnByb3RvGiRvYmplY3RfZGV0ZWN0aW9uL3By", + "b3Rvcy9sb3NzZXMucHJvdG8aLW9iamVjdF9kZXRlY3Rpb24vcHJvdG9zL3Bv", + "c3RfcHJvY2Vzc2luZy5wcm90byL5DAoKRmFzdGVyUmNubhIYChBudW1iZXJf", + "b2Zfc3RhZ2VzGAEgASgFEhMKC251bV9jbGFzc2VzGAMgASgFEjwKDWltYWdl", + "X3Jlc2l6ZXIYBCABKAsyJS5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5JbWFn", + "ZVJlc2l6ZXISTgoRZmVhdHVyZV9leHRyYWN0b3IYBSABKAsyMy5vYmplY3Rf", + "ZGV0ZWN0aW9uLnByb3Rvcy5GYXN0ZXJSY25uRmVhdHVyZUV4dHJhY3RvchJO", + "ChxmaXJzdF9zdGFnZV9hbmNob3JfZ2VuZXJhdG9yGAYgASgLMigub2JqZWN0", + "X2RldGVjdGlvbi5wcm90b3MuQW5jaG9yR2VuZXJhdG9yEh8KF2ZpcnN0X3N0", + "YWdlX2F0cm91c19yYXRlGAcgASgFElgKKmZpcnN0X3N0YWdlX2JveF9wcmVk", + "aWN0b3JfY29udl9oeXBlcnBhcmFtcxgIIAEoCzIkLm9iamVjdF9kZXRlY3Rp", + "b24ucHJvdG9zLkh5cGVycGFyYW1zEi0KJWZpcnN0X3N0YWdlX2JveF9wcmVk", + "aWN0b3Jfa2VybmVsX3NpemUYCSABKAUSJwofZmlyc3Rfc3RhZ2VfYm94X3By", + "ZWRpY3Rvcl9kZXB0aBgKIAEoBRIiChpmaXJzdF9zdGFnZV9taW5pYmF0Y2hf", + "c2l6ZRgLIAEoBRItCiVmaXJzdF9zdGFnZV9wb3NpdGl2ZV9iYWxhbmNlX2Zy", + "YWN0aW9uGAwgASgCEicKH2ZpcnN0X3N0YWdlX25tc19zY29yZV90aHJlc2hv", + "bGQYDSABKAISJQodZmlyc3Rfc3RhZ2Vfbm1zX2lvdV90aHJlc2hvbGQYDiAB", + "KAISIQoZZmlyc3Rfc3RhZ2VfbWF4X3Byb3Bvc2FscxgPIAEoBRIsCiRmaXJz", + "dF9zdGFnZV9sb2NhbGl6YXRpb25fbG9zc193ZWlnaHQYECABKAISKgoiZmly", + "c3Rfc3RhZ2Vfb2JqZWN0bmVzc19sb3NzX3dlaWdodBgRIAEoAhIZChFpbml0", + "aWFsX2Nyb3Bfc2l6ZRgSIAEoBRIbChNtYXhwb29sX2tlcm5lbF9zaXplGBMg", + "ASgFEhYKDm1heHBvb2xfc3RyaWRlGBQgASgFEkkKGnNlY29uZF9zdGFnZV9i", + "b3hfcHJlZGljdG9yGBUgASgLMiUub2JqZWN0X2RldGVjdGlvbi5wcm90b3Mu", + "Qm94UHJlZGljdG9yEh8KF3NlY29uZF9zdGFnZV9iYXRjaF9zaXplGBYgASgF", + "EiUKHXNlY29uZF9zdGFnZV9iYWxhbmNlX2ZyYWN0aW9uGBcgASgCEk0KHHNl", + "Y29uZF9zdGFnZV9wb3N0X3Byb2Nlc3NpbmcYGCABKAsyJy5vYmplY3RfZGV0", + "ZWN0aW9uLnByb3Rvcy5Qb3N0UHJvY2Vzc2luZxItCiVzZWNvbmRfc3RhZ2Vf", + "bG9jYWxpemF0aW9uX2xvc3Nfd2VpZ2h0GBkgASgCEi8KJ3NlY29uZF9zdGFn", + "ZV9jbGFzc2lmaWNhdGlvbl9sb3NzX3dlaWdodBgaIAEoAhIwCihzZWNvbmRf", + "c3RhZ2VfbWFza19wcmVkaWN0aW9uX2xvc3Nfd2VpZ2h0GBsgASgCEkUKEmhh", + "cmRfZXhhbXBsZV9taW5lchgcIAEoCzIpLm9iamVjdF9kZXRlY3Rpb24ucHJv", + "dG9zLkhhcmRFeGFtcGxlTWluZXISVQogc2Vjb25kX3N0YWdlX2NsYXNzaWZp", + "Y2F0aW9uX2xvc3MYHSABKAsyKy5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5D", + "bGFzc2lmaWNhdGlvbkxvc3MSIAoYaW5wbGFjZV9iYXRjaG5vcm1fdXBkYXRl", + "GB4gASgIEiIKGnVzZV9tYXRtdWxfY3JvcF9hbmRfcmVzaXplGB8gASgIEh0K", + "FWNsaXBfYW5jaG9yc190b19pbWFnZRggIAEoCBIkChx1c2VfbWF0bXVsX2dh", + "dGhlcl9pbl9tYXRjaGVyGCEgASgIEikKIXVzZV9zdGF0aWNfYmFsYW5jZWRf", + "bGFiZWxfc2FtcGxlchgiIAEoCBIZChF1c2Vfc3RhdGljX3NoYXBlcxgjIAEo", + "CBIUCgxyZXNpemVfbWFza3MYJCABKAgSIgoadXNlX3N0YXRpY19zaGFwZXNf", + "Zm9yX2V2YWwYJSABKAgibQoaRmFzdGVyUmNubkZlYXR1cmVFeHRyYWN0b3IS", + "DAoEdHlwZRgBIAEoCRIjChtmaXJzdF9zdGFnZV9mZWF0dXJlc19zdHJpZGUY", + "AiABKAUSHAoUYmF0Y2hfbm9ybV90cmFpbmFibGUYAyABKAhiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.AnchorGeneratorReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictorReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.ImageResizerReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.PostProcessingReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnn), global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnn.Parser, new[]{ "NumberOfStages", "NumClasses", "ImageResizer", "FeatureExtractor", "FirstStageAnchorGenerator", "FirstStageAtrousRate", "FirstStageBoxPredictorConvHyperparams", "FirstStageBoxPredictorKernelSize", "FirstStageBoxPredictorDepth", "FirstStageMinibatchSize", "FirstStagePositiveBalanceFraction", "FirstStageNmsScoreThreshold", "FirstStageNmsIouThreshold", "FirstStageMaxProposals", "FirstStageLocalizationLossWeight", "FirstStageObjectnessLossWeight", "InitialCropSize", "MaxpoolKernelSize", "MaxpoolStride", "SecondStageBoxPredictor", "SecondStageBatchSize", "SecondStageBalanceFraction", "SecondStagePostProcessing", "SecondStageLocalizationLossWeight", "SecondStageClassificationLossWeight", "SecondStageMaskPredictionLossWeight", "HardExampleMiner", "SecondStageClassificationLoss", "InplaceBatchnormUpdate", "UseMatmulCropAndResize", "ClipAnchorsToImage", "UseMatmulGatherInMatcher", "UseStaticBalancedLabelSampler", "UseStaticShapes", "ResizeMasks", "UseStaticShapesForEval" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnFeatureExtractor), global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnFeatureExtractor.Parser, new[]{ "Type", "FirstStageFeaturesStride", "BatchNormTrainable" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration for Faster R-CNN models. + /// See meta_architectures/faster_rcnn_meta_arch.py and models/model_builder.py + /// + /// Naming conventions: + /// Faster R-CNN models have two stages: a first stage region proposal network + /// (or RPN) and a second stage box classifier. We thus use the prefixes + /// `first_stage_` and `second_stage_` to indicate the stage to which each + /// parameter pertains when relevant. + /// + public sealed partial class FasterRcnn : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FasterRcnn()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FasterRcnn() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FasterRcnn(FasterRcnn other) : this() { + numberOfStages_ = other.numberOfStages_; + numClasses_ = other.numClasses_; + imageResizer_ = other.imageResizer_ != null ? other.imageResizer_.Clone() : null; + featureExtractor_ = other.featureExtractor_ != null ? other.featureExtractor_.Clone() : null; + firstStageAnchorGenerator_ = other.firstStageAnchorGenerator_ != null ? other.firstStageAnchorGenerator_.Clone() : null; + firstStageAtrousRate_ = other.firstStageAtrousRate_; + firstStageBoxPredictorConvHyperparams_ = other.firstStageBoxPredictorConvHyperparams_ != null ? other.firstStageBoxPredictorConvHyperparams_.Clone() : null; + firstStageBoxPredictorKernelSize_ = other.firstStageBoxPredictorKernelSize_; + firstStageBoxPredictorDepth_ = other.firstStageBoxPredictorDepth_; + firstStageMinibatchSize_ = other.firstStageMinibatchSize_; + firstStagePositiveBalanceFraction_ = other.firstStagePositiveBalanceFraction_; + firstStageNmsScoreThreshold_ = other.firstStageNmsScoreThreshold_; + firstStageNmsIouThreshold_ = other.firstStageNmsIouThreshold_; + firstStageMaxProposals_ = other.firstStageMaxProposals_; + firstStageLocalizationLossWeight_ = other.firstStageLocalizationLossWeight_; + firstStageObjectnessLossWeight_ = other.firstStageObjectnessLossWeight_; + initialCropSize_ = other.initialCropSize_; + maxpoolKernelSize_ = other.maxpoolKernelSize_; + maxpoolStride_ = other.maxpoolStride_; + secondStageBoxPredictor_ = other.secondStageBoxPredictor_ != null ? other.secondStageBoxPredictor_.Clone() : null; + secondStageBatchSize_ = other.secondStageBatchSize_; + secondStageBalanceFraction_ = other.secondStageBalanceFraction_; + secondStagePostProcessing_ = other.secondStagePostProcessing_ != null ? other.secondStagePostProcessing_.Clone() : null; + secondStageLocalizationLossWeight_ = other.secondStageLocalizationLossWeight_; + secondStageClassificationLossWeight_ = other.secondStageClassificationLossWeight_; + secondStageMaskPredictionLossWeight_ = other.secondStageMaskPredictionLossWeight_; + hardExampleMiner_ = other.hardExampleMiner_ != null ? other.hardExampleMiner_.Clone() : null; + secondStageClassificationLoss_ = other.secondStageClassificationLoss_ != null ? other.secondStageClassificationLoss_.Clone() : null; + inplaceBatchnormUpdate_ = other.inplaceBatchnormUpdate_; + useMatmulCropAndResize_ = other.useMatmulCropAndResize_; + clipAnchorsToImage_ = other.clipAnchorsToImage_; + useMatmulGatherInMatcher_ = other.useMatmulGatherInMatcher_; + useStaticBalancedLabelSampler_ = other.useStaticBalancedLabelSampler_; + useStaticShapes_ = other.useStaticShapes_; + resizeMasks_ = other.resizeMasks_; + useStaticShapesForEval_ = other.useStaticShapesForEval_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FasterRcnn Clone() { + return new FasterRcnn(this); + } + + /// Field number for the "number_of_stages" field. + public const int NumberOfStagesFieldNumber = 1; + private int numberOfStages_; + /// + /// Whether to construct only the Region Proposal Network (RPN). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumberOfStages { + get { return numberOfStages_; } + set { + numberOfStages_ = value; + } + } + + /// Field number for the "num_classes" field. + public const int NumClassesFieldNumber = 3; + private int numClasses_; + /// + /// Number of classes to predict. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumClasses { + get { return numClasses_; } + set { + numClasses_ = value; + } + } + + /// Field number for the "image_resizer" field. + public const int ImageResizerFieldNumber = 4; + private global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer imageResizer_; + /// + /// Image resizer for preprocessing the input image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer ImageResizer { + get { return imageResizer_; } + set { + imageResizer_ = value; + } + } + + /// Field number for the "feature_extractor" field. + public const int FeatureExtractorFieldNumber = 5; + private global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnFeatureExtractor featureExtractor_; + /// + /// Feature extractor config. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnFeatureExtractor FeatureExtractor { + get { return featureExtractor_; } + set { + featureExtractor_ = value; + } + } + + /// Field number for the "first_stage_anchor_generator" field. + public const int FirstStageAnchorGeneratorFieldNumber = 6; + private global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator firstStageAnchorGenerator_; + /// + /// Anchor generator to compute RPN anchors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator FirstStageAnchorGenerator { + get { return firstStageAnchorGenerator_; } + set { + firstStageAnchorGenerator_ = value; + } + } + + /// Field number for the "first_stage_atrous_rate" field. + public const int FirstStageAtrousRateFieldNumber = 7; + private int firstStageAtrousRate_; + /// + /// Atrous rate for the convolution op applied to the + /// `first_stage_features_to_crop` tensor to obtain box predictions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int FirstStageAtrousRate { + get { return firstStageAtrousRate_; } + set { + firstStageAtrousRate_ = value; + } + } + + /// Field number for the "first_stage_box_predictor_conv_hyperparams" field. + public const int FirstStageBoxPredictorConvHyperparamsFieldNumber = 8; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams firstStageBoxPredictorConvHyperparams_; + /// + /// Hyperparameters for the convolutional RPN box predictor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams FirstStageBoxPredictorConvHyperparams { + get { return firstStageBoxPredictorConvHyperparams_; } + set { + firstStageBoxPredictorConvHyperparams_ = value; + } + } + + /// Field number for the "first_stage_box_predictor_kernel_size" field. + public const int FirstStageBoxPredictorKernelSizeFieldNumber = 9; + private int firstStageBoxPredictorKernelSize_; + /// + /// Kernel size to use for the convolution op just prior to RPN box + /// predictions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int FirstStageBoxPredictorKernelSize { + get { return firstStageBoxPredictorKernelSize_; } + set { + firstStageBoxPredictorKernelSize_ = value; + } + } + + /// Field number for the "first_stage_box_predictor_depth" field. + public const int FirstStageBoxPredictorDepthFieldNumber = 10; + private int firstStageBoxPredictorDepth_; + /// + /// Output depth for the convolution op just prior to RPN box predictions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int FirstStageBoxPredictorDepth { + get { return firstStageBoxPredictorDepth_; } + set { + firstStageBoxPredictorDepth_ = value; + } + } + + /// Field number for the "first_stage_minibatch_size" field. + public const int FirstStageMinibatchSizeFieldNumber = 11; + private int firstStageMinibatchSize_; + /// + /// The batch size to use for computing the first stage objectness and + /// location losses. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int FirstStageMinibatchSize { + get { return firstStageMinibatchSize_; } + set { + firstStageMinibatchSize_ = value; + } + } + + /// Field number for the "first_stage_positive_balance_fraction" field. + public const int FirstStagePositiveBalanceFractionFieldNumber = 12; + private float firstStagePositiveBalanceFraction_; + /// + /// Fraction of positive examples per image for the RPN. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float FirstStagePositiveBalanceFraction { + get { return firstStagePositiveBalanceFraction_; } + set { + firstStagePositiveBalanceFraction_ = value; + } + } + + /// Field number for the "first_stage_nms_score_threshold" field. + public const int FirstStageNmsScoreThresholdFieldNumber = 13; + private float firstStageNmsScoreThreshold_; + /// + /// Non max suppression score threshold applied to first stage RPN proposals. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float FirstStageNmsScoreThreshold { + get { return firstStageNmsScoreThreshold_; } + set { + firstStageNmsScoreThreshold_ = value; + } + } + + /// Field number for the "first_stage_nms_iou_threshold" field. + public const int FirstStageNmsIouThresholdFieldNumber = 14; + private float firstStageNmsIouThreshold_; + /// + /// Non max suppression IOU threshold applied to first stage RPN proposals. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float FirstStageNmsIouThreshold { + get { return firstStageNmsIouThreshold_; } + set { + firstStageNmsIouThreshold_ = value; + } + } + + /// Field number for the "first_stage_max_proposals" field. + public const int FirstStageMaxProposalsFieldNumber = 15; + private int firstStageMaxProposals_; + /// + /// Maximum number of RPN proposals retained after first stage postprocessing. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int FirstStageMaxProposals { + get { return firstStageMaxProposals_; } + set { + firstStageMaxProposals_ = value; + } + } + + /// Field number for the "first_stage_localization_loss_weight" field. + public const int FirstStageLocalizationLossWeightFieldNumber = 16; + private float firstStageLocalizationLossWeight_; + /// + /// First stage RPN localization loss weight. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float FirstStageLocalizationLossWeight { + get { return firstStageLocalizationLossWeight_; } + set { + firstStageLocalizationLossWeight_ = value; + } + } + + /// Field number for the "first_stage_objectness_loss_weight" field. + public const int FirstStageObjectnessLossWeightFieldNumber = 17; + private float firstStageObjectnessLossWeight_; + /// + /// First stage RPN objectness loss weight. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float FirstStageObjectnessLossWeight { + get { return firstStageObjectnessLossWeight_; } + set { + firstStageObjectnessLossWeight_ = value; + } + } + + /// Field number for the "initial_crop_size" field. + public const int InitialCropSizeFieldNumber = 18; + private int initialCropSize_; + /// + /// Output size (width and height are set to be the same) of the initial + /// bilinear interpolation based cropping during ROI pooling. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int InitialCropSize { + get { return initialCropSize_; } + set { + initialCropSize_ = value; + } + } + + /// Field number for the "maxpool_kernel_size" field. + public const int MaxpoolKernelSizeFieldNumber = 19; + private int maxpoolKernelSize_; + /// + /// Kernel size of the max pool op on the cropped feature map during + /// ROI pooling. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxpoolKernelSize { + get { return maxpoolKernelSize_; } + set { + maxpoolKernelSize_ = value; + } + } + + /// Field number for the "maxpool_stride" field. + public const int MaxpoolStrideFieldNumber = 20; + private int maxpoolStride_; + /// + /// Stride of the max pool op on the cropped feature map during ROI pooling. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxpoolStride { + get { return maxpoolStride_; } + set { + maxpoolStride_ = value; + } + } + + /// Field number for the "second_stage_box_predictor" field. + public const int SecondStageBoxPredictorFieldNumber = 21; + private global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor secondStageBoxPredictor_; + /// + /// Hyperparameters for the second stage box predictor. If box predictor type + /// is set to rfcn_box_predictor, a R-FCN model is constructed, otherwise a + /// Faster R-CNN model is constructed. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor SecondStageBoxPredictor { + get { return secondStageBoxPredictor_; } + set { + secondStageBoxPredictor_ = value; + } + } + + /// Field number for the "second_stage_batch_size" field. + public const int SecondStageBatchSizeFieldNumber = 22; + private int secondStageBatchSize_; + /// + /// The batch size per image used for computing the classification and refined + /// location loss of the box classifier. + /// Note that this field is ignored if `hard_example_miner` is configured. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int SecondStageBatchSize { + get { return secondStageBatchSize_; } + set { + secondStageBatchSize_ = value; + } + } + + /// Field number for the "second_stage_balance_fraction" field. + public const int SecondStageBalanceFractionFieldNumber = 23; + private float secondStageBalanceFraction_; + /// + /// Fraction of positive examples to use per image for the box classifier. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float SecondStageBalanceFraction { + get { return secondStageBalanceFraction_; } + set { + secondStageBalanceFraction_ = value; + } + } + + /// Field number for the "second_stage_post_processing" field. + public const int SecondStagePostProcessingFieldNumber = 24; + private global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing secondStagePostProcessing_; + /// + /// Post processing to apply on the second stage box classifier predictions. + /// Note: the `score_converter` provided to the FasterRCNNMetaArch constructor + /// is taken from this `second_stage_post_processing` proto. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing SecondStagePostProcessing { + get { return secondStagePostProcessing_; } + set { + secondStagePostProcessing_ = value; + } + } + + /// Field number for the "second_stage_localization_loss_weight" field. + public const int SecondStageLocalizationLossWeightFieldNumber = 25; + private float secondStageLocalizationLossWeight_; + /// + /// Second stage refined localization loss weight. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float SecondStageLocalizationLossWeight { + get { return secondStageLocalizationLossWeight_; } + set { + secondStageLocalizationLossWeight_ = value; + } + } + + /// Field number for the "second_stage_classification_loss_weight" field. + public const int SecondStageClassificationLossWeightFieldNumber = 26; + private float secondStageClassificationLossWeight_; + /// + /// Second stage classification loss weight + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float SecondStageClassificationLossWeight { + get { return secondStageClassificationLossWeight_; } + set { + secondStageClassificationLossWeight_ = value; + } + } + + /// Field number for the "second_stage_mask_prediction_loss_weight" field. + public const int SecondStageMaskPredictionLossWeightFieldNumber = 27; + private float secondStageMaskPredictionLossWeight_; + /// + /// Second stage instance mask loss weight. Note that this is only applicable + /// when `MaskRCNNBoxPredictor` is selected for second stage and configured to + /// predict instance masks. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float SecondStageMaskPredictionLossWeight { + get { return secondStageMaskPredictionLossWeight_; } + set { + secondStageMaskPredictionLossWeight_ = value; + } + } + + /// Field number for the "hard_example_miner" field. + public const int HardExampleMinerFieldNumber = 28; + private global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner hardExampleMiner_; + /// + /// If not left to default, applies hard example mining only to classification + /// and localization loss.. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner HardExampleMiner { + get { return hardExampleMiner_; } + set { + hardExampleMiner_ = value; + } + } + + /// Field number for the "second_stage_classification_loss" field. + public const int SecondStageClassificationLossFieldNumber = 29; + private global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss secondStageClassificationLoss_; + /// + /// Loss for second stage box classifers, supports Softmax and Sigmoid. + /// Note that score converter must be consistent with loss type. + /// When there are multiple labels assigned to the same boxes, recommend + /// to use sigmoid loss and enable merge_multiple_label_boxes. + /// If not specified, Softmax loss is used as default. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss SecondStageClassificationLoss { + get { return secondStageClassificationLoss_; } + set { + secondStageClassificationLoss_ = value; + } + } + + /// Field number for the "inplace_batchnorm_update" field. + public const int InplaceBatchnormUpdateFieldNumber = 30; + private bool inplaceBatchnormUpdate_; + /// + /// Whether to update batch_norm inplace during training. This is required + /// for batch norm to work correctly on TPUs. When this is false, user must add + /// a control dependency on tf.GraphKeys.UPDATE_OPS for train/loss op in order + /// to update the batch norm moving average parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool InplaceBatchnormUpdate { + get { return inplaceBatchnormUpdate_; } + set { + inplaceBatchnormUpdate_ = value; + } + } + + /// Field number for the "use_matmul_crop_and_resize" field. + public const int UseMatmulCropAndResizeFieldNumber = 31; + private bool useMatmulCropAndResize_; + /// + /// Force the use of matrix multiplication based crop and resize instead of + /// standard tf.image.crop_and_resize while computing second stage input + /// feature maps. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseMatmulCropAndResize { + get { return useMatmulCropAndResize_; } + set { + useMatmulCropAndResize_ = value; + } + } + + /// Field number for the "clip_anchors_to_image" field. + public const int ClipAnchorsToImageFieldNumber = 32; + private bool clipAnchorsToImage_; + /// + /// Normally, anchors generated for a given image size are pruned during + /// training if they lie outside the image window. Setting this option to true, + /// clips the anchors to be within the image instead of pruning. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ClipAnchorsToImage { + get { return clipAnchorsToImage_; } + set { + clipAnchorsToImage_ = value; + } + } + + /// Field number for the "use_matmul_gather_in_matcher" field. + public const int UseMatmulGatherInMatcherFieldNumber = 33; + private bool useMatmulGatherInMatcher_; + /// + /// After peforming matching between anchors and targets, in order to pull out + /// targets for training Faster R-CNN meta architecture we perform a gather + /// operation. This options specifies whether to use an alternate + /// implementation of tf.gather that is faster on TPUs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseMatmulGatherInMatcher { + get { return useMatmulGatherInMatcher_; } + set { + useMatmulGatherInMatcher_ = value; + } + } + + /// Field number for the "use_static_balanced_label_sampler" field. + public const int UseStaticBalancedLabelSamplerFieldNumber = 34; + private bool useStaticBalancedLabelSampler_; + /// + /// Whether to use the balanced positive negative sampler implementation with + /// static shape guarantees. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseStaticBalancedLabelSampler { + get { return useStaticBalancedLabelSampler_; } + set { + useStaticBalancedLabelSampler_ = value; + } + } + + /// Field number for the "use_static_shapes" field. + public const int UseStaticShapesFieldNumber = 35; + private bool useStaticShapes_; + /// + /// If True, uses implementation of ops with static shape guarantees. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseStaticShapes { + get { return useStaticShapes_; } + set { + useStaticShapes_ = value; + } + } + + /// Field number for the "resize_masks" field. + public const int ResizeMasksFieldNumber = 36; + private bool resizeMasks_; + /// + /// Whether the masks present in groundtruth should be resized in the model to + /// match the image size. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ResizeMasks { + get { return resizeMasks_; } + set { + resizeMasks_ = value; + } + } + + /// Field number for the "use_static_shapes_for_eval" field. + public const int UseStaticShapesForEvalFieldNumber = 37; + private bool useStaticShapesForEval_; + /// + /// If True, uses implementation of ops with static shape guarantees when + /// running evaluation (specifically not is_training if False). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseStaticShapesForEval { + get { return useStaticShapesForEval_; } + set { + useStaticShapesForEval_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FasterRcnn); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FasterRcnn other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NumberOfStages != other.NumberOfStages) return false; + if (NumClasses != other.NumClasses) return false; + if (!object.Equals(ImageResizer, other.ImageResizer)) return false; + if (!object.Equals(FeatureExtractor, other.FeatureExtractor)) return false; + if (!object.Equals(FirstStageAnchorGenerator, other.FirstStageAnchorGenerator)) return false; + if (FirstStageAtrousRate != other.FirstStageAtrousRate) return false; + if (!object.Equals(FirstStageBoxPredictorConvHyperparams, other.FirstStageBoxPredictorConvHyperparams)) return false; + if (FirstStageBoxPredictorKernelSize != other.FirstStageBoxPredictorKernelSize) return false; + if (FirstStageBoxPredictorDepth != other.FirstStageBoxPredictorDepth) return false; + if (FirstStageMinibatchSize != other.FirstStageMinibatchSize) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(FirstStagePositiveBalanceFraction, other.FirstStagePositiveBalanceFraction)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(FirstStageNmsScoreThreshold, other.FirstStageNmsScoreThreshold)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(FirstStageNmsIouThreshold, other.FirstStageNmsIouThreshold)) return false; + if (FirstStageMaxProposals != other.FirstStageMaxProposals) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(FirstStageLocalizationLossWeight, other.FirstStageLocalizationLossWeight)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(FirstStageObjectnessLossWeight, other.FirstStageObjectnessLossWeight)) return false; + if (InitialCropSize != other.InitialCropSize) return false; + if (MaxpoolKernelSize != other.MaxpoolKernelSize) return false; + if (MaxpoolStride != other.MaxpoolStride) return false; + if (!object.Equals(SecondStageBoxPredictor, other.SecondStageBoxPredictor)) return false; + if (SecondStageBatchSize != other.SecondStageBatchSize) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(SecondStageBalanceFraction, other.SecondStageBalanceFraction)) return false; + if (!object.Equals(SecondStagePostProcessing, other.SecondStagePostProcessing)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(SecondStageLocalizationLossWeight, other.SecondStageLocalizationLossWeight)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(SecondStageClassificationLossWeight, other.SecondStageClassificationLossWeight)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(SecondStageMaskPredictionLossWeight, other.SecondStageMaskPredictionLossWeight)) return false; + if (!object.Equals(HardExampleMiner, other.HardExampleMiner)) return false; + if (!object.Equals(SecondStageClassificationLoss, other.SecondStageClassificationLoss)) return false; + if (InplaceBatchnormUpdate != other.InplaceBatchnormUpdate) return false; + if (UseMatmulCropAndResize != other.UseMatmulCropAndResize) return false; + if (ClipAnchorsToImage != other.ClipAnchorsToImage) return false; + if (UseMatmulGatherInMatcher != other.UseMatmulGatherInMatcher) return false; + if (UseStaticBalancedLabelSampler != other.UseStaticBalancedLabelSampler) return false; + if (UseStaticShapes != other.UseStaticShapes) return false; + if (ResizeMasks != other.ResizeMasks) return false; + if (UseStaticShapesForEval != other.UseStaticShapesForEval) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (NumberOfStages != 0) hash ^= NumberOfStages.GetHashCode(); + if (NumClasses != 0) hash ^= NumClasses.GetHashCode(); + if (imageResizer_ != null) hash ^= ImageResizer.GetHashCode(); + if (featureExtractor_ != null) hash ^= FeatureExtractor.GetHashCode(); + if (firstStageAnchorGenerator_ != null) hash ^= FirstStageAnchorGenerator.GetHashCode(); + if (FirstStageAtrousRate != 0) hash ^= FirstStageAtrousRate.GetHashCode(); + if (firstStageBoxPredictorConvHyperparams_ != null) hash ^= FirstStageBoxPredictorConvHyperparams.GetHashCode(); + if (FirstStageBoxPredictorKernelSize != 0) hash ^= FirstStageBoxPredictorKernelSize.GetHashCode(); + if (FirstStageBoxPredictorDepth != 0) hash ^= FirstStageBoxPredictorDepth.GetHashCode(); + if (FirstStageMinibatchSize != 0) hash ^= FirstStageMinibatchSize.GetHashCode(); + if (FirstStagePositiveBalanceFraction != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(FirstStagePositiveBalanceFraction); + if (FirstStageNmsScoreThreshold != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(FirstStageNmsScoreThreshold); + if (FirstStageNmsIouThreshold != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(FirstStageNmsIouThreshold); + if (FirstStageMaxProposals != 0) hash ^= FirstStageMaxProposals.GetHashCode(); + if (FirstStageLocalizationLossWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(FirstStageLocalizationLossWeight); + if (FirstStageObjectnessLossWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(FirstStageObjectnessLossWeight); + if (InitialCropSize != 0) hash ^= InitialCropSize.GetHashCode(); + if (MaxpoolKernelSize != 0) hash ^= MaxpoolKernelSize.GetHashCode(); + if (MaxpoolStride != 0) hash ^= MaxpoolStride.GetHashCode(); + if (secondStageBoxPredictor_ != null) hash ^= SecondStageBoxPredictor.GetHashCode(); + if (SecondStageBatchSize != 0) hash ^= SecondStageBatchSize.GetHashCode(); + if (SecondStageBalanceFraction != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(SecondStageBalanceFraction); + if (secondStagePostProcessing_ != null) hash ^= SecondStagePostProcessing.GetHashCode(); + if (SecondStageLocalizationLossWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(SecondStageLocalizationLossWeight); + if (SecondStageClassificationLossWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(SecondStageClassificationLossWeight); + if (SecondStageMaskPredictionLossWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(SecondStageMaskPredictionLossWeight); + if (hardExampleMiner_ != null) hash ^= HardExampleMiner.GetHashCode(); + if (secondStageClassificationLoss_ != null) hash ^= SecondStageClassificationLoss.GetHashCode(); + if (InplaceBatchnormUpdate != false) hash ^= InplaceBatchnormUpdate.GetHashCode(); + if (UseMatmulCropAndResize != false) hash ^= UseMatmulCropAndResize.GetHashCode(); + if (ClipAnchorsToImage != false) hash ^= ClipAnchorsToImage.GetHashCode(); + if (UseMatmulGatherInMatcher != false) hash ^= UseMatmulGatherInMatcher.GetHashCode(); + if (UseStaticBalancedLabelSampler != false) hash ^= UseStaticBalancedLabelSampler.GetHashCode(); + if (UseStaticShapes != false) hash ^= UseStaticShapes.GetHashCode(); + if (ResizeMasks != false) hash ^= ResizeMasks.GetHashCode(); + if (UseStaticShapesForEval != false) hash ^= UseStaticShapesForEval.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (NumberOfStages != 0) { + output.WriteRawTag(8); + output.WriteInt32(NumberOfStages); + } + if (NumClasses != 0) { + output.WriteRawTag(24); + output.WriteInt32(NumClasses); + } + if (imageResizer_ != null) { + output.WriteRawTag(34); + output.WriteMessage(ImageResizer); + } + if (featureExtractor_ != null) { + output.WriteRawTag(42); + output.WriteMessage(FeatureExtractor); + } + if (firstStageAnchorGenerator_ != null) { + output.WriteRawTag(50); + output.WriteMessage(FirstStageAnchorGenerator); + } + if (FirstStageAtrousRate != 0) { + output.WriteRawTag(56); + output.WriteInt32(FirstStageAtrousRate); + } + if (firstStageBoxPredictorConvHyperparams_ != null) { + output.WriteRawTag(66); + output.WriteMessage(FirstStageBoxPredictorConvHyperparams); + } + if (FirstStageBoxPredictorKernelSize != 0) { + output.WriteRawTag(72); + output.WriteInt32(FirstStageBoxPredictorKernelSize); + } + if (FirstStageBoxPredictorDepth != 0) { + output.WriteRawTag(80); + output.WriteInt32(FirstStageBoxPredictorDepth); + } + if (FirstStageMinibatchSize != 0) { + output.WriteRawTag(88); + output.WriteInt32(FirstStageMinibatchSize); + } + if (FirstStagePositiveBalanceFraction != 0F) { + output.WriteRawTag(101); + output.WriteFloat(FirstStagePositiveBalanceFraction); + } + if (FirstStageNmsScoreThreshold != 0F) { + output.WriteRawTag(109); + output.WriteFloat(FirstStageNmsScoreThreshold); + } + if (FirstStageNmsIouThreshold != 0F) { + output.WriteRawTag(117); + output.WriteFloat(FirstStageNmsIouThreshold); + } + if (FirstStageMaxProposals != 0) { + output.WriteRawTag(120); + output.WriteInt32(FirstStageMaxProposals); + } + if (FirstStageLocalizationLossWeight != 0F) { + output.WriteRawTag(133, 1); + output.WriteFloat(FirstStageLocalizationLossWeight); + } + if (FirstStageObjectnessLossWeight != 0F) { + output.WriteRawTag(141, 1); + output.WriteFloat(FirstStageObjectnessLossWeight); + } + if (InitialCropSize != 0) { + output.WriteRawTag(144, 1); + output.WriteInt32(InitialCropSize); + } + if (MaxpoolKernelSize != 0) { + output.WriteRawTag(152, 1); + output.WriteInt32(MaxpoolKernelSize); + } + if (MaxpoolStride != 0) { + output.WriteRawTag(160, 1); + output.WriteInt32(MaxpoolStride); + } + if (secondStageBoxPredictor_ != null) { + output.WriteRawTag(170, 1); + output.WriteMessage(SecondStageBoxPredictor); + } + if (SecondStageBatchSize != 0) { + output.WriteRawTag(176, 1); + output.WriteInt32(SecondStageBatchSize); + } + if (SecondStageBalanceFraction != 0F) { + output.WriteRawTag(189, 1); + output.WriteFloat(SecondStageBalanceFraction); + } + if (secondStagePostProcessing_ != null) { + output.WriteRawTag(194, 1); + output.WriteMessage(SecondStagePostProcessing); + } + if (SecondStageLocalizationLossWeight != 0F) { + output.WriteRawTag(205, 1); + output.WriteFloat(SecondStageLocalizationLossWeight); + } + if (SecondStageClassificationLossWeight != 0F) { + output.WriteRawTag(213, 1); + output.WriteFloat(SecondStageClassificationLossWeight); + } + if (SecondStageMaskPredictionLossWeight != 0F) { + output.WriteRawTag(221, 1); + output.WriteFloat(SecondStageMaskPredictionLossWeight); + } + if (hardExampleMiner_ != null) { + output.WriteRawTag(226, 1); + output.WriteMessage(HardExampleMiner); + } + if (secondStageClassificationLoss_ != null) { + output.WriteRawTag(234, 1); + output.WriteMessage(SecondStageClassificationLoss); + } + if (InplaceBatchnormUpdate != false) { + output.WriteRawTag(240, 1); + output.WriteBool(InplaceBatchnormUpdate); + } + if (UseMatmulCropAndResize != false) { + output.WriteRawTag(248, 1); + output.WriteBool(UseMatmulCropAndResize); + } + if (ClipAnchorsToImage != false) { + output.WriteRawTag(128, 2); + output.WriteBool(ClipAnchorsToImage); + } + if (UseMatmulGatherInMatcher != false) { + output.WriteRawTag(136, 2); + output.WriteBool(UseMatmulGatherInMatcher); + } + if (UseStaticBalancedLabelSampler != false) { + output.WriteRawTag(144, 2); + output.WriteBool(UseStaticBalancedLabelSampler); + } + if (UseStaticShapes != false) { + output.WriteRawTag(152, 2); + output.WriteBool(UseStaticShapes); + } + if (ResizeMasks != false) { + output.WriteRawTag(160, 2); + output.WriteBool(ResizeMasks); + } + if (UseStaticShapesForEval != false) { + output.WriteRawTag(168, 2); + output.WriteBool(UseStaticShapesForEval); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (NumberOfStages != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumberOfStages); + } + if (NumClasses != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumClasses); + } + if (imageResizer_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ImageResizer); + } + if (featureExtractor_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FeatureExtractor); + } + if (firstStageAnchorGenerator_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FirstStageAnchorGenerator); + } + if (FirstStageAtrousRate != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(FirstStageAtrousRate); + } + if (firstStageBoxPredictorConvHyperparams_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FirstStageBoxPredictorConvHyperparams); + } + if (FirstStageBoxPredictorKernelSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(FirstStageBoxPredictorKernelSize); + } + if (FirstStageBoxPredictorDepth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(FirstStageBoxPredictorDepth); + } + if (FirstStageMinibatchSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(FirstStageMinibatchSize); + } + if (FirstStagePositiveBalanceFraction != 0F) { + size += 1 + 4; + } + if (FirstStageNmsScoreThreshold != 0F) { + size += 1 + 4; + } + if (FirstStageNmsIouThreshold != 0F) { + size += 1 + 4; + } + if (FirstStageMaxProposals != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(FirstStageMaxProposals); + } + if (FirstStageLocalizationLossWeight != 0F) { + size += 2 + 4; + } + if (FirstStageObjectnessLossWeight != 0F) { + size += 2 + 4; + } + if (InitialCropSize != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(InitialCropSize); + } + if (MaxpoolKernelSize != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(MaxpoolKernelSize); + } + if (MaxpoolStride != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(MaxpoolStride); + } + if (secondStageBoxPredictor_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(SecondStageBoxPredictor); + } + if (SecondStageBatchSize != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(SecondStageBatchSize); + } + if (SecondStageBalanceFraction != 0F) { + size += 2 + 4; + } + if (secondStagePostProcessing_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(SecondStagePostProcessing); + } + if (SecondStageLocalizationLossWeight != 0F) { + size += 2 + 4; + } + if (SecondStageClassificationLossWeight != 0F) { + size += 2 + 4; + } + if (SecondStageMaskPredictionLossWeight != 0F) { + size += 2 + 4; + } + if (hardExampleMiner_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(HardExampleMiner); + } + if (secondStageClassificationLoss_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(SecondStageClassificationLoss); + } + if (InplaceBatchnormUpdate != false) { + size += 2 + 1; + } + if (UseMatmulCropAndResize != false) { + size += 2 + 1; + } + if (ClipAnchorsToImage != false) { + size += 2 + 1; + } + if (UseMatmulGatherInMatcher != false) { + size += 2 + 1; + } + if (UseStaticBalancedLabelSampler != false) { + size += 2 + 1; + } + if (UseStaticShapes != false) { + size += 2 + 1; + } + if (ResizeMasks != false) { + size += 2 + 1; + } + if (UseStaticShapesForEval != false) { + size += 2 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FasterRcnn other) { + if (other == null) { + return; + } + if (other.NumberOfStages != 0) { + NumberOfStages = other.NumberOfStages; + } + if (other.NumClasses != 0) { + NumClasses = other.NumClasses; + } + if (other.imageResizer_ != null) { + if (imageResizer_ == null) { + imageResizer_ = new global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer(); + } + ImageResizer.MergeFrom(other.ImageResizer); + } + if (other.featureExtractor_ != null) { + if (featureExtractor_ == null) { + featureExtractor_ = new global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnFeatureExtractor(); + } + FeatureExtractor.MergeFrom(other.FeatureExtractor); + } + if (other.firstStageAnchorGenerator_ != null) { + if (firstStageAnchorGenerator_ == null) { + firstStageAnchorGenerator_ = new global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator(); + } + FirstStageAnchorGenerator.MergeFrom(other.FirstStageAnchorGenerator); + } + if (other.FirstStageAtrousRate != 0) { + FirstStageAtrousRate = other.FirstStageAtrousRate; + } + if (other.firstStageBoxPredictorConvHyperparams_ != null) { + if (firstStageBoxPredictorConvHyperparams_ == null) { + firstStageBoxPredictorConvHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + FirstStageBoxPredictorConvHyperparams.MergeFrom(other.FirstStageBoxPredictorConvHyperparams); + } + if (other.FirstStageBoxPredictorKernelSize != 0) { + FirstStageBoxPredictorKernelSize = other.FirstStageBoxPredictorKernelSize; + } + if (other.FirstStageBoxPredictorDepth != 0) { + FirstStageBoxPredictorDepth = other.FirstStageBoxPredictorDepth; + } + if (other.FirstStageMinibatchSize != 0) { + FirstStageMinibatchSize = other.FirstStageMinibatchSize; + } + if (other.FirstStagePositiveBalanceFraction != 0F) { + FirstStagePositiveBalanceFraction = other.FirstStagePositiveBalanceFraction; + } + if (other.FirstStageNmsScoreThreshold != 0F) { + FirstStageNmsScoreThreshold = other.FirstStageNmsScoreThreshold; + } + if (other.FirstStageNmsIouThreshold != 0F) { + FirstStageNmsIouThreshold = other.FirstStageNmsIouThreshold; + } + if (other.FirstStageMaxProposals != 0) { + FirstStageMaxProposals = other.FirstStageMaxProposals; + } + if (other.FirstStageLocalizationLossWeight != 0F) { + FirstStageLocalizationLossWeight = other.FirstStageLocalizationLossWeight; + } + if (other.FirstStageObjectnessLossWeight != 0F) { + FirstStageObjectnessLossWeight = other.FirstStageObjectnessLossWeight; + } + if (other.InitialCropSize != 0) { + InitialCropSize = other.InitialCropSize; + } + if (other.MaxpoolKernelSize != 0) { + MaxpoolKernelSize = other.MaxpoolKernelSize; + } + if (other.MaxpoolStride != 0) { + MaxpoolStride = other.MaxpoolStride; + } + if (other.secondStageBoxPredictor_ != null) { + if (secondStageBoxPredictor_ == null) { + secondStageBoxPredictor_ = new global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor(); + } + SecondStageBoxPredictor.MergeFrom(other.SecondStageBoxPredictor); + } + if (other.SecondStageBatchSize != 0) { + SecondStageBatchSize = other.SecondStageBatchSize; + } + if (other.SecondStageBalanceFraction != 0F) { + SecondStageBalanceFraction = other.SecondStageBalanceFraction; + } + if (other.secondStagePostProcessing_ != null) { + if (secondStagePostProcessing_ == null) { + secondStagePostProcessing_ = new global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing(); + } + SecondStagePostProcessing.MergeFrom(other.SecondStagePostProcessing); + } + if (other.SecondStageLocalizationLossWeight != 0F) { + SecondStageLocalizationLossWeight = other.SecondStageLocalizationLossWeight; + } + if (other.SecondStageClassificationLossWeight != 0F) { + SecondStageClassificationLossWeight = other.SecondStageClassificationLossWeight; + } + if (other.SecondStageMaskPredictionLossWeight != 0F) { + SecondStageMaskPredictionLossWeight = other.SecondStageMaskPredictionLossWeight; + } + if (other.hardExampleMiner_ != null) { + if (hardExampleMiner_ == null) { + hardExampleMiner_ = new global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner(); + } + HardExampleMiner.MergeFrom(other.HardExampleMiner); + } + if (other.secondStageClassificationLoss_ != null) { + if (secondStageClassificationLoss_ == null) { + secondStageClassificationLoss_ = new global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss(); + } + SecondStageClassificationLoss.MergeFrom(other.SecondStageClassificationLoss); + } + if (other.InplaceBatchnormUpdate != false) { + InplaceBatchnormUpdate = other.InplaceBatchnormUpdate; + } + if (other.UseMatmulCropAndResize != false) { + UseMatmulCropAndResize = other.UseMatmulCropAndResize; + } + if (other.ClipAnchorsToImage != false) { + ClipAnchorsToImage = other.ClipAnchorsToImage; + } + if (other.UseMatmulGatherInMatcher != false) { + UseMatmulGatherInMatcher = other.UseMatmulGatherInMatcher; + } + if (other.UseStaticBalancedLabelSampler != false) { + UseStaticBalancedLabelSampler = other.UseStaticBalancedLabelSampler; + } + if (other.UseStaticShapes != false) { + UseStaticShapes = other.UseStaticShapes; + } + if (other.ResizeMasks != false) { + ResizeMasks = other.ResizeMasks; + } + if (other.UseStaticShapesForEval != false) { + UseStaticShapesForEval = other.UseStaticShapesForEval; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NumberOfStages = input.ReadInt32(); + break; + } + case 24: { + NumClasses = input.ReadInt32(); + break; + } + case 34: { + if (imageResizer_ == null) { + imageResizer_ = new global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer(); + } + input.ReadMessage(imageResizer_); + break; + } + case 42: { + if (featureExtractor_ == null) { + featureExtractor_ = new global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnFeatureExtractor(); + } + input.ReadMessage(featureExtractor_); + break; + } + case 50: { + if (firstStageAnchorGenerator_ == null) { + firstStageAnchorGenerator_ = new global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator(); + } + input.ReadMessage(firstStageAnchorGenerator_); + break; + } + case 56: { + FirstStageAtrousRate = input.ReadInt32(); + break; + } + case 66: { + if (firstStageBoxPredictorConvHyperparams_ == null) { + firstStageBoxPredictorConvHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + input.ReadMessage(firstStageBoxPredictorConvHyperparams_); + break; + } + case 72: { + FirstStageBoxPredictorKernelSize = input.ReadInt32(); + break; + } + case 80: { + FirstStageBoxPredictorDepth = input.ReadInt32(); + break; + } + case 88: { + FirstStageMinibatchSize = input.ReadInt32(); + break; + } + case 101: { + FirstStagePositiveBalanceFraction = input.ReadFloat(); + break; + } + case 109: { + FirstStageNmsScoreThreshold = input.ReadFloat(); + break; + } + case 117: { + FirstStageNmsIouThreshold = input.ReadFloat(); + break; + } + case 120: { + FirstStageMaxProposals = input.ReadInt32(); + break; + } + case 133: { + FirstStageLocalizationLossWeight = input.ReadFloat(); + break; + } + case 141: { + FirstStageObjectnessLossWeight = input.ReadFloat(); + break; + } + case 144: { + InitialCropSize = input.ReadInt32(); + break; + } + case 152: { + MaxpoolKernelSize = input.ReadInt32(); + break; + } + case 160: { + MaxpoolStride = input.ReadInt32(); + break; + } + case 170: { + if (secondStageBoxPredictor_ == null) { + secondStageBoxPredictor_ = new global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor(); + } + input.ReadMessage(secondStageBoxPredictor_); + break; + } + case 176: { + SecondStageBatchSize = input.ReadInt32(); + break; + } + case 189: { + SecondStageBalanceFraction = input.ReadFloat(); + break; + } + case 194: { + if (secondStagePostProcessing_ == null) { + secondStagePostProcessing_ = new global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing(); + } + input.ReadMessage(secondStagePostProcessing_); + break; + } + case 205: { + SecondStageLocalizationLossWeight = input.ReadFloat(); + break; + } + case 213: { + SecondStageClassificationLossWeight = input.ReadFloat(); + break; + } + case 221: { + SecondStageMaskPredictionLossWeight = input.ReadFloat(); + break; + } + case 226: { + if (hardExampleMiner_ == null) { + hardExampleMiner_ = new global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner(); + } + input.ReadMessage(hardExampleMiner_); + break; + } + case 234: { + if (secondStageClassificationLoss_ == null) { + secondStageClassificationLoss_ = new global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss(); + } + input.ReadMessage(secondStageClassificationLoss_); + break; + } + case 240: { + InplaceBatchnormUpdate = input.ReadBool(); + break; + } + case 248: { + UseMatmulCropAndResize = input.ReadBool(); + break; + } + case 256: { + ClipAnchorsToImage = input.ReadBool(); + break; + } + case 264: { + UseMatmulGatherInMatcher = input.ReadBool(); + break; + } + case 272: { + UseStaticBalancedLabelSampler = input.ReadBool(); + break; + } + case 280: { + UseStaticShapes = input.ReadBool(); + break; + } + case 288: { + ResizeMasks = input.ReadBool(); + break; + } + case 296: { + UseStaticShapesForEval = input.ReadBool(); + break; + } + } + } + } + + } + + public sealed partial class FasterRcnnFeatureExtractor : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FasterRcnnFeatureExtractor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FasterRcnnFeatureExtractor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FasterRcnnFeatureExtractor(FasterRcnnFeatureExtractor other) : this() { + type_ = other.type_; + firstStageFeaturesStride_ = other.firstStageFeaturesStride_; + batchNormTrainable_ = other.batchNormTrainable_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FasterRcnnFeatureExtractor Clone() { + return new FasterRcnnFeatureExtractor(this); + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 1; + private string type_ = ""; + /// + /// Type of Faster R-CNN model (e.g., 'faster_rcnn_resnet101'; + /// See builders/model_builder.py for expected types). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Type { + get { return type_; } + set { + type_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "first_stage_features_stride" field. + public const int FirstStageFeaturesStrideFieldNumber = 2; + private int firstStageFeaturesStride_; + /// + /// Output stride of extracted RPN feature map. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int FirstStageFeaturesStride { + get { return firstStageFeaturesStride_; } + set { + firstStageFeaturesStride_ = value; + } + } + + /// Field number for the "batch_norm_trainable" field. + public const int BatchNormTrainableFieldNumber = 3; + private bool batchNormTrainable_; + /// + /// Whether to update batch norm parameters during training or not. + /// When training with a relative large batch size (e.g. 8), it could be + /// desirable to enable batch norm update. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool BatchNormTrainable { + get { return batchNormTrainable_; } + set { + batchNormTrainable_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FasterRcnnFeatureExtractor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FasterRcnnFeatureExtractor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Type != other.Type) return false; + if (FirstStageFeaturesStride != other.FirstStageFeaturesStride) return false; + if (BatchNormTrainable != other.BatchNormTrainable) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Type.Length != 0) hash ^= Type.GetHashCode(); + if (FirstStageFeaturesStride != 0) hash ^= FirstStageFeaturesStride.GetHashCode(); + if (BatchNormTrainable != false) hash ^= BatchNormTrainable.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Type.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Type); + } + if (FirstStageFeaturesStride != 0) { + output.WriteRawTag(16); + output.WriteInt32(FirstStageFeaturesStride); + } + if (BatchNormTrainable != false) { + output.WriteRawTag(24); + output.WriteBool(BatchNormTrainable); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Type.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Type); + } + if (FirstStageFeaturesStride != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(FirstStageFeaturesStride); + } + if (BatchNormTrainable != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FasterRcnnFeatureExtractor other) { + if (other == null) { + return; + } + if (other.Type.Length != 0) { + Type = other.Type; + } + if (other.FirstStageFeaturesStride != 0) { + FirstStageFeaturesStride = other.FirstStageFeaturesStride; + } + if (other.BatchNormTrainable != false) { + BatchNormTrainable = other.BatchNormTrainable; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Type = input.ReadString(); + break; + } + case 16: { + FirstStageFeaturesStride = input.ReadInt32(); + break; + } + case 24: { + BatchNormTrainable = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/FasterRcnnBoxCoder.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/FasterRcnnBoxCoder.cs new file mode 100644 index 00000000..bf9dbe92 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/FasterRcnnBoxCoder.cs @@ -0,0 +1,272 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/faster_rcnn_box_coder.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/faster_rcnn_box_coder.proto + public static partial class FasterRcnnBoxCoderReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/faster_rcnn_box_coder.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static FasterRcnnBoxCoderReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjNvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9mYXN0ZXJfcmNubl9ib3hfY29k", + "ZXIucHJvdG8SF29iamVjdF9kZXRlY3Rpb24ucHJvdG9zImEKEkZhc3RlclJj", + "bm5Cb3hDb2RlchIPCgd5X3NjYWxlGAEgASgCEg8KB3hfc2NhbGUYAiABKAIS", + "FAoMaGVpZ2h0X3NjYWxlGAMgASgCEhMKC3dpZHRoX3NjYWxlGAQgASgCYgZw", + "cm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnBoxCoder), global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnBoxCoder.Parser, new[]{ "YScale", "XScale", "HeightScale", "WidthScale" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for FasterRCNNBoxCoder. See + /// box_coders/faster_rcnn_box_coder.py for details. + /// + public sealed partial class FasterRcnnBoxCoder : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FasterRcnnBoxCoder()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnBoxCoderReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FasterRcnnBoxCoder() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FasterRcnnBoxCoder(FasterRcnnBoxCoder other) : this() { + yScale_ = other.yScale_; + xScale_ = other.xScale_; + heightScale_ = other.heightScale_; + widthScale_ = other.widthScale_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FasterRcnnBoxCoder Clone() { + return new FasterRcnnBoxCoder(this); + } + + /// Field number for the "y_scale" field. + public const int YScaleFieldNumber = 1; + private float yScale_; + /// + /// Scale factor for anchor encoded box center. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float YScale { + get { return yScale_; } + set { + yScale_ = value; + } + } + + /// Field number for the "x_scale" field. + public const int XScaleFieldNumber = 2; + private float xScale_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float XScale { + get { return xScale_; } + set { + xScale_ = value; + } + } + + /// Field number for the "height_scale" field. + public const int HeightScaleFieldNumber = 3; + private float heightScale_; + /// + /// Scale factor for anchor encoded box height. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float HeightScale { + get { return heightScale_; } + set { + heightScale_ = value; + } + } + + /// Field number for the "width_scale" field. + public const int WidthScaleFieldNumber = 4; + private float widthScale_; + /// + /// Scale factor for anchor encoded box width. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float WidthScale { + get { return widthScale_; } + set { + widthScale_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FasterRcnnBoxCoder); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FasterRcnnBoxCoder other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(YScale, other.YScale)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(XScale, other.XScale)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(HeightScale, other.HeightScale)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(WidthScale, other.WidthScale)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (YScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(YScale); + if (XScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(XScale); + if (HeightScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(HeightScale); + if (WidthScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(WidthScale); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (YScale != 0F) { + output.WriteRawTag(13); + output.WriteFloat(YScale); + } + if (XScale != 0F) { + output.WriteRawTag(21); + output.WriteFloat(XScale); + } + if (HeightScale != 0F) { + output.WriteRawTag(29); + output.WriteFloat(HeightScale); + } + if (WidthScale != 0F) { + output.WriteRawTag(37); + output.WriteFloat(WidthScale); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (YScale != 0F) { + size += 1 + 4; + } + if (XScale != 0F) { + size += 1 + 4; + } + if (HeightScale != 0F) { + size += 1 + 4; + } + if (WidthScale != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FasterRcnnBoxCoder other) { + if (other == null) { + return; + } + if (other.YScale != 0F) { + YScale = other.YScale; + } + if (other.XScale != 0F) { + XScale = other.XScale; + } + if (other.HeightScale != 0F) { + HeightScale = other.HeightScale; + } + if (other.WidthScale != 0F) { + WidthScale = other.WidthScale; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + YScale = input.ReadFloat(); + break; + } + case 21: { + XScale = input.ReadFloat(); + break; + } + case 29: { + HeightScale = input.ReadFloat(); + break; + } + case 37: { + WidthScale = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/FlexibleGridAnchorGenerator.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/FlexibleGridAnchorGenerator.cs new file mode 100644 index 00000000..2847b5fd --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/FlexibleGridAnchorGenerator.cs @@ -0,0 +1,476 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/flexible_grid_anchor_generator.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/flexible_grid_anchor_generator.proto + public static partial class FlexibleGridAnchorGeneratorReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/flexible_grid_anchor_generator.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static FlexibleGridAnchorGeneratorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjxvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9mbGV4aWJsZV9ncmlkX2FuY2hv", + "cl9nZW5lcmF0b3IucHJvdG8SF29iamVjdF9kZXRlY3Rpb24ucHJvdG9zInYK", + "G0ZsZXhpYmxlR3JpZEFuY2hvckdlbmVyYXRvchI4CgthbmNob3JfZ3JpZBgB", + "IAMoCzIjLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLkFuY2hvckdyaWQSHQoV", + "bm9ybWFsaXplX2Nvb3JkaW5hdGVzGAIgASgIIpEBCgpBbmNob3JHcmlkEhIK", + "CmJhc2Vfc2l6ZXMYASADKAISFQoNYXNwZWN0X3JhdGlvcxgCIAMoAhIVCg1o", + "ZWlnaHRfc3RyaWRlGAMgASgNEhQKDHdpZHRoX3N0cmlkZRgEIAEoDRIVCg1o", + "ZWlnaHRfb2Zmc2V0GAUgASgNEhQKDHdpZHRoX29mZnNldBgGIAEoDWIGcHJv", + "dG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGenerator), global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGenerator.Parser, new[]{ "AnchorGrid", "NormalizeCoordinates" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.AnchorGrid), global::Tensorflow.Models.ObjectDetection.Protos.AnchorGrid.Parser, new[]{ "BaseSizes", "AspectRatios", "HeightStride", "WidthStride", "HeightOffset", "WidthOffset" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class FlexibleGridAnchorGenerator : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FlexibleGridAnchorGenerator()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGeneratorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FlexibleGridAnchorGenerator() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FlexibleGridAnchorGenerator(FlexibleGridAnchorGenerator other) : this() { + anchorGrid_ = other.anchorGrid_.Clone(); + normalizeCoordinates_ = other.normalizeCoordinates_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FlexibleGridAnchorGenerator Clone() { + return new FlexibleGridAnchorGenerator(this); + } + + /// Field number for the "anchor_grid" field. + public const int AnchorGridFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_anchorGrid_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.Models.ObjectDetection.Protos.AnchorGrid.Parser); + private readonly pbc::RepeatedField anchorGrid_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField AnchorGrid { + get { return anchorGrid_; } + } + + /// Field number for the "normalize_coordinates" field. + public const int NormalizeCoordinatesFieldNumber = 2; + private bool normalizeCoordinates_; + /// + /// Whether to produce anchors in normalized coordinates. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool NormalizeCoordinates { + get { return normalizeCoordinates_; } + set { + normalizeCoordinates_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FlexibleGridAnchorGenerator); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FlexibleGridAnchorGenerator other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!anchorGrid_.Equals(other.anchorGrid_)) return false; + if (NormalizeCoordinates != other.NormalizeCoordinates) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= anchorGrid_.GetHashCode(); + if (NormalizeCoordinates != false) hash ^= NormalizeCoordinates.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + anchorGrid_.WriteTo(output, _repeated_anchorGrid_codec); + if (NormalizeCoordinates != false) { + output.WriteRawTag(16); + output.WriteBool(NormalizeCoordinates); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += anchorGrid_.CalculateSize(_repeated_anchorGrid_codec); + if (NormalizeCoordinates != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FlexibleGridAnchorGenerator other) { + if (other == null) { + return; + } + anchorGrid_.Add(other.anchorGrid_); + if (other.NormalizeCoordinates != false) { + NormalizeCoordinates = other.NormalizeCoordinates; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + anchorGrid_.AddEntriesFrom(input, _repeated_anchorGrid_codec); + break; + } + case 16: { + NormalizeCoordinates = input.ReadBool(); + break; + } + } + } + } + + } + + public sealed partial class AnchorGrid : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AnchorGrid()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.FlexibleGridAnchorGeneratorReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnchorGrid() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnchorGrid(AnchorGrid other) : this() { + baseSizes_ = other.baseSizes_.Clone(); + aspectRatios_ = other.aspectRatios_.Clone(); + heightStride_ = other.heightStride_; + widthStride_ = other.widthStride_; + heightOffset_ = other.heightOffset_; + widthOffset_ = other.widthOffset_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AnchorGrid Clone() { + return new AnchorGrid(this); + } + + /// Field number for the "base_sizes" field. + public const int BaseSizesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_baseSizes_codec + = pb::FieldCodec.ForFloat(10); + private readonly pbc::RepeatedField baseSizes_ = new pbc::RepeatedField(); + /// + /// The base sizes in pixels for each anchor in this anchor layer. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField BaseSizes { + get { return baseSizes_; } + } + + /// Field number for the "aspect_ratios" field. + public const int AspectRatiosFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_aspectRatios_codec + = pb::FieldCodec.ForFloat(18); + private readonly pbc::RepeatedField aspectRatios_ = new pbc::RepeatedField(); + /// + /// The aspect ratios for each anchor in this anchor layer. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField AspectRatios { + get { return aspectRatios_; } + } + + /// Field number for the "height_stride" field. + public const int HeightStrideFieldNumber = 3; + private uint heightStride_; + /// + /// The anchor height stride in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint HeightStride { + get { return heightStride_; } + set { + heightStride_ = value; + } + } + + /// Field number for the "width_stride" field. + public const int WidthStrideFieldNumber = 4; + private uint widthStride_; + /// + /// The anchor width stride in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint WidthStride { + get { return widthStride_; } + set { + widthStride_ = value; + } + } + + /// Field number for the "height_offset" field. + public const int HeightOffsetFieldNumber = 5; + private uint heightOffset_; + /// + /// The anchor height offset in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint HeightOffset { + get { return heightOffset_; } + set { + heightOffset_ = value; + } + } + + /// Field number for the "width_offset" field. + public const int WidthOffsetFieldNumber = 6; + private uint widthOffset_; + /// + /// The anchor width offset in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint WidthOffset { + get { return widthOffset_; } + set { + widthOffset_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AnchorGrid); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AnchorGrid other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!baseSizes_.Equals(other.baseSizes_)) return false; + if(!aspectRatios_.Equals(other.aspectRatios_)) return false; + if (HeightStride != other.HeightStride) return false; + if (WidthStride != other.WidthStride) return false; + if (HeightOffset != other.HeightOffset) return false; + if (WidthOffset != other.WidthOffset) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= baseSizes_.GetHashCode(); + hash ^= aspectRatios_.GetHashCode(); + if (HeightStride != 0) hash ^= HeightStride.GetHashCode(); + if (WidthStride != 0) hash ^= WidthStride.GetHashCode(); + if (HeightOffset != 0) hash ^= HeightOffset.GetHashCode(); + if (WidthOffset != 0) hash ^= WidthOffset.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + baseSizes_.WriteTo(output, _repeated_baseSizes_codec); + aspectRatios_.WriteTo(output, _repeated_aspectRatios_codec); + if (HeightStride != 0) { + output.WriteRawTag(24); + output.WriteUInt32(HeightStride); + } + if (WidthStride != 0) { + output.WriteRawTag(32); + output.WriteUInt32(WidthStride); + } + if (HeightOffset != 0) { + output.WriteRawTag(40); + output.WriteUInt32(HeightOffset); + } + if (WidthOffset != 0) { + output.WriteRawTag(48); + output.WriteUInt32(WidthOffset); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += baseSizes_.CalculateSize(_repeated_baseSizes_codec); + size += aspectRatios_.CalculateSize(_repeated_aspectRatios_codec); + if (HeightStride != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(HeightStride); + } + if (WidthStride != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(WidthStride); + } + if (HeightOffset != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(HeightOffset); + } + if (WidthOffset != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(WidthOffset); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AnchorGrid other) { + if (other == null) { + return; + } + baseSizes_.Add(other.baseSizes_); + aspectRatios_.Add(other.aspectRatios_); + if (other.HeightStride != 0) { + HeightStride = other.HeightStride; + } + if (other.WidthStride != 0) { + WidthStride = other.WidthStride; + } + if (other.HeightOffset != 0) { + HeightOffset = other.HeightOffset; + } + if (other.WidthOffset != 0) { + WidthOffset = other.WidthOffset; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 13: { + baseSizes_.AddEntriesFrom(input, _repeated_baseSizes_codec); + break; + } + case 18: + case 21: { + aspectRatios_.AddEntriesFrom(input, _repeated_aspectRatios_codec); + break; + } + case 24: { + HeightStride = input.ReadUInt32(); + break; + } + case 32: { + WidthStride = input.ReadUInt32(); + break; + } + case 40: { + HeightOffset = input.ReadUInt32(); + break; + } + case 48: { + WidthOffset = input.ReadUInt32(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/GraphRewriter.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/GraphRewriter.cs new file mode 100644 index 00000000..04d3530c --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/GraphRewriter.cs @@ -0,0 +1,417 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/graph_rewriter.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/graph_rewriter.proto + public static partial class GraphRewriterReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/graph_rewriter.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static GraphRewriterReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CixvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9ncmFwaF9yZXdyaXRlci5wcm90", + "bxIXb2JqZWN0X2RldGVjdGlvbi5wcm90b3MiTAoNR3JhcGhSZXdyaXRlchI7", + "CgxxdWFudGl6YXRpb24YASABKAsyJS5vYmplY3RfZGV0ZWN0aW9uLnByb3Rv", + "cy5RdWFudGl6YXRpb24iXgoMUXVhbnRpemF0aW9uEg0KBWRlbGF5GAEgASgF", + "EhMKC3dlaWdodF9iaXRzGAIgASgFEhcKD2FjdGl2YXRpb25fYml0cxgDIAEo", + "BRIRCglzeW1tZXRyaWMYBCABKAhiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.GraphRewriter), global::Tensorflow.Models.ObjectDetection.Protos.GraphRewriter.Parser, new[]{ "Quantization" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Quantization), global::Tensorflow.Models.ObjectDetection.Protos.Quantization.Parser, new[]{ "Delay", "WeightBits", "ActivationBits", "Symmetric" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Message to configure graph rewriter for the tf graph. + /// + public sealed partial class GraphRewriter : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GraphRewriter()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.GraphRewriterReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GraphRewriter() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GraphRewriter(GraphRewriter other) : this() { + quantization_ = other.quantization_ != null ? other.quantization_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GraphRewriter Clone() { + return new GraphRewriter(this); + } + + /// Field number for the "quantization" field. + public const int QuantizationFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.Quantization quantization_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Quantization Quantization { + get { return quantization_; } + set { + quantization_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as GraphRewriter); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(GraphRewriter other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Quantization, other.Quantization)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (quantization_ != null) hash ^= Quantization.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (quantization_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Quantization); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (quantization_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Quantization); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(GraphRewriter other) { + if (other == null) { + return; + } + if (other.quantization_ != null) { + if (quantization_ == null) { + quantization_ = new global::Tensorflow.Models.ObjectDetection.Protos.Quantization(); + } + Quantization.MergeFrom(other.Quantization); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (quantization_ == null) { + quantization_ = new global::Tensorflow.Models.ObjectDetection.Protos.Quantization(); + } + input.ReadMessage(quantization_); + break; + } + } + } + } + + } + + /// + /// Message for quantization options. See + /// tensorflow/contrib/quantize/python/quantize.py for details. + /// + public sealed partial class Quantization : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Quantization()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.GraphRewriterReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Quantization() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Quantization(Quantization other) : this() { + delay_ = other.delay_; + weightBits_ = other.weightBits_; + activationBits_ = other.activationBits_; + symmetric_ = other.symmetric_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Quantization Clone() { + return new Quantization(this); + } + + /// Field number for the "delay" field. + public const int DelayFieldNumber = 1; + private int delay_; + /// + /// Number of steps to delay before quantization takes effect during training. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Delay { + get { return delay_; } + set { + delay_ = value; + } + } + + /// Field number for the "weight_bits" field. + public const int WeightBitsFieldNumber = 2; + private int weightBits_; + /// + /// Number of bits to use for quantizing weights. + /// Only 8 bit is supported for now. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int WeightBits { + get { return weightBits_; } + set { + weightBits_ = value; + } + } + + /// Field number for the "activation_bits" field. + public const int ActivationBitsFieldNumber = 3; + private int activationBits_; + /// + /// Number of bits to use for quantizing activations. + /// Only 8 bit is supported for now. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int ActivationBits { + get { return activationBits_; } + set { + activationBits_ = value; + } + } + + /// Field number for the "symmetric" field. + public const int SymmetricFieldNumber = 4; + private bool symmetric_; + /// + /// Whether to use symmetric weight quantization. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Symmetric { + get { return symmetric_; } + set { + symmetric_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Quantization); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Quantization other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Delay != other.Delay) return false; + if (WeightBits != other.WeightBits) return false; + if (ActivationBits != other.ActivationBits) return false; + if (Symmetric != other.Symmetric) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Delay != 0) hash ^= Delay.GetHashCode(); + if (WeightBits != 0) hash ^= WeightBits.GetHashCode(); + if (ActivationBits != 0) hash ^= ActivationBits.GetHashCode(); + if (Symmetric != false) hash ^= Symmetric.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Delay != 0) { + output.WriteRawTag(8); + output.WriteInt32(Delay); + } + if (WeightBits != 0) { + output.WriteRawTag(16); + output.WriteInt32(WeightBits); + } + if (ActivationBits != 0) { + output.WriteRawTag(24); + output.WriteInt32(ActivationBits); + } + if (Symmetric != false) { + output.WriteRawTag(32); + output.WriteBool(Symmetric); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Delay != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Delay); + } + if (WeightBits != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(WeightBits); + } + if (ActivationBits != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ActivationBits); + } + if (Symmetric != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Quantization other) { + if (other == null) { + return; + } + if (other.Delay != 0) { + Delay = other.Delay; + } + if (other.WeightBits != 0) { + WeightBits = other.WeightBits; + } + if (other.ActivationBits != 0) { + ActivationBits = other.ActivationBits; + } + if (other.Symmetric != false) { + Symmetric = other.Symmetric; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Delay = input.ReadInt32(); + break; + } + case 16: { + WeightBits = input.ReadInt32(); + break; + } + case 24: { + ActivationBits = input.ReadInt32(); + break; + } + case 32: { + Symmetric = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/GridAnchorGenerator.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/GridAnchorGenerator.cs new file mode 100644 index 00000000..76b31e74 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/GridAnchorGenerator.cs @@ -0,0 +1,386 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/grid_anchor_generator.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/grid_anchor_generator.proto + public static partial class GridAnchorGeneratorReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/grid_anchor_generator.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static GridAnchorGeneratorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjNvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9ncmlkX2FuY2hvcl9nZW5lcmF0", + "b3IucHJvdG8SF29iamVjdF9kZXRlY3Rpb24ucHJvdG9zIrUBChNHcmlkQW5j", + "aG9yR2VuZXJhdG9yEg4KBmhlaWdodBgBIAEoBRINCgV3aWR0aBgCIAEoBRIV", + "Cg1oZWlnaHRfc3RyaWRlGAMgASgFEhQKDHdpZHRoX3N0cmlkZRgEIAEoBRIV", + "Cg1oZWlnaHRfb2Zmc2V0GAUgASgFEhQKDHdpZHRoX29mZnNldBgGIAEoBRIO", + "CgZzY2FsZXMYByADKAISFQoNYXNwZWN0X3JhdGlvcxgIIAMoAmIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.GridAnchorGenerator), global::Tensorflow.Models.ObjectDetection.Protos.GridAnchorGenerator.Parser, new[]{ "Height", "Width", "HeightStride", "WidthStride", "HeightOffset", "WidthOffset", "Scales", "AspectRatios" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for GridAnchorGenerator. See + /// anchor_generators/grid_anchor_generator.py for details. + /// + public sealed partial class GridAnchorGenerator : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GridAnchorGenerator()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.GridAnchorGeneratorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GridAnchorGenerator() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GridAnchorGenerator(GridAnchorGenerator other) : this() { + height_ = other.height_; + width_ = other.width_; + heightStride_ = other.heightStride_; + widthStride_ = other.widthStride_; + heightOffset_ = other.heightOffset_; + widthOffset_ = other.widthOffset_; + scales_ = other.scales_.Clone(); + aspectRatios_ = other.aspectRatios_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GridAnchorGenerator Clone() { + return new GridAnchorGenerator(this); + } + + /// Field number for the "height" field. + public const int HeightFieldNumber = 1; + private int height_; + /// + /// Anchor height in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Height { + get { return height_; } + set { + height_ = value; + } + } + + /// Field number for the "width" field. + public const int WidthFieldNumber = 2; + private int width_; + /// + /// Anchor width in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Width { + get { return width_; } + set { + width_ = value; + } + } + + /// Field number for the "height_stride" field. + public const int HeightStrideFieldNumber = 3; + private int heightStride_; + /// + /// Anchor stride in height dimension in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int HeightStride { + get { return heightStride_; } + set { + heightStride_ = value; + } + } + + /// Field number for the "width_stride" field. + public const int WidthStrideFieldNumber = 4; + private int widthStride_; + /// + /// Anchor stride in width dimension in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int WidthStride { + get { return widthStride_; } + set { + widthStride_ = value; + } + } + + /// Field number for the "height_offset" field. + public const int HeightOffsetFieldNumber = 5; + private int heightOffset_; + /// + /// Anchor height offset in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int HeightOffset { + get { return heightOffset_; } + set { + heightOffset_ = value; + } + } + + /// Field number for the "width_offset" field. + public const int WidthOffsetFieldNumber = 6; + private int widthOffset_; + /// + /// Anchor width offset in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int WidthOffset { + get { return widthOffset_; } + set { + widthOffset_ = value; + } + } + + /// Field number for the "scales" field. + public const int ScalesFieldNumber = 7; + private static readonly pb::FieldCodec _repeated_scales_codec + = pb::FieldCodec.ForFloat(58); + private readonly pbc::RepeatedField scales_ = new pbc::RepeatedField(); + /// + /// List of scales for the anchors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Scales { + get { return scales_; } + } + + /// Field number for the "aspect_ratios" field. + public const int AspectRatiosFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_aspectRatios_codec + = pb::FieldCodec.ForFloat(66); + private readonly pbc::RepeatedField aspectRatios_ = new pbc::RepeatedField(); + /// + /// List of aspect ratios for the anchors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField AspectRatios { + get { return aspectRatios_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as GridAnchorGenerator); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(GridAnchorGenerator other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Height != other.Height) return false; + if (Width != other.Width) return false; + if (HeightStride != other.HeightStride) return false; + if (WidthStride != other.WidthStride) return false; + if (HeightOffset != other.HeightOffset) return false; + if (WidthOffset != other.WidthOffset) return false; + if(!scales_.Equals(other.scales_)) return false; + if(!aspectRatios_.Equals(other.aspectRatios_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Height != 0) hash ^= Height.GetHashCode(); + if (Width != 0) hash ^= Width.GetHashCode(); + if (HeightStride != 0) hash ^= HeightStride.GetHashCode(); + if (WidthStride != 0) hash ^= WidthStride.GetHashCode(); + if (HeightOffset != 0) hash ^= HeightOffset.GetHashCode(); + if (WidthOffset != 0) hash ^= WidthOffset.GetHashCode(); + hash ^= scales_.GetHashCode(); + hash ^= aspectRatios_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Height != 0) { + output.WriteRawTag(8); + output.WriteInt32(Height); + } + if (Width != 0) { + output.WriteRawTag(16); + output.WriteInt32(Width); + } + if (HeightStride != 0) { + output.WriteRawTag(24); + output.WriteInt32(HeightStride); + } + if (WidthStride != 0) { + output.WriteRawTag(32); + output.WriteInt32(WidthStride); + } + if (HeightOffset != 0) { + output.WriteRawTag(40); + output.WriteInt32(HeightOffset); + } + if (WidthOffset != 0) { + output.WriteRawTag(48); + output.WriteInt32(WidthOffset); + } + scales_.WriteTo(output, _repeated_scales_codec); + aspectRatios_.WriteTo(output, _repeated_aspectRatios_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Height != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Height); + } + if (Width != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Width); + } + if (HeightStride != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(HeightStride); + } + if (WidthStride != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(WidthStride); + } + if (HeightOffset != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(HeightOffset); + } + if (WidthOffset != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(WidthOffset); + } + size += scales_.CalculateSize(_repeated_scales_codec); + size += aspectRatios_.CalculateSize(_repeated_aspectRatios_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(GridAnchorGenerator other) { + if (other == null) { + return; + } + if (other.Height != 0) { + Height = other.Height; + } + if (other.Width != 0) { + Width = other.Width; + } + if (other.HeightStride != 0) { + HeightStride = other.HeightStride; + } + if (other.WidthStride != 0) { + WidthStride = other.WidthStride; + } + if (other.HeightOffset != 0) { + HeightOffset = other.HeightOffset; + } + if (other.WidthOffset != 0) { + WidthOffset = other.WidthOffset; + } + scales_.Add(other.scales_); + aspectRatios_.Add(other.aspectRatios_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Height = input.ReadInt32(); + break; + } + case 16: { + Width = input.ReadInt32(); + break; + } + case 24: { + HeightStride = input.ReadInt32(); + break; + } + case 32: { + WidthStride = input.ReadInt32(); + break; + } + case 40: { + HeightOffset = input.ReadInt32(); + break; + } + case 48: { + WidthOffset = input.ReadInt32(); + break; + } + case 58: + case 61: { + scales_.AddEntriesFrom(input, _repeated_scales_codec); + break; + } + case 66: + case 69: { + aspectRatios_.AddEntriesFrom(input, _repeated_aspectRatios_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Hyperparams.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Hyperparams.cs new file mode 100644 index 00000000..315b1c68 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Hyperparams.cs @@ -0,0 +1,2106 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/hyperparams.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/hyperparams.proto + public static partial class HyperparamsReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/hyperparams.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static HyperparamsReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CilvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9oeXBlcnBhcmFtcy5wcm90bxIX", + "b2JqZWN0X2RldGVjdGlvbi5wcm90b3Mi8wMKC0h5cGVycGFyYW1zEjMKAm9w", + "GAEgASgOMicub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuSHlwZXJwYXJhbXMu", + "T3ASOQoLcmVndWxhcml6ZXIYAiABKAsyJC5vYmplY3RfZGV0ZWN0aW9uLnBy", + "b3Rvcy5SZWd1bGFyaXplchI5Cgtpbml0aWFsaXplchgDIAEoCzIkLm9iamVj", + "dF9kZXRlY3Rpb24ucHJvdG9zLkluaXRpYWxpemVyEkMKCmFjdGl2YXRpb24Y", + "BCABKA4yLy5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5IeXBlcnBhcmFtcy5B", + "Y3RpdmF0aW9uEjgKCmJhdGNoX25vcm0YBSABKAsyIi5vYmplY3RfZGV0ZWN0", + "aW9uLnByb3Rvcy5CYXRjaE5vcm1IABI4Cgpncm91cF9ub3JtGAcgASgLMiIu", + "b2JqZWN0X2RldGVjdGlvbi5wcm90b3MuR3JvdXBOb3JtSAASHAoUcmVndWxh", + "cml6ZV9kZXB0aHdpc2UYBiABKAgiIAoCT3ASCAoETlVMTBAAEggKBENPTlYQ", + "ARIGCgJGQxACIiwKCkFjdGl2YXRpb24SCAoETk9ORRAAEggKBFJFTFUQARIK", + "CgZSRUxVXzYQAkISChBub3JtYWxpemVyX29uZW9mIqYBCgtSZWd1bGFyaXpl", + "chJACg5sMV9yZWd1bGFyaXplchgBIAEoCzImLm9iamVjdF9kZXRlY3Rpb24u", + "cHJvdG9zLkwxUmVndWxhcml6ZXJIABJACg5sMl9yZWd1bGFyaXplchgCIAEo", + "CzImLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLkwyUmVndWxhcml6ZXJIAEIT", + "ChFyZWd1bGFyaXplcl9vbmVvZiIfCg1MMVJlZ3VsYXJpemVyEg4KBndlaWdo", + "dBgBIAEoAiIfCg1MMlJlZ3VsYXJpemVyEg4KBndlaWdodBgBIAEoAiKzAgoL", + "SW5pdGlhbGl6ZXISWwocdHJ1bmNhdGVkX25vcm1hbF9pbml0aWFsaXplchgB", + "IAEoCzIzLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlRydW5jYXRlZE5vcm1h", + "bEluaXRpYWxpemVySAASWwocdmFyaWFuY2Vfc2NhbGluZ19pbml0aWFsaXpl", + "chgCIAEoCzIzLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlZhcmlhbmNlU2Nh", + "bGluZ0luaXRpYWxpemVySAASVQoZcmFuZG9tX25vcm1hbF9pbml0aWFsaXpl", + "chgDIAEoCzIwLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlJhbmRvbU5vcm1h", + "bEluaXRpYWxpemVySABCEwoRaW5pdGlhbGl6ZXJfb25lb2YiOgoaVHJ1bmNh", + "dGVkTm9ybWFsSW5pdGlhbGl6ZXISDAoEbWVhbhgBIAEoAhIOCgZzdGRkZXYY", + "AiABKAIiswEKGlZhcmlhbmNlU2NhbGluZ0luaXRpYWxpemVyEg4KBmZhY3Rv", + "chgBIAEoAhIPCgd1bmlmb3JtGAIgASgIEkYKBG1vZGUYAyABKA4yOC5vYmpl", + "Y3RfZGV0ZWN0aW9uLnByb3Rvcy5WYXJpYW5jZVNjYWxpbmdJbml0aWFsaXpl", + "ci5Nb2RlIiwKBE1vZGUSCgoGRkFOX0lOEAASCwoHRkFOX09VVBABEgsKB0ZB", + "Tl9BVkcQAiI3ChdSYW5kb21Ob3JtYWxJbml0aWFsaXplchIMCgRtZWFuGAEg", + "ASgCEg4KBnN0ZGRldhgCIAEoAiJZCglCYXRjaE5vcm0SDQoFZGVjYXkYASAB", + "KAISDgoGY2VudGVyGAIgASgIEg0KBXNjYWxlGAMgASgIEg8KB2Vwc2lsb24Y", + "BCABKAISDQoFdHJhaW4YBSABKAgiCwoJR3JvdXBOb3JtYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams), global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams.Parser, new[]{ "Op", "Regularizer", "Initializer", "Activation", "BatchNorm", "GroupNorm", "RegularizeDepthwise" }, new[]{ "NormalizerOneof" }, new[]{ typeof(global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams.Types.Op), typeof(global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams.Types.Activation) }, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Regularizer), global::Tensorflow.Models.ObjectDetection.Protos.Regularizer.Parser, new[]{ "L1Regularizer", "L2Regularizer" }, new[]{ "RegularizerOneof" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.L1Regularizer), global::Tensorflow.Models.ObjectDetection.Protos.L1Regularizer.Parser, new[]{ "Weight" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.L2Regularizer), global::Tensorflow.Models.ObjectDetection.Protos.L2Regularizer.Parser, new[]{ "Weight" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Initializer), global::Tensorflow.Models.ObjectDetection.Protos.Initializer.Parser, new[]{ "TruncatedNormalInitializer", "VarianceScalingInitializer", "RandomNormalInitializer" }, new[]{ "InitializerOneof" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.TruncatedNormalInitializer), global::Tensorflow.Models.ObjectDetection.Protos.TruncatedNormalInitializer.Parser, new[]{ "Mean", "Stddev" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer), global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer.Parser, new[]{ "Factor", "Uniform", "Mode" }, null, new[]{ typeof(global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer.Types.Mode) }, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomNormalInitializer), global::Tensorflow.Models.ObjectDetection.Protos.RandomNormalInitializer.Parser, new[]{ "Mean", "Stddev" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.BatchNorm), global::Tensorflow.Models.ObjectDetection.Protos.BatchNorm.Parser, new[]{ "Decay", "Center", "Scale", "Epsilon", "Train" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.GroupNorm), global::Tensorflow.Models.ObjectDetection.Protos.GroupNorm.Parser, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for the convolution op hyperparameters to use in the + /// object detection pipeline. + /// + public sealed partial class Hyperparams : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Hyperparams()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Hyperparams() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Hyperparams(Hyperparams other) : this() { + op_ = other.op_; + regularizer_ = other.regularizer_ != null ? other.regularizer_.Clone() : null; + initializer_ = other.initializer_ != null ? other.initializer_.Clone() : null; + activation_ = other.activation_; + regularizeDepthwise_ = other.regularizeDepthwise_; + switch (other.NormalizerOneofCase) { + case NormalizerOneofOneofCase.BatchNorm: + BatchNorm = other.BatchNorm.Clone(); + break; + case NormalizerOneofOneofCase.GroupNorm: + GroupNorm = other.GroupNorm.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Hyperparams Clone() { + return new Hyperparams(this); + } + + /// Field number for the "op" field. + public const int OpFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams.Types.Op op_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams.Types.Op Op { + get { return op_; } + set { + op_ = value; + } + } + + /// Field number for the "regularizer" field. + public const int RegularizerFieldNumber = 2; + private global::Tensorflow.Models.ObjectDetection.Protos.Regularizer regularizer_; + /// + /// Regularizer for the weights of the convolution op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Regularizer Regularizer { + get { return regularizer_; } + set { + regularizer_ = value; + } + } + + /// Field number for the "initializer" field. + public const int InitializerFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.Initializer initializer_; + /// + /// Initializer for the weights of the convolution op. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Initializer Initializer { + get { return initializer_; } + set { + initializer_ = value; + } + } + + /// Field number for the "activation" field. + public const int ActivationFieldNumber = 4; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams.Types.Activation activation_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams.Types.Activation Activation { + get { return activation_; } + set { + activation_ = value; + } + } + + /// Field number for the "batch_norm" field. + public const int BatchNormFieldNumber = 5; + /// + /// Note that if nothing below is selected, then no normalization is applied + /// BatchNorm hyperparameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.BatchNorm BatchNorm { + get { return normalizerOneofCase_ == NormalizerOneofOneofCase.BatchNorm ? (global::Tensorflow.Models.ObjectDetection.Protos.BatchNorm) normalizerOneof_ : null; } + set { + normalizerOneof_ = value; + normalizerOneofCase_ = value == null ? NormalizerOneofOneofCase.None : NormalizerOneofOneofCase.BatchNorm; + } + } + + /// Field number for the "group_norm" field. + public const int GroupNormFieldNumber = 7; + /// + /// GroupNorm hyperparameters. This is only supported on a subset of models. + /// Note that the current implementation of group norm instantiated in + /// tf.contrib.group.layers.group_norm() only supports fixed_size_resizer + /// for image preprocessing. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.GroupNorm GroupNorm { + get { return normalizerOneofCase_ == NormalizerOneofOneofCase.GroupNorm ? (global::Tensorflow.Models.ObjectDetection.Protos.GroupNorm) normalizerOneof_ : null; } + set { + normalizerOneof_ = value; + normalizerOneofCase_ = value == null ? NormalizerOneofOneofCase.None : NormalizerOneofOneofCase.GroupNorm; + } + } + + /// Field number for the "regularize_depthwise" field. + public const int RegularizeDepthwiseFieldNumber = 6; + private bool regularizeDepthwise_; + /// + /// Whether depthwise convolutions should be regularized. If this parameter is + /// NOT set then the conv hyperparams will default to the parent scope. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool RegularizeDepthwise { + get { return regularizeDepthwise_; } + set { + regularizeDepthwise_ = value; + } + } + + private object normalizerOneof_; + /// Enum of possible cases for the "normalizer_oneof" oneof. + public enum NormalizerOneofOneofCase { + None = 0, + BatchNorm = 5, + GroupNorm = 7, + } + private NormalizerOneofOneofCase normalizerOneofCase_ = NormalizerOneofOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NormalizerOneofOneofCase NormalizerOneofCase { + get { return normalizerOneofCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearNormalizerOneof() { + normalizerOneofCase_ = NormalizerOneofOneofCase.None; + normalizerOneof_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Hyperparams); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Hyperparams other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Op != other.Op) return false; + if (!object.Equals(Regularizer, other.Regularizer)) return false; + if (!object.Equals(Initializer, other.Initializer)) return false; + if (Activation != other.Activation) return false; + if (!object.Equals(BatchNorm, other.BatchNorm)) return false; + if (!object.Equals(GroupNorm, other.GroupNorm)) return false; + if (RegularizeDepthwise != other.RegularizeDepthwise) return false; + if (NormalizerOneofCase != other.NormalizerOneofCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Op != 0) hash ^= Op.GetHashCode(); + if (regularizer_ != null) hash ^= Regularizer.GetHashCode(); + if (initializer_ != null) hash ^= Initializer.GetHashCode(); + if (Activation != 0) hash ^= Activation.GetHashCode(); + if (normalizerOneofCase_ == NormalizerOneofOneofCase.BatchNorm) hash ^= BatchNorm.GetHashCode(); + if (normalizerOneofCase_ == NormalizerOneofOneofCase.GroupNorm) hash ^= GroupNorm.GetHashCode(); + if (RegularizeDepthwise != false) hash ^= RegularizeDepthwise.GetHashCode(); + hash ^= (int) normalizerOneofCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Op != 0) { + output.WriteRawTag(8); + output.WriteEnum((int) Op); + } + if (regularizer_ != null) { + output.WriteRawTag(18); + output.WriteMessage(Regularizer); + } + if (initializer_ != null) { + output.WriteRawTag(26); + output.WriteMessage(Initializer); + } + if (Activation != 0) { + output.WriteRawTag(32); + output.WriteEnum((int) Activation); + } + if (normalizerOneofCase_ == NormalizerOneofOneofCase.BatchNorm) { + output.WriteRawTag(42); + output.WriteMessage(BatchNorm); + } + if (RegularizeDepthwise != false) { + output.WriteRawTag(48); + output.WriteBool(RegularizeDepthwise); + } + if (normalizerOneofCase_ == NormalizerOneofOneofCase.GroupNorm) { + output.WriteRawTag(58); + output.WriteMessage(GroupNorm); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Op != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Op); + } + if (regularizer_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Regularizer); + } + if (initializer_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Initializer); + } + if (Activation != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Activation); + } + if (normalizerOneofCase_ == NormalizerOneofOneofCase.BatchNorm) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BatchNorm); + } + if (normalizerOneofCase_ == NormalizerOneofOneofCase.GroupNorm) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(GroupNorm); + } + if (RegularizeDepthwise != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Hyperparams other) { + if (other == null) { + return; + } + if (other.Op != 0) { + Op = other.Op; + } + if (other.regularizer_ != null) { + if (regularizer_ == null) { + regularizer_ = new global::Tensorflow.Models.ObjectDetection.Protos.Regularizer(); + } + Regularizer.MergeFrom(other.Regularizer); + } + if (other.initializer_ != null) { + if (initializer_ == null) { + initializer_ = new global::Tensorflow.Models.ObjectDetection.Protos.Initializer(); + } + Initializer.MergeFrom(other.Initializer); + } + if (other.Activation != 0) { + Activation = other.Activation; + } + if (other.RegularizeDepthwise != false) { + RegularizeDepthwise = other.RegularizeDepthwise; + } + switch (other.NormalizerOneofCase) { + case NormalizerOneofOneofCase.BatchNorm: + if (BatchNorm == null) { + BatchNorm = new global::Tensorflow.Models.ObjectDetection.Protos.BatchNorm(); + } + BatchNorm.MergeFrom(other.BatchNorm); + break; + case NormalizerOneofOneofCase.GroupNorm: + if (GroupNorm == null) { + GroupNorm = new global::Tensorflow.Models.ObjectDetection.Protos.GroupNorm(); + } + GroupNorm.MergeFrom(other.GroupNorm); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + op_ = (global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams.Types.Op) input.ReadEnum(); + break; + } + case 18: { + if (regularizer_ == null) { + regularizer_ = new global::Tensorflow.Models.ObjectDetection.Protos.Regularizer(); + } + input.ReadMessage(regularizer_); + break; + } + case 26: { + if (initializer_ == null) { + initializer_ = new global::Tensorflow.Models.ObjectDetection.Protos.Initializer(); + } + input.ReadMessage(initializer_); + break; + } + case 32: { + activation_ = (global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams.Types.Activation) input.ReadEnum(); + break; + } + case 42: { + global::Tensorflow.Models.ObjectDetection.Protos.BatchNorm subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.BatchNorm(); + if (normalizerOneofCase_ == NormalizerOneofOneofCase.BatchNorm) { + subBuilder.MergeFrom(BatchNorm); + } + input.ReadMessage(subBuilder); + BatchNorm = subBuilder; + break; + } + case 48: { + RegularizeDepthwise = input.ReadBool(); + break; + } + case 58: { + global::Tensorflow.Models.ObjectDetection.Protos.GroupNorm subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.GroupNorm(); + if (normalizerOneofCase_ == NormalizerOneofOneofCase.GroupNorm) { + subBuilder.MergeFrom(GroupNorm); + } + input.ReadMessage(subBuilder); + GroupNorm = subBuilder; + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the Hyperparams message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// Operations affected by hyperparameters. + /// + public enum Op { + /// + /// Use None + /// + [pbr::OriginalName("NULL")] Null = 0, + /// + /// Convolution, Separable Convolution, Convolution transpose. + /// + [pbr::OriginalName("CONV")] Conv = 1, + /// + /// Fully connected + /// + [pbr::OriginalName("FC")] Fc = 2, + } + + /// + /// Type of activation to apply after convolution. + /// + public enum Activation { + /// + /// Use None (no activation) + /// + [pbr::OriginalName("NONE")] None = 0, + /// + /// Use tf.nn.relu + /// + [pbr::OriginalName("RELU")] Relu = 1, + /// + /// Use tf.nn.relu6 + /// + [pbr::OriginalName("RELU_6")] Relu6 = 2, + } + + } + #endregion + + } + + /// + /// Proto with one-of field for regularizers. + /// + public sealed partial class Regularizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Regularizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Regularizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Regularizer(Regularizer other) : this() { + switch (other.RegularizerOneofCase) { + case RegularizerOneofOneofCase.L1Regularizer: + L1Regularizer = other.L1Regularizer.Clone(); + break; + case RegularizerOneofOneofCase.L2Regularizer: + L2Regularizer = other.L2Regularizer.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Regularizer Clone() { + return new Regularizer(this); + } + + /// Field number for the "l1_regularizer" field. + public const int L1RegularizerFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.L1Regularizer L1Regularizer { + get { return regularizerOneofCase_ == RegularizerOneofOneofCase.L1Regularizer ? (global::Tensorflow.Models.ObjectDetection.Protos.L1Regularizer) regularizerOneof_ : null; } + set { + regularizerOneof_ = value; + regularizerOneofCase_ = value == null ? RegularizerOneofOneofCase.None : RegularizerOneofOneofCase.L1Regularizer; + } + } + + /// Field number for the "l2_regularizer" field. + public const int L2RegularizerFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.L2Regularizer L2Regularizer { + get { return regularizerOneofCase_ == RegularizerOneofOneofCase.L2Regularizer ? (global::Tensorflow.Models.ObjectDetection.Protos.L2Regularizer) regularizerOneof_ : null; } + set { + regularizerOneof_ = value; + regularizerOneofCase_ = value == null ? RegularizerOneofOneofCase.None : RegularizerOneofOneofCase.L2Regularizer; + } + } + + private object regularizerOneof_; + /// Enum of possible cases for the "regularizer_oneof" oneof. + public enum RegularizerOneofOneofCase { + None = 0, + L1Regularizer = 1, + L2Regularizer = 2, + } + private RegularizerOneofOneofCase regularizerOneofCase_ = RegularizerOneofOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RegularizerOneofOneofCase RegularizerOneofCase { + get { return regularizerOneofCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearRegularizerOneof() { + regularizerOneofCase_ = RegularizerOneofOneofCase.None; + regularizerOneof_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Regularizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Regularizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(L1Regularizer, other.L1Regularizer)) return false; + if (!object.Equals(L2Regularizer, other.L2Regularizer)) return false; + if (RegularizerOneofCase != other.RegularizerOneofCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (regularizerOneofCase_ == RegularizerOneofOneofCase.L1Regularizer) hash ^= L1Regularizer.GetHashCode(); + if (regularizerOneofCase_ == RegularizerOneofOneofCase.L2Regularizer) hash ^= L2Regularizer.GetHashCode(); + hash ^= (int) regularizerOneofCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (regularizerOneofCase_ == RegularizerOneofOneofCase.L1Regularizer) { + output.WriteRawTag(10); + output.WriteMessage(L1Regularizer); + } + if (regularizerOneofCase_ == RegularizerOneofOneofCase.L2Regularizer) { + output.WriteRawTag(18); + output.WriteMessage(L2Regularizer); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (regularizerOneofCase_ == RegularizerOneofOneofCase.L1Regularizer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(L1Regularizer); + } + if (regularizerOneofCase_ == RegularizerOneofOneofCase.L2Regularizer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(L2Regularizer); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Regularizer other) { + if (other == null) { + return; + } + switch (other.RegularizerOneofCase) { + case RegularizerOneofOneofCase.L1Regularizer: + if (L1Regularizer == null) { + L1Regularizer = new global::Tensorflow.Models.ObjectDetection.Protos.L1Regularizer(); + } + L1Regularizer.MergeFrom(other.L1Regularizer); + break; + case RegularizerOneofOneofCase.L2Regularizer: + if (L2Regularizer == null) { + L2Regularizer = new global::Tensorflow.Models.ObjectDetection.Protos.L2Regularizer(); + } + L2Regularizer.MergeFrom(other.L2Regularizer); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.L1Regularizer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.L1Regularizer(); + if (regularizerOneofCase_ == RegularizerOneofOneofCase.L1Regularizer) { + subBuilder.MergeFrom(L1Regularizer); + } + input.ReadMessage(subBuilder); + L1Regularizer = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.L2Regularizer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.L2Regularizer(); + if (regularizerOneofCase_ == RegularizerOneofOneofCase.L2Regularizer) { + subBuilder.MergeFrom(L2Regularizer); + } + input.ReadMessage(subBuilder); + L2Regularizer = subBuilder; + break; + } + } + } + } + + } + + /// + /// Configuration proto for L1 Regularizer. + /// See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l1_regularizer + /// + public sealed partial class L1Regularizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new L1Regularizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public L1Regularizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public L1Regularizer(L1Regularizer other) : this() { + weight_ = other.weight_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public L1Regularizer Clone() { + return new L1Regularizer(this); + } + + /// Field number for the "weight" field. + public const int WeightFieldNumber = 1; + private float weight_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Weight { + get { return weight_; } + set { + weight_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as L1Regularizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(L1Regularizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Weight, other.Weight)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Weight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Weight); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Weight != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Weight); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Weight != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(L1Regularizer other) { + if (other == null) { + return; + } + if (other.Weight != 0F) { + Weight = other.Weight; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Weight = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Configuration proto for L2 Regularizer. + /// See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l2_regularizer + /// + public sealed partial class L2Regularizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new L2Regularizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public L2Regularizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public L2Regularizer(L2Regularizer other) : this() { + weight_ = other.weight_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public L2Regularizer Clone() { + return new L2Regularizer(this); + } + + /// Field number for the "weight" field. + public const int WeightFieldNumber = 1; + private float weight_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Weight { + get { return weight_; } + set { + weight_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as L2Regularizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(L2Regularizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Weight, other.Weight)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Weight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Weight); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Weight != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Weight); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Weight != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(L2Regularizer other) { + if (other == null) { + return; + } + if (other.Weight != 0F) { + Weight = other.Weight; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Weight = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Proto with one-of field for initializers. + /// + public sealed partial class Initializer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Initializer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Initializer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Initializer(Initializer other) : this() { + switch (other.InitializerOneofCase) { + case InitializerOneofOneofCase.TruncatedNormalInitializer: + TruncatedNormalInitializer = other.TruncatedNormalInitializer.Clone(); + break; + case InitializerOneofOneofCase.VarianceScalingInitializer: + VarianceScalingInitializer = other.VarianceScalingInitializer.Clone(); + break; + case InitializerOneofOneofCase.RandomNormalInitializer: + RandomNormalInitializer = other.RandomNormalInitializer.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Initializer Clone() { + return new Initializer(this); + } + + /// Field number for the "truncated_normal_initializer" field. + public const int TruncatedNormalInitializerFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.TruncatedNormalInitializer TruncatedNormalInitializer { + get { return initializerOneofCase_ == InitializerOneofOneofCase.TruncatedNormalInitializer ? (global::Tensorflow.Models.ObjectDetection.Protos.TruncatedNormalInitializer) initializerOneof_ : null; } + set { + initializerOneof_ = value; + initializerOneofCase_ = value == null ? InitializerOneofOneofCase.None : InitializerOneofOneofCase.TruncatedNormalInitializer; + } + } + + /// Field number for the "variance_scaling_initializer" field. + public const int VarianceScalingInitializerFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer VarianceScalingInitializer { + get { return initializerOneofCase_ == InitializerOneofOneofCase.VarianceScalingInitializer ? (global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer) initializerOneof_ : null; } + set { + initializerOneof_ = value; + initializerOneofCase_ = value == null ? InitializerOneofOneofCase.None : InitializerOneofOneofCase.VarianceScalingInitializer; + } + } + + /// Field number for the "random_normal_initializer" field. + public const int RandomNormalInitializerFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomNormalInitializer RandomNormalInitializer { + get { return initializerOneofCase_ == InitializerOneofOneofCase.RandomNormalInitializer ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomNormalInitializer) initializerOneof_ : null; } + set { + initializerOneof_ = value; + initializerOneofCase_ = value == null ? InitializerOneofOneofCase.None : InitializerOneofOneofCase.RandomNormalInitializer; + } + } + + private object initializerOneof_; + /// Enum of possible cases for the "initializer_oneof" oneof. + public enum InitializerOneofOneofCase { + None = 0, + TruncatedNormalInitializer = 1, + VarianceScalingInitializer = 2, + RandomNormalInitializer = 3, + } + private InitializerOneofOneofCase initializerOneofCase_ = InitializerOneofOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public InitializerOneofOneofCase InitializerOneofCase { + get { return initializerOneofCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearInitializerOneof() { + initializerOneofCase_ = InitializerOneofOneofCase.None; + initializerOneof_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Initializer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Initializer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(TruncatedNormalInitializer, other.TruncatedNormalInitializer)) return false; + if (!object.Equals(VarianceScalingInitializer, other.VarianceScalingInitializer)) return false; + if (!object.Equals(RandomNormalInitializer, other.RandomNormalInitializer)) return false; + if (InitializerOneofCase != other.InitializerOneofCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (initializerOneofCase_ == InitializerOneofOneofCase.TruncatedNormalInitializer) hash ^= TruncatedNormalInitializer.GetHashCode(); + if (initializerOneofCase_ == InitializerOneofOneofCase.VarianceScalingInitializer) hash ^= VarianceScalingInitializer.GetHashCode(); + if (initializerOneofCase_ == InitializerOneofOneofCase.RandomNormalInitializer) hash ^= RandomNormalInitializer.GetHashCode(); + hash ^= (int) initializerOneofCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (initializerOneofCase_ == InitializerOneofOneofCase.TruncatedNormalInitializer) { + output.WriteRawTag(10); + output.WriteMessage(TruncatedNormalInitializer); + } + if (initializerOneofCase_ == InitializerOneofOneofCase.VarianceScalingInitializer) { + output.WriteRawTag(18); + output.WriteMessage(VarianceScalingInitializer); + } + if (initializerOneofCase_ == InitializerOneofOneofCase.RandomNormalInitializer) { + output.WriteRawTag(26); + output.WriteMessage(RandomNormalInitializer); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (initializerOneofCase_ == InitializerOneofOneofCase.TruncatedNormalInitializer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TruncatedNormalInitializer); + } + if (initializerOneofCase_ == InitializerOneofOneofCase.VarianceScalingInitializer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(VarianceScalingInitializer); + } + if (initializerOneofCase_ == InitializerOneofOneofCase.RandomNormalInitializer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomNormalInitializer); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Initializer other) { + if (other == null) { + return; + } + switch (other.InitializerOneofCase) { + case InitializerOneofOneofCase.TruncatedNormalInitializer: + if (TruncatedNormalInitializer == null) { + TruncatedNormalInitializer = new global::Tensorflow.Models.ObjectDetection.Protos.TruncatedNormalInitializer(); + } + TruncatedNormalInitializer.MergeFrom(other.TruncatedNormalInitializer); + break; + case InitializerOneofOneofCase.VarianceScalingInitializer: + if (VarianceScalingInitializer == null) { + VarianceScalingInitializer = new global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer(); + } + VarianceScalingInitializer.MergeFrom(other.VarianceScalingInitializer); + break; + case InitializerOneofOneofCase.RandomNormalInitializer: + if (RandomNormalInitializer == null) { + RandomNormalInitializer = new global::Tensorflow.Models.ObjectDetection.Protos.RandomNormalInitializer(); + } + RandomNormalInitializer.MergeFrom(other.RandomNormalInitializer); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.TruncatedNormalInitializer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.TruncatedNormalInitializer(); + if (initializerOneofCase_ == InitializerOneofOneofCase.TruncatedNormalInitializer) { + subBuilder.MergeFrom(TruncatedNormalInitializer); + } + input.ReadMessage(subBuilder); + TruncatedNormalInitializer = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer(); + if (initializerOneofCase_ == InitializerOneofOneofCase.VarianceScalingInitializer) { + subBuilder.MergeFrom(VarianceScalingInitializer); + } + input.ReadMessage(subBuilder); + VarianceScalingInitializer = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomNormalInitializer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomNormalInitializer(); + if (initializerOneofCase_ == InitializerOneofOneofCase.RandomNormalInitializer) { + subBuilder.MergeFrom(RandomNormalInitializer); + } + input.ReadMessage(subBuilder); + RandomNormalInitializer = subBuilder; + break; + } + } + } + } + + } + + /// + /// Configuration proto for truncated normal initializer. See + /// https://www.tensorflow.org/api_docs/python/tf/truncated_normal_initializer + /// + public sealed partial class TruncatedNormalInitializer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TruncatedNormalInitializer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TruncatedNormalInitializer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TruncatedNormalInitializer(TruncatedNormalInitializer other) : this() { + mean_ = other.mean_; + stddev_ = other.stddev_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TruncatedNormalInitializer Clone() { + return new TruncatedNormalInitializer(this); + } + + /// Field number for the "mean" field. + public const int MeanFieldNumber = 1; + private float mean_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Mean { + get { return mean_; } + set { + mean_ = value; + } + } + + /// Field number for the "stddev" field. + public const int StddevFieldNumber = 2; + private float stddev_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Stddev { + get { return stddev_; } + set { + stddev_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TruncatedNormalInitializer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TruncatedNormalInitializer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Mean, other.Mean)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Stddev, other.Stddev)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Mean != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Mean); + if (Stddev != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Stddev); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Mean != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Mean); + } + if (Stddev != 0F) { + output.WriteRawTag(21); + output.WriteFloat(Stddev); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Mean != 0F) { + size += 1 + 4; + } + if (Stddev != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TruncatedNormalInitializer other) { + if (other == null) { + return; + } + if (other.Mean != 0F) { + Mean = other.Mean; + } + if (other.Stddev != 0F) { + Stddev = other.Stddev; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Mean = input.ReadFloat(); + break; + } + case 21: { + Stddev = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Configuration proto for variance scaling initializer. See + /// https://www.tensorflow.org/api_docs/python/tf/contrib/layers/ + /// variance_scaling_initializer + /// + public sealed partial class VarianceScalingInitializer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new VarianceScalingInitializer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VarianceScalingInitializer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VarianceScalingInitializer(VarianceScalingInitializer other) : this() { + factor_ = other.factor_; + uniform_ = other.uniform_; + mode_ = other.mode_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VarianceScalingInitializer Clone() { + return new VarianceScalingInitializer(this); + } + + /// Field number for the "factor" field. + public const int FactorFieldNumber = 1; + private float factor_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Factor { + get { return factor_; } + set { + factor_ = value; + } + } + + /// Field number for the "uniform" field. + public const int UniformFieldNumber = 2; + private bool uniform_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Uniform { + get { return uniform_; } + set { + uniform_ = value; + } + } + + /// Field number for the "mode" field. + public const int ModeFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer.Types.Mode mode_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer.Types.Mode Mode { + get { return mode_; } + set { + mode_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as VarianceScalingInitializer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(VarianceScalingInitializer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Factor, other.Factor)) return false; + if (Uniform != other.Uniform) return false; + if (Mode != other.Mode) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Factor != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Factor); + if (Uniform != false) hash ^= Uniform.GetHashCode(); + if (Mode != 0) hash ^= Mode.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Factor != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Factor); + } + if (Uniform != false) { + output.WriteRawTag(16); + output.WriteBool(Uniform); + } + if (Mode != 0) { + output.WriteRawTag(24); + output.WriteEnum((int) Mode); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Factor != 0F) { + size += 1 + 4; + } + if (Uniform != false) { + size += 1 + 1; + } + if (Mode != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Mode); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(VarianceScalingInitializer other) { + if (other == null) { + return; + } + if (other.Factor != 0F) { + Factor = other.Factor; + } + if (other.Uniform != false) { + Uniform = other.Uniform; + } + if (other.Mode != 0) { + Mode = other.Mode; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Factor = input.ReadFloat(); + break; + } + case 16: { + Uniform = input.ReadBool(); + break; + } + case 24: { + mode_ = (global::Tensorflow.Models.ObjectDetection.Protos.VarianceScalingInitializer.Types.Mode) input.ReadEnum(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the VarianceScalingInitializer message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + public enum Mode { + [pbr::OriginalName("FAN_IN")] FanIn = 0, + [pbr::OriginalName("FAN_OUT")] FanOut = 1, + [pbr::OriginalName("FAN_AVG")] FanAvg = 2, + } + + } + #endregion + + } + + /// + /// Configuration proto for random normal initializer. See + /// https://www.tensorflow.org/api_docs/python/tf/random_normal_initializer + /// + public sealed partial class RandomNormalInitializer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomNormalInitializer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomNormalInitializer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomNormalInitializer(RandomNormalInitializer other) : this() { + mean_ = other.mean_; + stddev_ = other.stddev_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomNormalInitializer Clone() { + return new RandomNormalInitializer(this); + } + + /// Field number for the "mean" field. + public const int MeanFieldNumber = 1; + private float mean_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Mean { + get { return mean_; } + set { + mean_ = value; + } + } + + /// Field number for the "stddev" field. + public const int StddevFieldNumber = 2; + private float stddev_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Stddev { + get { return stddev_; } + set { + stddev_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomNormalInitializer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomNormalInitializer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Mean, other.Mean)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Stddev, other.Stddev)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Mean != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Mean); + if (Stddev != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Stddev); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Mean != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Mean); + } + if (Stddev != 0F) { + output.WriteRawTag(21); + output.WriteFloat(Stddev); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Mean != 0F) { + size += 1 + 4; + } + if (Stddev != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomNormalInitializer other) { + if (other == null) { + return; + } + if (other.Mean != 0F) { + Mean = other.Mean; + } + if (other.Stddev != 0F) { + Stddev = other.Stddev; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Mean = input.ReadFloat(); + break; + } + case 21: { + Stddev = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Configuration proto for batch norm to apply after convolution op. See + /// https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm + /// + public sealed partial class BatchNorm : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BatchNorm()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BatchNorm() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BatchNorm(BatchNorm other) : this() { + decay_ = other.decay_; + center_ = other.center_; + scale_ = other.scale_; + epsilon_ = other.epsilon_; + train_ = other.train_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BatchNorm Clone() { + return new BatchNorm(this); + } + + /// Field number for the "decay" field. + public const int DecayFieldNumber = 1; + private float decay_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Decay { + get { return decay_; } + set { + decay_ = value; + } + } + + /// Field number for the "center" field. + public const int CenterFieldNumber = 2; + private bool center_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Center { + get { return center_; } + set { + center_ = value; + } + } + + /// Field number for the "scale" field. + public const int ScaleFieldNumber = 3; + private bool scale_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Scale { + get { return scale_; } + set { + scale_ = value; + } + } + + /// Field number for the "epsilon" field. + public const int EpsilonFieldNumber = 4; + private float epsilon_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Epsilon { + get { return epsilon_; } + set { + epsilon_ = value; + } + } + + /// Field number for the "train" field. + public const int TrainFieldNumber = 5; + private bool train_; + /// + /// Whether to train the batch norm variables. If this is set to false during + /// training, the current value of the batch_norm variables are used for + /// forward pass but they are never updated. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Train { + get { return train_; } + set { + train_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BatchNorm); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BatchNorm other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Decay, other.Decay)) return false; + if (Center != other.Center) return false; + if (Scale != other.Scale) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Epsilon, other.Epsilon)) return false; + if (Train != other.Train) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Decay != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Decay); + if (Center != false) hash ^= Center.GetHashCode(); + if (Scale != false) hash ^= Scale.GetHashCode(); + if (Epsilon != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Epsilon); + if (Train != false) hash ^= Train.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Decay != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Decay); + } + if (Center != false) { + output.WriteRawTag(16); + output.WriteBool(Center); + } + if (Scale != false) { + output.WriteRawTag(24); + output.WriteBool(Scale); + } + if (Epsilon != 0F) { + output.WriteRawTag(37); + output.WriteFloat(Epsilon); + } + if (Train != false) { + output.WriteRawTag(40); + output.WriteBool(Train); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Decay != 0F) { + size += 1 + 4; + } + if (Center != false) { + size += 1 + 1; + } + if (Scale != false) { + size += 1 + 1; + } + if (Epsilon != 0F) { + size += 1 + 4; + } + if (Train != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BatchNorm other) { + if (other == null) { + return; + } + if (other.Decay != 0F) { + Decay = other.Decay; + } + if (other.Center != false) { + Center = other.Center; + } + if (other.Scale != false) { + Scale = other.Scale; + } + if (other.Epsilon != 0F) { + Epsilon = other.Epsilon; + } + if (other.Train != false) { + Train = other.Train; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Decay = input.ReadFloat(); + break; + } + case 16: { + Center = input.ReadBool(); + break; + } + case 24: { + Scale = input.ReadBool(); + break; + } + case 37: { + Epsilon = input.ReadFloat(); + break; + } + case 40: { + Train = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Configuration proto for group normalization to apply after convolution op. + /// https://arxiv.org/abs/1803.08494 + /// + public sealed partial class GroupNorm : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new GroupNorm()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GroupNorm() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GroupNorm(GroupNorm other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public GroupNorm Clone() { + return new GroupNorm(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as GroupNorm); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(GroupNorm other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(GroupNorm other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/ImageResizer.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/ImageResizer.cs new file mode 100644 index 00000000..8f10a4bb --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/ImageResizer.cs @@ -0,0 +1,1255 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/image_resizer.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/image_resizer.proto + public static partial class ImageResizerReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/image_resizer.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ImageResizerReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CitvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9pbWFnZV9yZXNpemVyLnByb3Rv", + "EhdvYmplY3RfZGV0ZWN0aW9uLnByb3RvcyLjAgoMSW1hZ2VSZXNpemVyElQK", + "GWtlZXBfYXNwZWN0X3JhdGlvX3Jlc2l6ZXIYASABKAsyLy5vYmplY3RfZGV0", + "ZWN0aW9uLnByb3Rvcy5LZWVwQXNwZWN0UmF0aW9SZXNpemVySAASSQoTZml4", + "ZWRfc2hhcGVfcmVzaXplchgCIAEoCzIqLm9iamVjdF9kZXRlY3Rpb24ucHJv", + "dG9zLkZpeGVkU2hhcGVSZXNpemVySAASRAoQaWRlbnRpdHlfcmVzaXplchgD", + "IAEoCzIoLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLklkZW50aXR5UmVzaXpl", + "ckgAElUKGWNvbmRpdGlvbmFsX3NoYXBlX3Jlc2l6ZXIYBCABKAsyMC5vYmpl", + "Y3RfZGV0ZWN0aW9uLnByb3Rvcy5Db25kaXRpb25hbFNoYXBlUmVzaXplckgA", + "QhUKE2ltYWdlX3Jlc2l6ZXJfb25lb2YiEQoPSWRlbnRpdHlSZXNpemVyIt0B", + "ChZLZWVwQXNwZWN0UmF0aW9SZXNpemVyEhUKDW1pbl9kaW1lbnNpb24YASAB", + "KAUSFQoNbWF4X2RpbWVuc2lvbhgCIAEoBRI6Cg1yZXNpemVfbWV0aG9kGAMg", + "ASgOMiMub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuUmVzaXplVHlwZRIcChRw", + "YWRfdG9fbWF4X2RpbWVuc2lvbhgEIAEoCBIcChRjb252ZXJ0X3RvX2dyYXlz", + "Y2FsZRgFIAEoCBIdChVwZXJfY2hhbm5lbF9wYWRfdmFsdWUYBiADKAIijAEK", + "EUZpeGVkU2hhcGVSZXNpemVyEg4KBmhlaWdodBgBIAEoBRINCgV3aWR0aBgC", + "IAEoBRI6Cg1yZXNpemVfbWV0aG9kGAMgASgOMiMub2JqZWN0X2RldGVjdGlv", + "bi5wcm90b3MuUmVzaXplVHlwZRIcChRjb252ZXJ0X3RvX2dyYXlzY2FsZRgE", + "IAEoCCKaAgoXQ29uZGl0aW9uYWxTaGFwZVJlc2l6ZXISUwoJY29uZGl0aW9u", + "GAEgASgOMkAub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuQ29uZGl0aW9uYWxT", + "aGFwZVJlc2l6ZXIuUmVzaXplQ29uZGl0aW9uEhYKDnNpemVfdGhyZXNob2xk", + "GAIgASgFEjoKDXJlc2l6ZV9tZXRob2QYAyABKA4yIy5vYmplY3RfZGV0ZWN0", + "aW9uLnByb3Rvcy5SZXNpemVUeXBlEhwKFGNvbnZlcnRfdG9fZ3JheXNjYWxl", + "GAQgASgIIjgKD1Jlc2l6ZUNvbmRpdGlvbhILCgdJTlZBTElEEAASCwoHR1JF", + "QVRFUhABEgsKB1NNQUxMRVIQAipHCgpSZXNpemVUeXBlEgwKCEJJTElORUFS", + "EAASFAoQTkVBUkVTVF9ORUlHSEJPUhABEgsKB0JJQ1VCSUMQAhIICgRBUkVB", + "EANiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.Models.ObjectDetection.Protos.ResizeType), }, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer), global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer.Parser, new[]{ "KeepAspectRatioResizer", "FixedShapeResizer", "IdentityResizer", "ConditionalShapeResizer" }, new[]{ "ImageResizerOneof" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.IdentityResizer), global::Tensorflow.Models.ObjectDetection.Protos.IdentityResizer.Parser, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.KeepAspectRatioResizer), global::Tensorflow.Models.ObjectDetection.Protos.KeepAspectRatioResizer.Parser, new[]{ "MinDimension", "MaxDimension", "ResizeMethod", "PadToMaxDimension", "ConvertToGrayscale", "PerChannelPadValue" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.FixedShapeResizer), global::Tensorflow.Models.ObjectDetection.Protos.FixedShapeResizer.Parser, new[]{ "Height", "Width", "ResizeMethod", "ConvertToGrayscale" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer), global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer.Parser, new[]{ "Condition", "SizeThreshold", "ResizeMethod", "ConvertToGrayscale" }, null, new[]{ typeof(global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer.Types.ResizeCondition) }, null) + })); + } + #endregion + + } + #region Enums + /// + /// Enumeration type for image resizing methods provided in TensorFlow. + /// + public enum ResizeType { + /// + /// Corresponds to tf.image.ResizeMethod.BILINEAR + /// + [pbr::OriginalName("BILINEAR")] Bilinear = 0, + /// + /// Corresponds to tf.image.ResizeMethod.NEAREST_NEIGHBOR + /// + [pbr::OriginalName("NEAREST_NEIGHBOR")] NearestNeighbor = 1, + /// + /// Corresponds to tf.image.ResizeMethod.BICUBIC + /// + [pbr::OriginalName("BICUBIC")] Bicubic = 2, + /// + /// Corresponds to tf.image.ResizeMethod.AREA + /// + [pbr::OriginalName("AREA")] Area = 3, + } + + #endregion + + #region Messages + /// + /// Configuration proto for image resizing operations. + /// See builders/image_resizer_builder.py for details. + /// + public sealed partial class ImageResizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ImageResizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.ImageResizerReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ImageResizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ImageResizer(ImageResizer other) : this() { + switch (other.ImageResizerOneofCase) { + case ImageResizerOneofOneofCase.KeepAspectRatioResizer: + KeepAspectRatioResizer = other.KeepAspectRatioResizer.Clone(); + break; + case ImageResizerOneofOneofCase.FixedShapeResizer: + FixedShapeResizer = other.FixedShapeResizer.Clone(); + break; + case ImageResizerOneofOneofCase.IdentityResizer: + IdentityResizer = other.IdentityResizer.Clone(); + break; + case ImageResizerOneofOneofCase.ConditionalShapeResizer: + ConditionalShapeResizer = other.ConditionalShapeResizer.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ImageResizer Clone() { + return new ImageResizer(this); + } + + /// Field number for the "keep_aspect_ratio_resizer" field. + public const int KeepAspectRatioResizerFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.KeepAspectRatioResizer KeepAspectRatioResizer { + get { return imageResizerOneofCase_ == ImageResizerOneofOneofCase.KeepAspectRatioResizer ? (global::Tensorflow.Models.ObjectDetection.Protos.KeepAspectRatioResizer) imageResizerOneof_ : null; } + set { + imageResizerOneof_ = value; + imageResizerOneofCase_ = value == null ? ImageResizerOneofOneofCase.None : ImageResizerOneofOneofCase.KeepAspectRatioResizer; + } + } + + /// Field number for the "fixed_shape_resizer" field. + public const int FixedShapeResizerFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.FixedShapeResizer FixedShapeResizer { + get { return imageResizerOneofCase_ == ImageResizerOneofOneofCase.FixedShapeResizer ? (global::Tensorflow.Models.ObjectDetection.Protos.FixedShapeResizer) imageResizerOneof_ : null; } + set { + imageResizerOneof_ = value; + imageResizerOneofCase_ = value == null ? ImageResizerOneofOneofCase.None : ImageResizerOneofOneofCase.FixedShapeResizer; + } + } + + /// Field number for the "identity_resizer" field. + public const int IdentityResizerFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.IdentityResizer IdentityResizer { + get { return imageResizerOneofCase_ == ImageResizerOneofOneofCase.IdentityResizer ? (global::Tensorflow.Models.ObjectDetection.Protos.IdentityResizer) imageResizerOneof_ : null; } + set { + imageResizerOneof_ = value; + imageResizerOneofCase_ = value == null ? ImageResizerOneofOneofCase.None : ImageResizerOneofOneofCase.IdentityResizer; + } + } + + /// Field number for the "conditional_shape_resizer" field. + public const int ConditionalShapeResizerFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer ConditionalShapeResizer { + get { return imageResizerOneofCase_ == ImageResizerOneofOneofCase.ConditionalShapeResizer ? (global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer) imageResizerOneof_ : null; } + set { + imageResizerOneof_ = value; + imageResizerOneofCase_ = value == null ? ImageResizerOneofOneofCase.None : ImageResizerOneofOneofCase.ConditionalShapeResizer; + } + } + + private object imageResizerOneof_; + /// Enum of possible cases for the "image_resizer_oneof" oneof. + public enum ImageResizerOneofOneofCase { + None = 0, + KeepAspectRatioResizer = 1, + FixedShapeResizer = 2, + IdentityResizer = 3, + ConditionalShapeResizer = 4, + } + private ImageResizerOneofOneofCase imageResizerOneofCase_ = ImageResizerOneofOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ImageResizerOneofOneofCase ImageResizerOneofCase { + get { return imageResizerOneofCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearImageResizerOneof() { + imageResizerOneofCase_ = ImageResizerOneofOneofCase.None; + imageResizerOneof_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ImageResizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ImageResizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(KeepAspectRatioResizer, other.KeepAspectRatioResizer)) return false; + if (!object.Equals(FixedShapeResizer, other.FixedShapeResizer)) return false; + if (!object.Equals(IdentityResizer, other.IdentityResizer)) return false; + if (!object.Equals(ConditionalShapeResizer, other.ConditionalShapeResizer)) return false; + if (ImageResizerOneofCase != other.ImageResizerOneofCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.KeepAspectRatioResizer) hash ^= KeepAspectRatioResizer.GetHashCode(); + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.FixedShapeResizer) hash ^= FixedShapeResizer.GetHashCode(); + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.IdentityResizer) hash ^= IdentityResizer.GetHashCode(); + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.ConditionalShapeResizer) hash ^= ConditionalShapeResizer.GetHashCode(); + hash ^= (int) imageResizerOneofCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.KeepAspectRatioResizer) { + output.WriteRawTag(10); + output.WriteMessage(KeepAspectRatioResizer); + } + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.FixedShapeResizer) { + output.WriteRawTag(18); + output.WriteMessage(FixedShapeResizer); + } + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.IdentityResizer) { + output.WriteRawTag(26); + output.WriteMessage(IdentityResizer); + } + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.ConditionalShapeResizer) { + output.WriteRawTag(34); + output.WriteMessage(ConditionalShapeResizer); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.KeepAspectRatioResizer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(KeepAspectRatioResizer); + } + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.FixedShapeResizer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FixedShapeResizer); + } + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.IdentityResizer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(IdentityResizer); + } + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.ConditionalShapeResizer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ConditionalShapeResizer); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ImageResizer other) { + if (other == null) { + return; + } + switch (other.ImageResizerOneofCase) { + case ImageResizerOneofOneofCase.KeepAspectRatioResizer: + if (KeepAspectRatioResizer == null) { + KeepAspectRatioResizer = new global::Tensorflow.Models.ObjectDetection.Protos.KeepAspectRatioResizer(); + } + KeepAspectRatioResizer.MergeFrom(other.KeepAspectRatioResizer); + break; + case ImageResizerOneofOneofCase.FixedShapeResizer: + if (FixedShapeResizer == null) { + FixedShapeResizer = new global::Tensorflow.Models.ObjectDetection.Protos.FixedShapeResizer(); + } + FixedShapeResizer.MergeFrom(other.FixedShapeResizer); + break; + case ImageResizerOneofOneofCase.IdentityResizer: + if (IdentityResizer == null) { + IdentityResizer = new global::Tensorflow.Models.ObjectDetection.Protos.IdentityResizer(); + } + IdentityResizer.MergeFrom(other.IdentityResizer); + break; + case ImageResizerOneofOneofCase.ConditionalShapeResizer: + if (ConditionalShapeResizer == null) { + ConditionalShapeResizer = new global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer(); + } + ConditionalShapeResizer.MergeFrom(other.ConditionalShapeResizer); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.KeepAspectRatioResizer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.KeepAspectRatioResizer(); + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.KeepAspectRatioResizer) { + subBuilder.MergeFrom(KeepAspectRatioResizer); + } + input.ReadMessage(subBuilder); + KeepAspectRatioResizer = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.FixedShapeResizer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.FixedShapeResizer(); + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.FixedShapeResizer) { + subBuilder.MergeFrom(FixedShapeResizer); + } + input.ReadMessage(subBuilder); + FixedShapeResizer = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.IdentityResizer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.IdentityResizer(); + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.IdentityResizer) { + subBuilder.MergeFrom(IdentityResizer); + } + input.ReadMessage(subBuilder); + IdentityResizer = subBuilder; + break; + } + case 34: { + global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer(); + if (imageResizerOneofCase_ == ImageResizerOneofOneofCase.ConditionalShapeResizer) { + subBuilder.MergeFrom(ConditionalShapeResizer); + } + input.ReadMessage(subBuilder); + ConditionalShapeResizer = subBuilder; + break; + } + } + } + } + + } + + public sealed partial class IdentityResizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new IdentityResizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.ImageResizerReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IdentityResizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IdentityResizer(IdentityResizer other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IdentityResizer Clone() { + return new IdentityResizer(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as IdentityResizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(IdentityResizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(IdentityResizer other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + /// + /// Configuration proto for image resizer that keeps aspect ratio. + /// + public sealed partial class KeepAspectRatioResizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KeepAspectRatioResizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.ImageResizerReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public KeepAspectRatioResizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public KeepAspectRatioResizer(KeepAspectRatioResizer other) : this() { + minDimension_ = other.minDimension_; + maxDimension_ = other.maxDimension_; + resizeMethod_ = other.resizeMethod_; + padToMaxDimension_ = other.padToMaxDimension_; + convertToGrayscale_ = other.convertToGrayscale_; + perChannelPadValue_ = other.perChannelPadValue_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public KeepAspectRatioResizer Clone() { + return new KeepAspectRatioResizer(this); + } + + /// Field number for the "min_dimension" field. + public const int MinDimensionFieldNumber = 1; + private int minDimension_; + /// + /// Desired size of the smaller image dimension in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinDimension { + get { return minDimension_; } + set { + minDimension_ = value; + } + } + + /// Field number for the "max_dimension" field. + public const int MaxDimensionFieldNumber = 2; + private int maxDimension_; + /// + /// Desired size of the larger image dimension in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxDimension { + get { return maxDimension_; } + set { + maxDimension_ = value; + } + } + + /// Field number for the "resize_method" field. + public const int ResizeMethodFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.ResizeType resizeMethod_ = 0; + /// + /// Desired method when resizing image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ResizeType ResizeMethod { + get { return resizeMethod_; } + set { + resizeMethod_ = value; + } + } + + /// Field number for the "pad_to_max_dimension" field. + public const int PadToMaxDimensionFieldNumber = 4; + private bool padToMaxDimension_; + /// + /// Whether to pad the image with zeros so the output spatial size is + /// [max_dimension, max_dimension]. Note that the zeros are padded to the + /// bottom and the right of the resized image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool PadToMaxDimension { + get { return padToMaxDimension_; } + set { + padToMaxDimension_ = value; + } + } + + /// Field number for the "convert_to_grayscale" field. + public const int ConvertToGrayscaleFieldNumber = 5; + private bool convertToGrayscale_; + /// + /// Whether to also resize the image channels from 3 to 1 (RGB to grayscale). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ConvertToGrayscale { + get { return convertToGrayscale_; } + set { + convertToGrayscale_ = value; + } + } + + /// Field number for the "per_channel_pad_value" field. + public const int PerChannelPadValueFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_perChannelPadValue_codec + = pb::FieldCodec.ForFloat(50); + private readonly pbc::RepeatedField perChannelPadValue_ = new pbc::RepeatedField(); + /// + /// Per-channel pad value. This is only used when pad_to_max_dimension is True. + /// If unspecified, a default pad value of 0 is applied to all channels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField PerChannelPadValue { + get { return perChannelPadValue_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as KeepAspectRatioResizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(KeepAspectRatioResizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MinDimension != other.MinDimension) return false; + if (MaxDimension != other.MaxDimension) return false; + if (ResizeMethod != other.ResizeMethod) return false; + if (PadToMaxDimension != other.PadToMaxDimension) return false; + if (ConvertToGrayscale != other.ConvertToGrayscale) return false; + if(!perChannelPadValue_.Equals(other.perChannelPadValue_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinDimension != 0) hash ^= MinDimension.GetHashCode(); + if (MaxDimension != 0) hash ^= MaxDimension.GetHashCode(); + if (ResizeMethod != 0) hash ^= ResizeMethod.GetHashCode(); + if (PadToMaxDimension != false) hash ^= PadToMaxDimension.GetHashCode(); + if (ConvertToGrayscale != false) hash ^= ConvertToGrayscale.GetHashCode(); + hash ^= perChannelPadValue_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinDimension != 0) { + output.WriteRawTag(8); + output.WriteInt32(MinDimension); + } + if (MaxDimension != 0) { + output.WriteRawTag(16); + output.WriteInt32(MaxDimension); + } + if (ResizeMethod != 0) { + output.WriteRawTag(24); + output.WriteEnum((int) ResizeMethod); + } + if (PadToMaxDimension != false) { + output.WriteRawTag(32); + output.WriteBool(PadToMaxDimension); + } + if (ConvertToGrayscale != false) { + output.WriteRawTag(40); + output.WriteBool(ConvertToGrayscale); + } + perChannelPadValue_.WriteTo(output, _repeated_perChannelPadValue_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinDimension != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinDimension); + } + if (MaxDimension != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxDimension); + } + if (ResizeMethod != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ResizeMethod); + } + if (PadToMaxDimension != false) { + size += 1 + 1; + } + if (ConvertToGrayscale != false) { + size += 1 + 1; + } + size += perChannelPadValue_.CalculateSize(_repeated_perChannelPadValue_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(KeepAspectRatioResizer other) { + if (other == null) { + return; + } + if (other.MinDimension != 0) { + MinDimension = other.MinDimension; + } + if (other.MaxDimension != 0) { + MaxDimension = other.MaxDimension; + } + if (other.ResizeMethod != 0) { + ResizeMethod = other.ResizeMethod; + } + if (other.PadToMaxDimension != false) { + PadToMaxDimension = other.PadToMaxDimension; + } + if (other.ConvertToGrayscale != false) { + ConvertToGrayscale = other.ConvertToGrayscale; + } + perChannelPadValue_.Add(other.perChannelPadValue_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + MinDimension = input.ReadInt32(); + break; + } + case 16: { + MaxDimension = input.ReadInt32(); + break; + } + case 24: { + resizeMethod_ = (global::Tensorflow.Models.ObjectDetection.Protos.ResizeType) input.ReadEnum(); + break; + } + case 32: { + PadToMaxDimension = input.ReadBool(); + break; + } + case 40: { + ConvertToGrayscale = input.ReadBool(); + break; + } + case 50: + case 53: { + perChannelPadValue_.AddEntriesFrom(input, _repeated_perChannelPadValue_codec); + break; + } + } + } + } + + } + + /// + /// Configuration proto for image resizer that resizes to a fixed shape. + /// + public sealed partial class FixedShapeResizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FixedShapeResizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.ImageResizerReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FixedShapeResizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FixedShapeResizer(FixedShapeResizer other) : this() { + height_ = other.height_; + width_ = other.width_; + resizeMethod_ = other.resizeMethod_; + convertToGrayscale_ = other.convertToGrayscale_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FixedShapeResizer Clone() { + return new FixedShapeResizer(this); + } + + /// Field number for the "height" field. + public const int HeightFieldNumber = 1; + private int height_; + /// + /// Desired height of image in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Height { + get { return height_; } + set { + height_ = value; + } + } + + /// Field number for the "width" field. + public const int WidthFieldNumber = 2; + private int width_; + /// + /// Desired width of image in pixels. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Width { + get { return width_; } + set { + width_ = value; + } + } + + /// Field number for the "resize_method" field. + public const int ResizeMethodFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.ResizeType resizeMethod_ = 0; + /// + /// Desired method when resizing image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ResizeType ResizeMethod { + get { return resizeMethod_; } + set { + resizeMethod_ = value; + } + } + + /// Field number for the "convert_to_grayscale" field. + public const int ConvertToGrayscaleFieldNumber = 4; + private bool convertToGrayscale_; + /// + /// Whether to also resize the image channels from 3 to 1 (RGB to grayscale). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ConvertToGrayscale { + get { return convertToGrayscale_; } + set { + convertToGrayscale_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FixedShapeResizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FixedShapeResizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Height != other.Height) return false; + if (Width != other.Width) return false; + if (ResizeMethod != other.ResizeMethod) return false; + if (ConvertToGrayscale != other.ConvertToGrayscale) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Height != 0) hash ^= Height.GetHashCode(); + if (Width != 0) hash ^= Width.GetHashCode(); + if (ResizeMethod != 0) hash ^= ResizeMethod.GetHashCode(); + if (ConvertToGrayscale != false) hash ^= ConvertToGrayscale.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Height != 0) { + output.WriteRawTag(8); + output.WriteInt32(Height); + } + if (Width != 0) { + output.WriteRawTag(16); + output.WriteInt32(Width); + } + if (ResizeMethod != 0) { + output.WriteRawTag(24); + output.WriteEnum((int) ResizeMethod); + } + if (ConvertToGrayscale != false) { + output.WriteRawTag(32); + output.WriteBool(ConvertToGrayscale); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Height != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Height); + } + if (Width != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Width); + } + if (ResizeMethod != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ResizeMethod); + } + if (ConvertToGrayscale != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FixedShapeResizer other) { + if (other == null) { + return; + } + if (other.Height != 0) { + Height = other.Height; + } + if (other.Width != 0) { + Width = other.Width; + } + if (other.ResizeMethod != 0) { + ResizeMethod = other.ResizeMethod; + } + if (other.ConvertToGrayscale != false) { + ConvertToGrayscale = other.ConvertToGrayscale; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Height = input.ReadInt32(); + break; + } + case 16: { + Width = input.ReadInt32(); + break; + } + case 24: { + resizeMethod_ = (global::Tensorflow.Models.ObjectDetection.Protos.ResizeType) input.ReadEnum(); + break; + } + case 32: { + ConvertToGrayscale = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Configuration proto for image resizer that resizes only if input image height + /// or width is greater or smaller than a certain size. + /// Aspect ratio is maintained. + /// + public sealed partial class ConditionalShapeResizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ConditionalShapeResizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.ImageResizerReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConditionalShapeResizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConditionalShapeResizer(ConditionalShapeResizer other) : this() { + condition_ = other.condition_; + sizeThreshold_ = other.sizeThreshold_; + resizeMethod_ = other.resizeMethod_; + convertToGrayscale_ = other.convertToGrayscale_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConditionalShapeResizer Clone() { + return new ConditionalShapeResizer(this); + } + + /// Field number for the "condition" field. + public const int ConditionFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer.Types.ResizeCondition condition_ = 0; + /// + /// Condition which must be true to resize the image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer.Types.ResizeCondition Condition { + get { return condition_; } + set { + condition_ = value; + } + } + + /// Field number for the "size_threshold" field. + public const int SizeThresholdFieldNumber = 2; + private int sizeThreshold_; + /// + /// Threshold for the image size. If any image dimension is above or below this + /// (as specified by condition) the image will be resized so that it meets the + /// threshold. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int SizeThreshold { + get { return sizeThreshold_; } + set { + sizeThreshold_ = value; + } + } + + /// Field number for the "resize_method" field. + public const int ResizeMethodFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.ResizeType resizeMethod_ = 0; + /// + /// Desired method when resizing image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ResizeType ResizeMethod { + get { return resizeMethod_; } + set { + resizeMethod_ = value; + } + } + + /// Field number for the "convert_to_grayscale" field. + public const int ConvertToGrayscaleFieldNumber = 4; + private bool convertToGrayscale_; + /// + /// Whether to also resize the image channels from 3 to 1 (RGB to grayscale). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ConvertToGrayscale { + get { return convertToGrayscale_; } + set { + convertToGrayscale_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ConditionalShapeResizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ConditionalShapeResizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Condition != other.Condition) return false; + if (SizeThreshold != other.SizeThreshold) return false; + if (ResizeMethod != other.ResizeMethod) return false; + if (ConvertToGrayscale != other.ConvertToGrayscale) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Condition != 0) hash ^= Condition.GetHashCode(); + if (SizeThreshold != 0) hash ^= SizeThreshold.GetHashCode(); + if (ResizeMethod != 0) hash ^= ResizeMethod.GetHashCode(); + if (ConvertToGrayscale != false) hash ^= ConvertToGrayscale.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Condition != 0) { + output.WriteRawTag(8); + output.WriteEnum((int) Condition); + } + if (SizeThreshold != 0) { + output.WriteRawTag(16); + output.WriteInt32(SizeThreshold); + } + if (ResizeMethod != 0) { + output.WriteRawTag(24); + output.WriteEnum((int) ResizeMethod); + } + if (ConvertToGrayscale != false) { + output.WriteRawTag(32); + output.WriteBool(ConvertToGrayscale); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Condition != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Condition); + } + if (SizeThreshold != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(SizeThreshold); + } + if (ResizeMethod != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ResizeMethod); + } + if (ConvertToGrayscale != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ConditionalShapeResizer other) { + if (other == null) { + return; + } + if (other.Condition != 0) { + Condition = other.Condition; + } + if (other.SizeThreshold != 0) { + SizeThreshold = other.SizeThreshold; + } + if (other.ResizeMethod != 0) { + ResizeMethod = other.ResizeMethod; + } + if (other.ConvertToGrayscale != false) { + ConvertToGrayscale = other.ConvertToGrayscale; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + condition_ = (global::Tensorflow.Models.ObjectDetection.Protos.ConditionalShapeResizer.Types.ResizeCondition) input.ReadEnum(); + break; + } + case 16: { + SizeThreshold = input.ReadInt32(); + break; + } + case 24: { + resizeMethod_ = (global::Tensorflow.Models.ObjectDetection.Protos.ResizeType) input.ReadEnum(); + break; + } + case 32: { + ConvertToGrayscale = input.ReadBool(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the ConditionalShapeResizer message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// Enumeration for the condition on which to resize an image. + /// + public enum ResizeCondition { + /// + /// Default value. + /// + [pbr::OriginalName("INVALID")] Invalid = 0, + /// + /// Resizes image if a dimension is greater than specified size. + /// + [pbr::OriginalName("GREATER")] Greater = 1, + /// + /// Resizes image if a dimension is smaller than specified size. + /// + [pbr::OriginalName("SMALLER")] Smaller = 2, + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/InputReader.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/InputReader.cs new file mode 100644 index 00000000..90f7ae1b --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/InputReader.cs @@ -0,0 +1,1225 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/input_reader.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/input_reader.proto + public static partial class InputReaderReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/input_reader.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static InputReaderReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CipvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9pbnB1dF9yZWFkZXIucHJvdG8S", + "F29iamVjdF9kZXRlY3Rpb24ucHJvdG9zIqsGCgtJbnB1dFJlYWRlchIMCgRu", + "YW1lGBcgASgJEhYKDmxhYmVsX21hcF9wYXRoGAEgASgJEg8KB3NodWZmbGUY", + "AiABKAgSGwoTc2h1ZmZsZV9idWZmZXJfc2l6ZRgLIAEoDRIlCh1maWxlbmFt", + "ZXNfc2h1ZmZsZV9idWZmZXJfc2l6ZRgMIAEoDRISCgpudW1fZXBvY2hzGAUg", + "ASgNEh4KFnNhbXBsZV8xX29mX25fZXhhbXBsZXMYFiABKA0SEwoLbnVtX3Jl", + "YWRlcnMYBiABKA0SHAoUbnVtX3BhcmFsbGVsX2JhdGNoZXMYEyABKA0SHAoU", + "bnVtX3ByZWZldGNoX2JhdGNoZXMYFCABKAUSFgoOcXVldWVfY2FwYWNpdHkY", + "AyABKA0SGQoRbWluX2FmdGVyX2RlcXVldWUYBCABKA0SGQoRcmVhZF9ibG9j", + "a19sZW5ndGgYDyABKA0SFQoNcHJlZmV0Y2hfc2l6ZRgNIAEoDRIeChZudW1f", + "cGFyYWxsZWxfbWFwX2NhbGxzGA4gASgNEh8KF251bV9hZGRpdGlvbmFsX2No", + "YW5uZWxzGBIgASgFEhUKDW51bV9rZXlwb2ludHMYECABKA0SGwoTbWF4X251", + "bWJlcl9vZl9ib3hlcxgVIAEoBRIeChZsb2FkX211bHRpY2xhc3Nfc2NvcmVz", + "GBggASgIEhsKE2xvYWRfaW5zdGFuY2VfbWFza3MYByABKAgSPAoJbWFza190", + "eXBlGAogASgOMikub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuSW5zdGFuY2VN", + "YXNrVHlwZRIYChB1c2VfZGlzcGxheV9uYW1lGBEgASgIEk4KFnRmX3JlY29y", + "ZF9pbnB1dF9yZWFkZXIYCCABKAsyLC5vYmplY3RfZGV0ZWN0aW9uLnByb3Rv", + "cy5URlJlY29yZElucHV0UmVhZGVySAASTQoVZXh0ZXJuYWxfaW5wdXRfcmVh", + "ZGVyGAkgASgLMiwub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuRXh0ZXJuYWxJ", + "bnB1dFJlYWRlckgAQg4KDGlucHV0X3JlYWRlciIpChNURlJlY29yZElucHV0", + "UmVhZGVyEhIKCmlucHV0X3BhdGgYASADKAkiFQoTRXh0ZXJuYWxJbnB1dFJl", + "YWRlcipDChBJbnN0YW5jZU1hc2tUeXBlEgsKB0RFRkFVTFQQABITCg9OVU1F", + "UklDQUxfTUFTS1MQARINCglQTkdfTUFTS1MQAmIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.Models.ObjectDetection.Protos.InstanceMaskType), }, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.InputReader), global::Tensorflow.Models.ObjectDetection.Protos.InputReader.Parser, new[]{ "Name", "LabelMapPath", "Shuffle", "ShuffleBufferSize", "FilenamesShuffleBufferSize", "NumEpochs", "Sample1OfNExamples", "NumReaders", "NumParallelBatches", "NumPrefetchBatches", "QueueCapacity", "MinAfterDequeue", "ReadBlockLength", "PrefetchSize", "NumParallelMapCalls", "NumAdditionalChannels", "NumKeypoints", "MaxNumberOfBoxes", "LoadMulticlassScores", "LoadInstanceMasks", "MaskType", "UseDisplayName", "TfRecordInputReader", "ExternalInputReader" }, new[]{ "InputReader" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.TFRecordInputReader), global::Tensorflow.Models.ObjectDetection.Protos.TFRecordInputReader.Parser, new[]{ "InputPath" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ExternalInputReader), global::Tensorflow.Models.ObjectDetection.Protos.ExternalInputReader.Parser, null, null, null, null) + })); + } + #endregion + + } + #region Enums + /// + /// Instance mask format. Note that PNG masks are much more space efficient. + /// + public enum InstanceMaskType { + /// + /// Default implementation, currently NUMERICAL_MASKS + /// + [pbr::OriginalName("DEFAULT")] Default = 0, + /// + /// [num_masks, H, W] float32 binary masks. + /// + [pbr::OriginalName("NUMERICAL_MASKS")] NumericalMasks = 1, + /// + /// Encoded PNG masks. + /// + [pbr::OriginalName("PNG_MASKS")] PngMasks = 2, + } + + #endregion + + #region Messages + /// + /// Next id: 25 + /// + public sealed partial class InputReader : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new InputReader()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.InputReaderReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public InputReader() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public InputReader(InputReader other) : this() { + name_ = other.name_; + labelMapPath_ = other.labelMapPath_; + shuffle_ = other.shuffle_; + shuffleBufferSize_ = other.shuffleBufferSize_; + filenamesShuffleBufferSize_ = other.filenamesShuffleBufferSize_; + numEpochs_ = other.numEpochs_; + sample1OfNExamples_ = other.sample1OfNExamples_; + numReaders_ = other.numReaders_; + numParallelBatches_ = other.numParallelBatches_; + numPrefetchBatches_ = other.numPrefetchBatches_; + queueCapacity_ = other.queueCapacity_; + minAfterDequeue_ = other.minAfterDequeue_; + readBlockLength_ = other.readBlockLength_; + prefetchSize_ = other.prefetchSize_; + numParallelMapCalls_ = other.numParallelMapCalls_; + numAdditionalChannels_ = other.numAdditionalChannels_; + numKeypoints_ = other.numKeypoints_; + maxNumberOfBoxes_ = other.maxNumberOfBoxes_; + loadMulticlassScores_ = other.loadMulticlassScores_; + loadInstanceMasks_ = other.loadInstanceMasks_; + maskType_ = other.maskType_; + useDisplayName_ = other.useDisplayName_; + switch (other.InputReaderCase) { + case InputReaderOneofCase.TfRecordInputReader: + TfRecordInputReader = other.TfRecordInputReader.Clone(); + break; + case InputReaderOneofCase.ExternalInputReader: + ExternalInputReader = other.ExternalInputReader.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public InputReader Clone() { + return new InputReader(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 23; + private string name_ = ""; + /// + /// Name of input reader. Typically used to describe the dataset that is read + /// by this input reader. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "label_map_path" field. + public const int LabelMapPathFieldNumber = 1; + private string labelMapPath_ = ""; + /// + /// Path to StringIntLabelMap pbtxt file specifying the mapping from string + /// labels to integer ids. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string LabelMapPath { + get { return labelMapPath_; } + set { + labelMapPath_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "shuffle" field. + public const int ShuffleFieldNumber = 2; + private bool shuffle_; + /// + /// Whether data should be processed in the order they are read in, or + /// shuffled randomly. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Shuffle { + get { return shuffle_; } + set { + shuffle_ = value; + } + } + + /// Field number for the "shuffle_buffer_size" field. + public const int ShuffleBufferSizeFieldNumber = 11; + private uint shuffleBufferSize_; + /// + /// Buffer size to be used when shuffling. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint ShuffleBufferSize { + get { return shuffleBufferSize_; } + set { + shuffleBufferSize_ = value; + } + } + + /// Field number for the "filenames_shuffle_buffer_size" field. + public const int FilenamesShuffleBufferSizeFieldNumber = 12; + private uint filenamesShuffleBufferSize_; + /// + /// Buffer size to be used when shuffling file names. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint FilenamesShuffleBufferSize { + get { return filenamesShuffleBufferSize_; } + set { + filenamesShuffleBufferSize_ = value; + } + } + + /// Field number for the "num_epochs" field. + public const int NumEpochsFieldNumber = 5; + private uint numEpochs_; + /// + /// The number of times a data source is read. If set to zero, the data source + /// will be reused indefinitely. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint NumEpochs { + get { return numEpochs_; } + set { + numEpochs_ = value; + } + } + + /// Field number for the "sample_1_of_n_examples" field. + public const int Sample1OfNExamplesFieldNumber = 22; + private uint sample1OfNExamples_; + /// + /// Integer representing how often an example should be sampled. To feed + /// only 1/3 of your data into your model, set `sample_1_of_n_examples` to 3. + /// This is particularly useful for evaluation, where you might not prefer to + /// evaluate all of your samples. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint Sample1OfNExamples { + get { return sample1OfNExamples_; } + set { + sample1OfNExamples_ = value; + } + } + + /// Field number for the "num_readers" field. + public const int NumReadersFieldNumber = 6; + private uint numReaders_; + /// + /// Number of file shards to read in parallel. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint NumReaders { + get { return numReaders_; } + set { + numReaders_ = value; + } + } + + /// Field number for the "num_parallel_batches" field. + public const int NumParallelBatchesFieldNumber = 19; + private uint numParallelBatches_; + /// + /// Number of batches to produce in parallel. If this is run on a 2x2 TPU set + /// this to 8. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint NumParallelBatches { + get { return numParallelBatches_; } + set { + numParallelBatches_ = value; + } + } + + /// Field number for the "num_prefetch_batches" field. + public const int NumPrefetchBatchesFieldNumber = 20; + private int numPrefetchBatches_; + /// + /// Number of batches to prefetch. Prefetch decouples input pipeline and + /// model so they can be pipelined resulting in higher throughput. Set this + /// to a small constant and increment linearly until the improvements become + /// marginal or you exceed your cpu memory budget. Setting this to -1, + /// automatically tunes this value for you. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumPrefetchBatches { + get { return numPrefetchBatches_; } + set { + numPrefetchBatches_ = value; + } + } + + /// Field number for the "queue_capacity" field. + public const int QueueCapacityFieldNumber = 3; + private uint queueCapacity_; + /// + /// Maximum number of records to keep in reader queue. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint QueueCapacity { + get { return queueCapacity_; } + set { + queueCapacity_ = value; + } + } + + /// Field number for the "min_after_dequeue" field. + public const int MinAfterDequeueFieldNumber = 4; + private uint minAfterDequeue_; + /// + /// Minimum number of records to keep in reader queue. A large value is needed + /// to generate a good random shuffle. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint MinAfterDequeue { + get { return minAfterDequeue_; } + set { + minAfterDequeue_ = value; + } + } + + /// Field number for the "read_block_length" field. + public const int ReadBlockLengthFieldNumber = 15; + private uint readBlockLength_; + /// + /// Number of records to read from each reader at once. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint ReadBlockLength { + get { return readBlockLength_; } + set { + readBlockLength_ = value; + } + } + + /// Field number for the "prefetch_size" field. + public const int PrefetchSizeFieldNumber = 13; + private uint prefetchSize_; + /// + /// Number of decoded records to prefetch before batching. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint PrefetchSize { + get { return prefetchSize_; } + set { + prefetchSize_ = value; + } + } + + /// Field number for the "num_parallel_map_calls" field. + public const int NumParallelMapCallsFieldNumber = 14; + private uint numParallelMapCalls_; + /// + /// Number of parallel decode ops to apply. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint NumParallelMapCalls { + get { return numParallelMapCalls_; } + set { + numParallelMapCalls_ = value; + } + } + + /// Field number for the "num_additional_channels" field. + public const int NumAdditionalChannelsFieldNumber = 18; + private int numAdditionalChannels_; + /// + /// If positive, TfExampleDecoder will try to decode rasters of additional + /// channels from tf.Examples. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumAdditionalChannels { + get { return numAdditionalChannels_; } + set { + numAdditionalChannels_ = value; + } + } + + /// Field number for the "num_keypoints" field. + public const int NumKeypointsFieldNumber = 16; + private uint numKeypoints_; + /// + /// Number of groundtruth keypoints per object. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint NumKeypoints { + get { return numKeypoints_; } + set { + numKeypoints_ = value; + } + } + + /// Field number for the "max_number_of_boxes" field. + public const int MaxNumberOfBoxesFieldNumber = 21; + private int maxNumberOfBoxes_; + /// + /// Maximum number of boxes to pad to during training / evaluation. + /// Set this to at least the maximum amount of boxes in the input data, + /// otherwise some groundtruth boxes may be clipped. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxNumberOfBoxes { + get { return maxNumberOfBoxes_; } + set { + maxNumberOfBoxes_ = value; + } + } + + /// Field number for the "load_multiclass_scores" field. + public const int LoadMulticlassScoresFieldNumber = 24; + private bool loadMulticlassScores_; + /// + /// Whether to load multiclass scores from the dataset. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool LoadMulticlassScores { + get { return loadMulticlassScores_; } + set { + loadMulticlassScores_ = value; + } + } + + /// Field number for the "load_instance_masks" field. + public const int LoadInstanceMasksFieldNumber = 7; + private bool loadInstanceMasks_; + /// + /// Whether to load groundtruth instance masks. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool LoadInstanceMasks { + get { return loadInstanceMasks_; } + set { + loadInstanceMasks_ = value; + } + } + + /// Field number for the "mask_type" field. + public const int MaskTypeFieldNumber = 10; + private global::Tensorflow.Models.ObjectDetection.Protos.InstanceMaskType maskType_ = 0; + /// + /// Type of instance mask. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.InstanceMaskType MaskType { + get { return maskType_; } + set { + maskType_ = value; + } + } + + /// Field number for the "use_display_name" field. + public const int UseDisplayNameFieldNumber = 17; + private bool useDisplayName_; + /// + /// Whether to use the display name when decoding examples. This is only used + /// when mapping class text strings to integers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseDisplayName { + get { return useDisplayName_; } + set { + useDisplayName_ = value; + } + } + + /// Field number for the "tf_record_input_reader" field. + public const int TfRecordInputReaderFieldNumber = 8; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.TFRecordInputReader TfRecordInputReader { + get { return inputReaderCase_ == InputReaderOneofCase.TfRecordInputReader ? (global::Tensorflow.Models.ObjectDetection.Protos.TFRecordInputReader) inputReader_ : null; } + set { + inputReader_ = value; + inputReaderCase_ = value == null ? InputReaderOneofCase.None : InputReaderOneofCase.TfRecordInputReader; + } + } + + /// Field number for the "external_input_reader" field. + public const int ExternalInputReaderFieldNumber = 9; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ExternalInputReader ExternalInputReader { + get { return inputReaderCase_ == InputReaderOneofCase.ExternalInputReader ? (global::Tensorflow.Models.ObjectDetection.Protos.ExternalInputReader) inputReader_ : null; } + set { + inputReader_ = value; + inputReaderCase_ = value == null ? InputReaderOneofCase.None : InputReaderOneofCase.ExternalInputReader; + } + } + + private object inputReader_; + /// Enum of possible cases for the "input_reader" oneof. + public enum InputReaderOneofCase { + None = 0, + TfRecordInputReader = 8, + ExternalInputReader = 9, + } + private InputReaderOneofCase inputReaderCase_ = InputReaderOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public InputReaderOneofCase InputReaderCase { + get { return inputReaderCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearInputReader() { + inputReaderCase_ = InputReaderOneofCase.None; + inputReader_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as InputReader); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(InputReader other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (LabelMapPath != other.LabelMapPath) return false; + if (Shuffle != other.Shuffle) return false; + if (ShuffleBufferSize != other.ShuffleBufferSize) return false; + if (FilenamesShuffleBufferSize != other.FilenamesShuffleBufferSize) return false; + if (NumEpochs != other.NumEpochs) return false; + if (Sample1OfNExamples != other.Sample1OfNExamples) return false; + if (NumReaders != other.NumReaders) return false; + if (NumParallelBatches != other.NumParallelBatches) return false; + if (NumPrefetchBatches != other.NumPrefetchBatches) return false; + if (QueueCapacity != other.QueueCapacity) return false; + if (MinAfterDequeue != other.MinAfterDequeue) return false; + if (ReadBlockLength != other.ReadBlockLength) return false; + if (PrefetchSize != other.PrefetchSize) return false; + if (NumParallelMapCalls != other.NumParallelMapCalls) return false; + if (NumAdditionalChannels != other.NumAdditionalChannels) return false; + if (NumKeypoints != other.NumKeypoints) return false; + if (MaxNumberOfBoxes != other.MaxNumberOfBoxes) return false; + if (LoadMulticlassScores != other.LoadMulticlassScores) return false; + if (LoadInstanceMasks != other.LoadInstanceMasks) return false; + if (MaskType != other.MaskType) return false; + if (UseDisplayName != other.UseDisplayName) return false; + if (!object.Equals(TfRecordInputReader, other.TfRecordInputReader)) return false; + if (!object.Equals(ExternalInputReader, other.ExternalInputReader)) return false; + if (InputReaderCase != other.InputReaderCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (LabelMapPath.Length != 0) hash ^= LabelMapPath.GetHashCode(); + if (Shuffle != false) hash ^= Shuffle.GetHashCode(); + if (ShuffleBufferSize != 0) hash ^= ShuffleBufferSize.GetHashCode(); + if (FilenamesShuffleBufferSize != 0) hash ^= FilenamesShuffleBufferSize.GetHashCode(); + if (NumEpochs != 0) hash ^= NumEpochs.GetHashCode(); + if (Sample1OfNExamples != 0) hash ^= Sample1OfNExamples.GetHashCode(); + if (NumReaders != 0) hash ^= NumReaders.GetHashCode(); + if (NumParallelBatches != 0) hash ^= NumParallelBatches.GetHashCode(); + if (NumPrefetchBatches != 0) hash ^= NumPrefetchBatches.GetHashCode(); + if (QueueCapacity != 0) hash ^= QueueCapacity.GetHashCode(); + if (MinAfterDequeue != 0) hash ^= MinAfterDequeue.GetHashCode(); + if (ReadBlockLength != 0) hash ^= ReadBlockLength.GetHashCode(); + if (PrefetchSize != 0) hash ^= PrefetchSize.GetHashCode(); + if (NumParallelMapCalls != 0) hash ^= NumParallelMapCalls.GetHashCode(); + if (NumAdditionalChannels != 0) hash ^= NumAdditionalChannels.GetHashCode(); + if (NumKeypoints != 0) hash ^= NumKeypoints.GetHashCode(); + if (MaxNumberOfBoxes != 0) hash ^= MaxNumberOfBoxes.GetHashCode(); + if (LoadMulticlassScores != false) hash ^= LoadMulticlassScores.GetHashCode(); + if (LoadInstanceMasks != false) hash ^= LoadInstanceMasks.GetHashCode(); + if (MaskType != 0) hash ^= MaskType.GetHashCode(); + if (UseDisplayName != false) hash ^= UseDisplayName.GetHashCode(); + if (inputReaderCase_ == InputReaderOneofCase.TfRecordInputReader) hash ^= TfRecordInputReader.GetHashCode(); + if (inputReaderCase_ == InputReaderOneofCase.ExternalInputReader) hash ^= ExternalInputReader.GetHashCode(); + hash ^= (int) inputReaderCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (LabelMapPath.Length != 0) { + output.WriteRawTag(10); + output.WriteString(LabelMapPath); + } + if (Shuffle != false) { + output.WriteRawTag(16); + output.WriteBool(Shuffle); + } + if (QueueCapacity != 0) { + output.WriteRawTag(24); + output.WriteUInt32(QueueCapacity); + } + if (MinAfterDequeue != 0) { + output.WriteRawTag(32); + output.WriteUInt32(MinAfterDequeue); + } + if (NumEpochs != 0) { + output.WriteRawTag(40); + output.WriteUInt32(NumEpochs); + } + if (NumReaders != 0) { + output.WriteRawTag(48); + output.WriteUInt32(NumReaders); + } + if (LoadInstanceMasks != false) { + output.WriteRawTag(56); + output.WriteBool(LoadInstanceMasks); + } + if (inputReaderCase_ == InputReaderOneofCase.TfRecordInputReader) { + output.WriteRawTag(66); + output.WriteMessage(TfRecordInputReader); + } + if (inputReaderCase_ == InputReaderOneofCase.ExternalInputReader) { + output.WriteRawTag(74); + output.WriteMessage(ExternalInputReader); + } + if (MaskType != 0) { + output.WriteRawTag(80); + output.WriteEnum((int) MaskType); + } + if (ShuffleBufferSize != 0) { + output.WriteRawTag(88); + output.WriteUInt32(ShuffleBufferSize); + } + if (FilenamesShuffleBufferSize != 0) { + output.WriteRawTag(96); + output.WriteUInt32(FilenamesShuffleBufferSize); + } + if (PrefetchSize != 0) { + output.WriteRawTag(104); + output.WriteUInt32(PrefetchSize); + } + if (NumParallelMapCalls != 0) { + output.WriteRawTag(112); + output.WriteUInt32(NumParallelMapCalls); + } + if (ReadBlockLength != 0) { + output.WriteRawTag(120); + output.WriteUInt32(ReadBlockLength); + } + if (NumKeypoints != 0) { + output.WriteRawTag(128, 1); + output.WriteUInt32(NumKeypoints); + } + if (UseDisplayName != false) { + output.WriteRawTag(136, 1); + output.WriteBool(UseDisplayName); + } + if (NumAdditionalChannels != 0) { + output.WriteRawTag(144, 1); + output.WriteInt32(NumAdditionalChannels); + } + if (NumParallelBatches != 0) { + output.WriteRawTag(152, 1); + output.WriteUInt32(NumParallelBatches); + } + if (NumPrefetchBatches != 0) { + output.WriteRawTag(160, 1); + output.WriteInt32(NumPrefetchBatches); + } + if (MaxNumberOfBoxes != 0) { + output.WriteRawTag(168, 1); + output.WriteInt32(MaxNumberOfBoxes); + } + if (Sample1OfNExamples != 0) { + output.WriteRawTag(176, 1); + output.WriteUInt32(Sample1OfNExamples); + } + if (Name.Length != 0) { + output.WriteRawTag(186, 1); + output.WriteString(Name); + } + if (LoadMulticlassScores != false) { + output.WriteRawTag(192, 1); + output.WriteBool(LoadMulticlassScores); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (LabelMapPath.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(LabelMapPath); + } + if (Shuffle != false) { + size += 1 + 1; + } + if (ShuffleBufferSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(ShuffleBufferSize); + } + if (FilenamesShuffleBufferSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(FilenamesShuffleBufferSize); + } + if (NumEpochs != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(NumEpochs); + } + if (Sample1OfNExamples != 0) { + size += 2 + pb::CodedOutputStream.ComputeUInt32Size(Sample1OfNExamples); + } + if (NumReaders != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(NumReaders); + } + if (NumParallelBatches != 0) { + size += 2 + pb::CodedOutputStream.ComputeUInt32Size(NumParallelBatches); + } + if (NumPrefetchBatches != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(NumPrefetchBatches); + } + if (QueueCapacity != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(QueueCapacity); + } + if (MinAfterDequeue != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(MinAfterDequeue); + } + if (ReadBlockLength != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(ReadBlockLength); + } + if (PrefetchSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(PrefetchSize); + } + if (NumParallelMapCalls != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(NumParallelMapCalls); + } + if (NumAdditionalChannels != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(NumAdditionalChannels); + } + if (NumKeypoints != 0) { + size += 2 + pb::CodedOutputStream.ComputeUInt32Size(NumKeypoints); + } + if (MaxNumberOfBoxes != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(MaxNumberOfBoxes); + } + if (LoadMulticlassScores != false) { + size += 2 + 1; + } + if (LoadInstanceMasks != false) { + size += 1 + 1; + } + if (MaskType != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) MaskType); + } + if (UseDisplayName != false) { + size += 2 + 1; + } + if (inputReaderCase_ == InputReaderOneofCase.TfRecordInputReader) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TfRecordInputReader); + } + if (inputReaderCase_ == InputReaderOneofCase.ExternalInputReader) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ExternalInputReader); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(InputReader other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.LabelMapPath.Length != 0) { + LabelMapPath = other.LabelMapPath; + } + if (other.Shuffle != false) { + Shuffle = other.Shuffle; + } + if (other.ShuffleBufferSize != 0) { + ShuffleBufferSize = other.ShuffleBufferSize; + } + if (other.FilenamesShuffleBufferSize != 0) { + FilenamesShuffleBufferSize = other.FilenamesShuffleBufferSize; + } + if (other.NumEpochs != 0) { + NumEpochs = other.NumEpochs; + } + if (other.Sample1OfNExamples != 0) { + Sample1OfNExamples = other.Sample1OfNExamples; + } + if (other.NumReaders != 0) { + NumReaders = other.NumReaders; + } + if (other.NumParallelBatches != 0) { + NumParallelBatches = other.NumParallelBatches; + } + if (other.NumPrefetchBatches != 0) { + NumPrefetchBatches = other.NumPrefetchBatches; + } + if (other.QueueCapacity != 0) { + QueueCapacity = other.QueueCapacity; + } + if (other.MinAfterDequeue != 0) { + MinAfterDequeue = other.MinAfterDequeue; + } + if (other.ReadBlockLength != 0) { + ReadBlockLength = other.ReadBlockLength; + } + if (other.PrefetchSize != 0) { + PrefetchSize = other.PrefetchSize; + } + if (other.NumParallelMapCalls != 0) { + NumParallelMapCalls = other.NumParallelMapCalls; + } + if (other.NumAdditionalChannels != 0) { + NumAdditionalChannels = other.NumAdditionalChannels; + } + if (other.NumKeypoints != 0) { + NumKeypoints = other.NumKeypoints; + } + if (other.MaxNumberOfBoxes != 0) { + MaxNumberOfBoxes = other.MaxNumberOfBoxes; + } + if (other.LoadMulticlassScores != false) { + LoadMulticlassScores = other.LoadMulticlassScores; + } + if (other.LoadInstanceMasks != false) { + LoadInstanceMasks = other.LoadInstanceMasks; + } + if (other.MaskType != 0) { + MaskType = other.MaskType; + } + if (other.UseDisplayName != false) { + UseDisplayName = other.UseDisplayName; + } + switch (other.InputReaderCase) { + case InputReaderOneofCase.TfRecordInputReader: + if (TfRecordInputReader == null) { + TfRecordInputReader = new global::Tensorflow.Models.ObjectDetection.Protos.TFRecordInputReader(); + } + TfRecordInputReader.MergeFrom(other.TfRecordInputReader); + break; + case InputReaderOneofCase.ExternalInputReader: + if (ExternalInputReader == null) { + ExternalInputReader = new global::Tensorflow.Models.ObjectDetection.Protos.ExternalInputReader(); + } + ExternalInputReader.MergeFrom(other.ExternalInputReader); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + LabelMapPath = input.ReadString(); + break; + } + case 16: { + Shuffle = input.ReadBool(); + break; + } + case 24: { + QueueCapacity = input.ReadUInt32(); + break; + } + case 32: { + MinAfterDequeue = input.ReadUInt32(); + break; + } + case 40: { + NumEpochs = input.ReadUInt32(); + break; + } + case 48: { + NumReaders = input.ReadUInt32(); + break; + } + case 56: { + LoadInstanceMasks = input.ReadBool(); + break; + } + case 66: { + global::Tensorflow.Models.ObjectDetection.Protos.TFRecordInputReader subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.TFRecordInputReader(); + if (inputReaderCase_ == InputReaderOneofCase.TfRecordInputReader) { + subBuilder.MergeFrom(TfRecordInputReader); + } + input.ReadMessage(subBuilder); + TfRecordInputReader = subBuilder; + break; + } + case 74: { + global::Tensorflow.Models.ObjectDetection.Protos.ExternalInputReader subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ExternalInputReader(); + if (inputReaderCase_ == InputReaderOneofCase.ExternalInputReader) { + subBuilder.MergeFrom(ExternalInputReader); + } + input.ReadMessage(subBuilder); + ExternalInputReader = subBuilder; + break; + } + case 80: { + maskType_ = (global::Tensorflow.Models.ObjectDetection.Protos.InstanceMaskType) input.ReadEnum(); + break; + } + case 88: { + ShuffleBufferSize = input.ReadUInt32(); + break; + } + case 96: { + FilenamesShuffleBufferSize = input.ReadUInt32(); + break; + } + case 104: { + PrefetchSize = input.ReadUInt32(); + break; + } + case 112: { + NumParallelMapCalls = input.ReadUInt32(); + break; + } + case 120: { + ReadBlockLength = input.ReadUInt32(); + break; + } + case 128: { + NumKeypoints = input.ReadUInt32(); + break; + } + case 136: { + UseDisplayName = input.ReadBool(); + break; + } + case 144: { + NumAdditionalChannels = input.ReadInt32(); + break; + } + case 152: { + NumParallelBatches = input.ReadUInt32(); + break; + } + case 160: { + NumPrefetchBatches = input.ReadInt32(); + break; + } + case 168: { + MaxNumberOfBoxes = input.ReadInt32(); + break; + } + case 176: { + Sample1OfNExamples = input.ReadUInt32(); + break; + } + case 186: { + Name = input.ReadString(); + break; + } + case 192: { + LoadMulticlassScores = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// An input reader that reads TF Example protos from local TFRecord files. + /// + public sealed partial class TFRecordInputReader : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TFRecordInputReader()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.InputReaderReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TFRecordInputReader() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TFRecordInputReader(TFRecordInputReader other) : this() { + inputPath_ = other.inputPath_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TFRecordInputReader Clone() { + return new TFRecordInputReader(this); + } + + /// Field number for the "input_path" field. + public const int InputPathFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_inputPath_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField inputPath_ = new pbc::RepeatedField(); + /// + /// Path(s) to `TFRecordFile`s. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField InputPath { + get { return inputPath_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TFRecordInputReader); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TFRecordInputReader other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!inputPath_.Equals(other.inputPath_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= inputPath_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + inputPath_.WriteTo(output, _repeated_inputPath_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += inputPath_.CalculateSize(_repeated_inputPath_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TFRecordInputReader other) { + if (other == null) { + return; + } + inputPath_.Add(other.inputPath_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + inputPath_.AddEntriesFrom(input, _repeated_inputPath_codec); + break; + } + } + } + } + + } + + /// + /// An externally defined input reader. Users may define an extension to this + /// proto to interface their own input readers. + /// + public sealed partial class ExternalInputReader : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExternalInputReader()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.InputReaderReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ExternalInputReader() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ExternalInputReader(ExternalInputReader other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ExternalInputReader Clone() { + return new ExternalInputReader(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ExternalInputReader); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ExternalInputReader other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ExternalInputReader other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/KeypointBoxCoder.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/KeypointBoxCoder.cs new file mode 100644 index 00000000..bf1f76bd --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/KeypointBoxCoder.cs @@ -0,0 +1,300 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/keypoint_box_coder.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/keypoint_box_coder.proto + public static partial class KeypointBoxCoderReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/keypoint_box_coder.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static KeypointBoxCoderReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjBvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9rZXlwb2ludF9ib3hfY29kZXIu", + "cHJvdG8SF29iamVjdF9kZXRlY3Rpb24ucHJvdG9zInYKEEtleXBvaW50Qm94", + "Q29kZXISFQoNbnVtX2tleXBvaW50cxgBIAEoBRIPCgd5X3NjYWxlGAIgASgC", + "Eg8KB3hfc2NhbGUYAyABKAISFAoMaGVpZ2h0X3NjYWxlGAQgASgCEhMKC3dp", + "ZHRoX3NjYWxlGAUgASgCYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.KeypointBoxCoder), global::Tensorflow.Models.ObjectDetection.Protos.KeypointBoxCoder.Parser, new[]{ "NumKeypoints", "YScale", "XScale", "HeightScale", "WidthScale" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for KeypointBoxCoder. See + /// box_coders/keypoint_box_coder.py for details. + /// + public sealed partial class KeypointBoxCoder : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new KeypointBoxCoder()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.KeypointBoxCoderReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public KeypointBoxCoder() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public KeypointBoxCoder(KeypointBoxCoder other) : this() { + numKeypoints_ = other.numKeypoints_; + yScale_ = other.yScale_; + xScale_ = other.xScale_; + heightScale_ = other.heightScale_; + widthScale_ = other.widthScale_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public KeypointBoxCoder Clone() { + return new KeypointBoxCoder(this); + } + + /// Field number for the "num_keypoints" field. + public const int NumKeypointsFieldNumber = 1; + private int numKeypoints_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumKeypoints { + get { return numKeypoints_; } + set { + numKeypoints_ = value; + } + } + + /// Field number for the "y_scale" field. + public const int YScaleFieldNumber = 2; + private float yScale_; + /// + /// Scale factor for anchor encoded box center and keypoints. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float YScale { + get { return yScale_; } + set { + yScale_ = value; + } + } + + /// Field number for the "x_scale" field. + public const int XScaleFieldNumber = 3; + private float xScale_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float XScale { + get { return xScale_; } + set { + xScale_ = value; + } + } + + /// Field number for the "height_scale" field. + public const int HeightScaleFieldNumber = 4; + private float heightScale_; + /// + /// Scale factor for anchor encoded box height. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float HeightScale { + get { return heightScale_; } + set { + heightScale_ = value; + } + } + + /// Field number for the "width_scale" field. + public const int WidthScaleFieldNumber = 5; + private float widthScale_; + /// + /// Scale factor for anchor encoded box width. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float WidthScale { + get { return widthScale_; } + set { + widthScale_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as KeypointBoxCoder); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(KeypointBoxCoder other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NumKeypoints != other.NumKeypoints) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(YScale, other.YScale)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(XScale, other.XScale)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(HeightScale, other.HeightScale)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(WidthScale, other.WidthScale)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (NumKeypoints != 0) hash ^= NumKeypoints.GetHashCode(); + if (YScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(YScale); + if (XScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(XScale); + if (HeightScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(HeightScale); + if (WidthScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(WidthScale); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (NumKeypoints != 0) { + output.WriteRawTag(8); + output.WriteInt32(NumKeypoints); + } + if (YScale != 0F) { + output.WriteRawTag(21); + output.WriteFloat(YScale); + } + if (XScale != 0F) { + output.WriteRawTag(29); + output.WriteFloat(XScale); + } + if (HeightScale != 0F) { + output.WriteRawTag(37); + output.WriteFloat(HeightScale); + } + if (WidthScale != 0F) { + output.WriteRawTag(45); + output.WriteFloat(WidthScale); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (NumKeypoints != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumKeypoints); + } + if (YScale != 0F) { + size += 1 + 4; + } + if (XScale != 0F) { + size += 1 + 4; + } + if (HeightScale != 0F) { + size += 1 + 4; + } + if (WidthScale != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(KeypointBoxCoder other) { + if (other == null) { + return; + } + if (other.NumKeypoints != 0) { + NumKeypoints = other.NumKeypoints; + } + if (other.YScale != 0F) { + YScale = other.YScale; + } + if (other.XScale != 0F) { + XScale = other.XScale; + } + if (other.HeightScale != 0F) { + HeightScale = other.HeightScale; + } + if (other.WidthScale != 0F) { + WidthScale = other.WidthScale; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NumKeypoints = input.ReadInt32(); + break; + } + case 21: { + YScale = input.ReadFloat(); + break; + } + case 29: { + XScale = input.ReadFloat(); + break; + } + case 37: { + HeightScale = input.ReadFloat(); + break; + } + case 45: { + WidthScale = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Losses.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Losses.cs new file mode 100644 index 00000000..e34ed91b --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Losses.cs @@ -0,0 +1,3009 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/losses.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/losses.proto + public static partial class LossesReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/losses.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static LossesReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiRvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9sb3NzZXMucHJvdG8SF29iamVj", + "dF9kZXRlY3Rpb24ucHJvdG9zIukFCgRMb3NzEkQKEWxvY2FsaXphdGlvbl9s", + "b3NzGAEgASgLMikub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuTG9jYWxpemF0", + "aW9uTG9zcxJIChNjbGFzc2lmaWNhdGlvbl9sb3NzGAIgASgLMisub2JqZWN0", + "X2RldGVjdGlvbi5wcm90b3MuQ2xhc3NpZmljYXRpb25Mb3NzEkUKEmhhcmRf", + "ZXhhbXBsZV9taW5lchgDIAEoCzIpLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9z", + "LkhhcmRFeGFtcGxlTWluZXISHQoVY2xhc3NpZmljYXRpb25fd2VpZ2h0GAQg", + "ASgCEhsKE2xvY2FsaXphdGlvbl93ZWlnaHQYBSABKAISTQoWcmFuZG9tX2V4", + "YW1wbGVfc2FtcGxlchgGIAEoCzItLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9z", + "LlJhbmRvbUV4YW1wbGVTYW1wbGVyEkkKEWVxdWFsaXphdGlvbl9sb3NzGAcg", + "ASgLMi4ub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuTG9zcy5FcXVhbGl6YXRp", + "b25Mb3NzElAKFWV4cGVjdGVkX2xvc3Nfd2VpZ2h0cxgSIAEoDjIxLm9iamVj", + "dF9kZXRlY3Rpb24ucHJvdG9zLkxvc3MuRXhwZWN0ZWRMb3NzV2VpZ2h0cxIg", + "ChhtaW5fbnVtX25lZ2F0aXZlX3NhbXBsZXMYEyABKAISJwofZGVzaXJlZF9u", + "ZWdhdGl2ZV9zYW1wbGluZ19yYXRpbxgUIAEoAho8ChBFcXVhbGl6YXRpb25M", + "b3NzEg4KBndlaWdodBgBIAEoAhIYChBleGNsdWRlX3ByZWZpeGVzGAIgAygJ", + "IlkKE0V4cGVjdGVkTG9zc1dlaWdodHMSCAoETk9ORRAAEhUKEUVYUEVDVEVE", + "X1NBTVBMSU5HEAESIQodUkVXRUlHSFRJTkdfVU5NQVRDSEVEX0FOQ0hPUlMQ", + "AiKaAgoQTG9jYWxpemF0aW9uTG9zcxJKCgt3ZWlnaHRlZF9sMhgBIAEoCzIz", + "Lm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLldlaWdodGVkTDJMb2NhbGl6YXRp", + "b25Mb3NzSAASVwoSd2VpZ2h0ZWRfc21vb3RoX2wxGAIgASgLMjkub2JqZWN0", + "X2RldGVjdGlvbi5wcm90b3MuV2VpZ2h0ZWRTbW9vdGhMMUxvY2FsaXphdGlv", + "bkxvc3NIABJMCgx3ZWlnaHRlZF9pb3UYAyABKAsyNC5vYmplY3RfZGV0ZWN0", + "aW9uLnByb3Rvcy5XZWlnaHRlZElPVUxvY2FsaXphdGlvbkxvc3NIAEITChFs", + "b2NhbGl6YXRpb25fbG9zcyI3ChpXZWlnaHRlZEwyTG9jYWxpemF0aW9uTG9z", + "cxIZChFhbmNob3J3aXNlX291dHB1dBgBIAEoCCJMCiBXZWlnaHRlZFNtb290", + "aEwxTG9jYWxpemF0aW9uTG9zcxIZChFhbmNob3J3aXNlX291dHB1dBgBIAEo", + "CBINCgVkZWx0YRgCIAEoAiIdChtXZWlnaHRlZElPVUxvY2FsaXphdGlvbkxv", + "c3MiggQKEkNsYXNzaWZpY2F0aW9uTG9zcxJWChB3ZWlnaHRlZF9zaWdtb2lk", + "GAEgASgLMjoub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuV2VpZ2h0ZWRTaWdt", + "b2lkQ2xhc3NpZmljYXRpb25Mb3NzSAASVgoQd2VpZ2h0ZWRfc29mdG1heBgC", + "IAEoCzI6Lm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLldlaWdodGVkU29mdG1h", + "eENsYXNzaWZpY2F0aW9uTG9zc0gAEmoKF3dlaWdodGVkX2xvZ2l0c19zb2Z0", + "bWF4GAUgASgLMkcub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuV2VpZ2h0ZWRT", + "b2Z0bWF4Q2xhc3NpZmljYXRpb25BZ2FpbnN0TG9naXRzTG9zc0gAEl4KFGJv", + "b3RzdHJhcHBlZF9zaWdtb2lkGAMgASgLMj4ub2JqZWN0X2RldGVjdGlvbi5w", + "cm90b3MuQm9vdHN0cmFwcGVkU2lnbW9pZENsYXNzaWZpY2F0aW9uTG9zc0gA", + "ElkKFndlaWdodGVkX3NpZ21vaWRfZm9jYWwYBCABKAsyNy5vYmplY3RfZGV0", + "ZWN0aW9uLnByb3Rvcy5TaWdtb2lkRm9jYWxDbGFzc2lmaWNhdGlvbkxvc3NI", + "AEIVChNjbGFzc2lmaWNhdGlvbl9sb3NzIj4KIVdlaWdodGVkU2lnbW9pZENs", + "YXNzaWZpY2F0aW9uTG9zcxIZChFhbmNob3J3aXNlX291dHB1dBgBIAEoCCJZ", + "Ch5TaWdtb2lkRm9jYWxDbGFzc2lmaWNhdGlvbkxvc3MSGQoRYW5jaG9yd2lz", + "ZV9vdXRwdXQYASABKAgSDQoFZ2FtbWEYAiABKAISDQoFYWxwaGEYAyABKAIi", + "UwohV2VpZ2h0ZWRTb2Z0bWF4Q2xhc3NpZmljYXRpb25Mb3NzEhkKEWFuY2hv", + "cndpc2Vfb3V0cHV0GAEgASgIEhMKC2xvZ2l0X3NjYWxlGAIgASgCImAKLldl", + "aWdodGVkU29mdG1heENsYXNzaWZpY2F0aW9uQWdhaW5zdExvZ2l0c0xvc3MS", + "GQoRYW5jaG9yd2lzZV9vdXRwdXQYASABKAgSEwoLbG9naXRfc2NhbGUYAiAB", + "KAIiaQolQm9vdHN0cmFwcGVkU2lnbW9pZENsYXNzaWZpY2F0aW9uTG9zcxIN", + "CgVhbHBoYRgBIAEoAhIWCg5oYXJkX2Jvb3RzdHJhcBgCIAEoCBIZChFhbmNo", + "b3J3aXNlX291dHB1dBgDIAEoCCKMAgoQSGFyZEV4YW1wbGVNaW5lchIZChFu", + "dW1faGFyZF9leGFtcGxlcxgBIAEoBRIVCg1pb3VfdGhyZXNob2xkGAIgASgC", + "EkUKCWxvc3NfdHlwZRgDIAEoDjIyLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9z", + "LkhhcmRFeGFtcGxlTWluZXIuTG9zc1R5cGUSIgoabWF4X25lZ2F0aXZlc19w", + "ZXJfcG9zaXRpdmUYBCABKAUSHwoXbWluX25lZ2F0aXZlc19wZXJfaW1hZ2UY", + "BSABKAUiOgoITG9zc1R5cGUSCAoEQk9USBAAEhIKDkNMQVNTSUZJQ0FUSU9O", + "EAESEAoMTE9DQUxJWkFUSU9OEAIiOAoUUmFuZG9tRXhhbXBsZVNhbXBsZXIS", + "IAoYcG9zaXRpdmVfc2FtcGxlX2ZyYWN0aW9uGAEgASgCYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Loss), global::Tensorflow.Models.ObjectDetection.Protos.Loss.Parser, new[]{ "LocalizationLoss", "ClassificationLoss", "HardExampleMiner", "ClassificationWeight", "LocalizationWeight", "RandomExampleSampler", "EqualizationLoss", "ExpectedLossWeights", "MinNumNegativeSamples", "DesiredNegativeSamplingRatio" }, null, new[]{ typeof(global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.ExpectedLossWeights) }, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.EqualizationLoss), global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.EqualizationLoss.Parser, new[]{ "Weight", "ExcludePrefixes" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.LocalizationLoss), global::Tensorflow.Models.ObjectDetection.Protos.LocalizationLoss.Parser, new[]{ "WeightedL2", "WeightedSmoothL1", "WeightedIou" }, new[]{ "LocalizationLoss" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.WeightedL2LocalizationLoss), global::Tensorflow.Models.ObjectDetection.Protos.WeightedL2LocalizationLoss.Parser, new[]{ "AnchorwiseOutput" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.WeightedSmoothL1LocalizationLoss), global::Tensorflow.Models.ObjectDetection.Protos.WeightedSmoothL1LocalizationLoss.Parser, new[]{ "AnchorwiseOutput", "Delta" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.WeightedIOULocalizationLoss), global::Tensorflow.Models.ObjectDetection.Protos.WeightedIOULocalizationLoss.Parser, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss), global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss.Parser, new[]{ "WeightedSigmoid", "WeightedSoftmax", "WeightedLogitsSoftmax", "BootstrappedSigmoid", "WeightedSigmoidFocal" }, new[]{ "ClassificationLoss" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.WeightedSigmoidClassificationLoss), global::Tensorflow.Models.ObjectDetection.Protos.WeightedSigmoidClassificationLoss.Parser, new[]{ "AnchorwiseOutput" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SigmoidFocalClassificationLoss), global::Tensorflow.Models.ObjectDetection.Protos.SigmoidFocalClassificationLoss.Parser, new[]{ "AnchorwiseOutput", "Gamma", "Alpha" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationLoss), global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationLoss.Parser, new[]{ "AnchorwiseOutput", "LogitScale" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationAgainstLogitsLoss), global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationAgainstLogitsLoss.Parser, new[]{ "AnchorwiseOutput", "LogitScale" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.BootstrappedSigmoidClassificationLoss), global::Tensorflow.Models.ObjectDetection.Protos.BootstrappedSigmoidClassificationLoss.Parser, new[]{ "Alpha", "HardBootstrap", "AnchorwiseOutput" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner), global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner.Parser, new[]{ "NumHardExamples", "IouThreshold", "LossType", "MaxNegativesPerPositive", "MinNegativesPerImage" }, null, new[]{ typeof(global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner.Types.LossType) }, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomExampleSampler), global::Tensorflow.Models.ObjectDetection.Protos.RandomExampleSampler.Parser, new[]{ "PositiveSampleFraction" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Message for configuring the localization loss, classification loss and hard + /// example miner used for training object detection models. See core/losses.py + /// for details + /// + public sealed partial class Loss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Loss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Loss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Loss(Loss other) : this() { + localizationLoss_ = other.localizationLoss_ != null ? other.localizationLoss_.Clone() : null; + classificationLoss_ = other.classificationLoss_ != null ? other.classificationLoss_.Clone() : null; + hardExampleMiner_ = other.hardExampleMiner_ != null ? other.hardExampleMiner_.Clone() : null; + classificationWeight_ = other.classificationWeight_; + localizationWeight_ = other.localizationWeight_; + randomExampleSampler_ = other.randomExampleSampler_ != null ? other.randomExampleSampler_.Clone() : null; + equalizationLoss_ = other.equalizationLoss_ != null ? other.equalizationLoss_.Clone() : null; + expectedLossWeights_ = other.expectedLossWeights_; + minNumNegativeSamples_ = other.minNumNegativeSamples_; + desiredNegativeSamplingRatio_ = other.desiredNegativeSamplingRatio_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Loss Clone() { + return new Loss(this); + } + + /// Field number for the "localization_loss" field. + public const int LocalizationLossFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.LocalizationLoss localizationLoss_; + /// + /// Localization loss to use. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.LocalizationLoss LocalizationLoss { + get { return localizationLoss_; } + set { + localizationLoss_ = value; + } + } + + /// Field number for the "classification_loss" field. + public const int ClassificationLossFieldNumber = 2; + private global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss classificationLoss_; + /// + /// Classification loss to use. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss ClassificationLoss { + get { return classificationLoss_; } + set { + classificationLoss_ = value; + } + } + + /// Field number for the "hard_example_miner" field. + public const int HardExampleMinerFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner hardExampleMiner_; + /// + /// If not left to default, applies hard example mining. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner HardExampleMiner { + get { return hardExampleMiner_; } + set { + hardExampleMiner_ = value; + } + } + + /// Field number for the "classification_weight" field. + public const int ClassificationWeightFieldNumber = 4; + private float classificationWeight_; + /// + /// Classification loss weight. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float ClassificationWeight { + get { return classificationWeight_; } + set { + classificationWeight_ = value; + } + } + + /// Field number for the "localization_weight" field. + public const int LocalizationWeightFieldNumber = 5; + private float localizationWeight_; + /// + /// Localization loss weight. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float LocalizationWeight { + get { return localizationWeight_; } + set { + localizationWeight_ = value; + } + } + + /// Field number for the "random_example_sampler" field. + public const int RandomExampleSamplerFieldNumber = 6; + private global::Tensorflow.Models.ObjectDetection.Protos.RandomExampleSampler randomExampleSampler_; + /// + /// If not left to default, applies random example sampling. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomExampleSampler RandomExampleSampler { + get { return randomExampleSampler_; } + set { + randomExampleSampler_ = value; + } + } + + /// Field number for the "equalization_loss" field. + public const int EqualizationLossFieldNumber = 7; + private global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.EqualizationLoss equalizationLoss_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.EqualizationLoss EqualizationLoss { + get { return equalizationLoss_; } + set { + equalizationLoss_ = value; + } + } + + /// Field number for the "expected_loss_weights" field. + public const int ExpectedLossWeightsFieldNumber = 18; + private global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.ExpectedLossWeights expectedLossWeights_ = 0; + /// + /// Method to compute expected loss weights with respect to balanced + /// positive/negative sampling scheme. If NONE, use explicit sampling. + /// TODO(birdbrain): Move under ExpectedLossWeights. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.ExpectedLossWeights ExpectedLossWeights { + get { return expectedLossWeights_; } + set { + expectedLossWeights_ = value; + } + } + + /// Field number for the "min_num_negative_samples" field. + public const int MinNumNegativeSamplesFieldNumber = 19; + private float minNumNegativeSamples_; + /// + /// Minimum number of effective negative samples. + /// Only applies if expected_loss_weights is not NONE. + /// TODO(birdbrain): Move under ExpectedLossWeights. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinNumNegativeSamples { + get { return minNumNegativeSamples_; } + set { + minNumNegativeSamples_ = value; + } + } + + /// Field number for the "desired_negative_sampling_ratio" field. + public const int DesiredNegativeSamplingRatioFieldNumber = 20; + private float desiredNegativeSamplingRatio_; + /// + /// Desired number of effective negative samples per positive sample. + /// Only applies if expected_loss_weights is not NONE. + /// TODO(birdbrain): Move under ExpectedLossWeights. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float DesiredNegativeSamplingRatio { + get { return desiredNegativeSamplingRatio_; } + set { + desiredNegativeSamplingRatio_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Loss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Loss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(LocalizationLoss, other.LocalizationLoss)) return false; + if (!object.Equals(ClassificationLoss, other.ClassificationLoss)) return false; + if (!object.Equals(HardExampleMiner, other.HardExampleMiner)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(ClassificationWeight, other.ClassificationWeight)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(LocalizationWeight, other.LocalizationWeight)) return false; + if (!object.Equals(RandomExampleSampler, other.RandomExampleSampler)) return false; + if (!object.Equals(EqualizationLoss, other.EqualizationLoss)) return false; + if (ExpectedLossWeights != other.ExpectedLossWeights) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinNumNegativeSamples, other.MinNumNegativeSamples)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(DesiredNegativeSamplingRatio, other.DesiredNegativeSamplingRatio)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (localizationLoss_ != null) hash ^= LocalizationLoss.GetHashCode(); + if (classificationLoss_ != null) hash ^= ClassificationLoss.GetHashCode(); + if (hardExampleMiner_ != null) hash ^= HardExampleMiner.GetHashCode(); + if (ClassificationWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(ClassificationWeight); + if (LocalizationWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(LocalizationWeight); + if (randomExampleSampler_ != null) hash ^= RandomExampleSampler.GetHashCode(); + if (equalizationLoss_ != null) hash ^= EqualizationLoss.GetHashCode(); + if (ExpectedLossWeights != 0) hash ^= ExpectedLossWeights.GetHashCode(); + if (MinNumNegativeSamples != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinNumNegativeSamples); + if (DesiredNegativeSamplingRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(DesiredNegativeSamplingRatio); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (localizationLoss_ != null) { + output.WriteRawTag(10); + output.WriteMessage(LocalizationLoss); + } + if (classificationLoss_ != null) { + output.WriteRawTag(18); + output.WriteMessage(ClassificationLoss); + } + if (hardExampleMiner_ != null) { + output.WriteRawTag(26); + output.WriteMessage(HardExampleMiner); + } + if (ClassificationWeight != 0F) { + output.WriteRawTag(37); + output.WriteFloat(ClassificationWeight); + } + if (LocalizationWeight != 0F) { + output.WriteRawTag(45); + output.WriteFloat(LocalizationWeight); + } + if (randomExampleSampler_ != null) { + output.WriteRawTag(50); + output.WriteMessage(RandomExampleSampler); + } + if (equalizationLoss_ != null) { + output.WriteRawTag(58); + output.WriteMessage(EqualizationLoss); + } + if (ExpectedLossWeights != 0) { + output.WriteRawTag(144, 1); + output.WriteEnum((int) ExpectedLossWeights); + } + if (MinNumNegativeSamples != 0F) { + output.WriteRawTag(157, 1); + output.WriteFloat(MinNumNegativeSamples); + } + if (DesiredNegativeSamplingRatio != 0F) { + output.WriteRawTag(165, 1); + output.WriteFloat(DesiredNegativeSamplingRatio); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (localizationLoss_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(LocalizationLoss); + } + if (classificationLoss_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ClassificationLoss); + } + if (hardExampleMiner_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(HardExampleMiner); + } + if (ClassificationWeight != 0F) { + size += 1 + 4; + } + if (LocalizationWeight != 0F) { + size += 1 + 4; + } + if (randomExampleSampler_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomExampleSampler); + } + if (equalizationLoss_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(EqualizationLoss); + } + if (ExpectedLossWeights != 0) { + size += 2 + pb::CodedOutputStream.ComputeEnumSize((int) ExpectedLossWeights); + } + if (MinNumNegativeSamples != 0F) { + size += 2 + 4; + } + if (DesiredNegativeSamplingRatio != 0F) { + size += 2 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Loss other) { + if (other == null) { + return; + } + if (other.localizationLoss_ != null) { + if (localizationLoss_ == null) { + localizationLoss_ = new global::Tensorflow.Models.ObjectDetection.Protos.LocalizationLoss(); + } + LocalizationLoss.MergeFrom(other.LocalizationLoss); + } + if (other.classificationLoss_ != null) { + if (classificationLoss_ == null) { + classificationLoss_ = new global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss(); + } + ClassificationLoss.MergeFrom(other.ClassificationLoss); + } + if (other.hardExampleMiner_ != null) { + if (hardExampleMiner_ == null) { + hardExampleMiner_ = new global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner(); + } + HardExampleMiner.MergeFrom(other.HardExampleMiner); + } + if (other.ClassificationWeight != 0F) { + ClassificationWeight = other.ClassificationWeight; + } + if (other.LocalizationWeight != 0F) { + LocalizationWeight = other.LocalizationWeight; + } + if (other.randomExampleSampler_ != null) { + if (randomExampleSampler_ == null) { + randomExampleSampler_ = new global::Tensorflow.Models.ObjectDetection.Protos.RandomExampleSampler(); + } + RandomExampleSampler.MergeFrom(other.RandomExampleSampler); + } + if (other.equalizationLoss_ != null) { + if (equalizationLoss_ == null) { + equalizationLoss_ = new global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.EqualizationLoss(); + } + EqualizationLoss.MergeFrom(other.EqualizationLoss); + } + if (other.ExpectedLossWeights != 0) { + ExpectedLossWeights = other.ExpectedLossWeights; + } + if (other.MinNumNegativeSamples != 0F) { + MinNumNegativeSamples = other.MinNumNegativeSamples; + } + if (other.DesiredNegativeSamplingRatio != 0F) { + DesiredNegativeSamplingRatio = other.DesiredNegativeSamplingRatio; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (localizationLoss_ == null) { + localizationLoss_ = new global::Tensorflow.Models.ObjectDetection.Protos.LocalizationLoss(); + } + input.ReadMessage(localizationLoss_); + break; + } + case 18: { + if (classificationLoss_ == null) { + classificationLoss_ = new global::Tensorflow.Models.ObjectDetection.Protos.ClassificationLoss(); + } + input.ReadMessage(classificationLoss_); + break; + } + case 26: { + if (hardExampleMiner_ == null) { + hardExampleMiner_ = new global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner(); + } + input.ReadMessage(hardExampleMiner_); + break; + } + case 37: { + ClassificationWeight = input.ReadFloat(); + break; + } + case 45: { + LocalizationWeight = input.ReadFloat(); + break; + } + case 50: { + if (randomExampleSampler_ == null) { + randomExampleSampler_ = new global::Tensorflow.Models.ObjectDetection.Protos.RandomExampleSampler(); + } + input.ReadMessage(randomExampleSampler_); + break; + } + case 58: { + if (equalizationLoss_ == null) { + equalizationLoss_ = new global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.EqualizationLoss(); + } + input.ReadMessage(equalizationLoss_); + break; + } + case 144: { + expectedLossWeights_ = (global::Tensorflow.Models.ObjectDetection.Protos.Loss.Types.ExpectedLossWeights) input.ReadEnum(); + break; + } + case 157: { + MinNumNegativeSamples = input.ReadFloat(); + break; + } + case 165: { + DesiredNegativeSamplingRatio = input.ReadFloat(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the Loss message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + public enum ExpectedLossWeights { + [pbr::OriginalName("NONE")] None = 0, + /// + /// Use expected_classification_loss_by_expected_sampling + /// from third_party/tensorflow_models/object_detection/utils/ops.py + /// + [pbr::OriginalName("EXPECTED_SAMPLING")] ExpectedSampling = 1, + /// + /// Use expected_classification_loss_by_reweighting_unmatched_anchors + /// from third_party/tensorflow_models/object_detection/utils/ops.py + /// + [pbr::OriginalName("REWEIGHTING_UNMATCHED_ANCHORS")] ReweightingUnmatchedAnchors = 2, + } + + /// + /// Equalization loss. + /// + public sealed partial class EqualizationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new EqualizationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.Loss.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EqualizationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EqualizationLoss(EqualizationLoss other) : this() { + weight_ = other.weight_; + excludePrefixes_ = other.excludePrefixes_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public EqualizationLoss Clone() { + return new EqualizationLoss(this); + } + + /// Field number for the "weight" field. + public const int WeightFieldNumber = 1; + private float weight_; + /// + /// Weight equalization loss strength. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Weight { + get { return weight_; } + set { + weight_ = value; + } + } + + /// Field number for the "exclude_prefixes" field. + public const int ExcludePrefixesFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_excludePrefixes_codec + = pb::FieldCodec.ForString(18); + private readonly pbc::RepeatedField excludePrefixes_ = new pbc::RepeatedField(); + /// + /// When computing equalization loss, ops that start with + /// equalization_exclude_prefixes will be ignored. Only used when + /// equalization_weight > 0. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField ExcludePrefixes { + get { return excludePrefixes_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as EqualizationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(EqualizationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Weight, other.Weight)) return false; + if(!excludePrefixes_.Equals(other.excludePrefixes_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Weight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Weight); + hash ^= excludePrefixes_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Weight != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Weight); + } + excludePrefixes_.WriteTo(output, _repeated_excludePrefixes_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Weight != 0F) { + size += 1 + 4; + } + size += excludePrefixes_.CalculateSize(_repeated_excludePrefixes_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(EqualizationLoss other) { + if (other == null) { + return; + } + if (other.Weight != 0F) { + Weight = other.Weight; + } + excludePrefixes_.Add(other.excludePrefixes_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Weight = input.ReadFloat(); + break; + } + case 18: { + excludePrefixes_.AddEntriesFrom(input, _repeated_excludePrefixes_codec); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// Configuration for bounding box localization loss function. + /// + public sealed partial class LocalizationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LocalizationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LocalizationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LocalizationLoss(LocalizationLoss other) : this() { + switch (other.LocalizationLossCase) { + case LocalizationLossOneofCase.WeightedL2: + WeightedL2 = other.WeightedL2.Clone(); + break; + case LocalizationLossOneofCase.WeightedSmoothL1: + WeightedSmoothL1 = other.WeightedSmoothL1.Clone(); + break; + case LocalizationLossOneofCase.WeightedIou: + WeightedIou = other.WeightedIou.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LocalizationLoss Clone() { + return new LocalizationLoss(this); + } + + /// Field number for the "weighted_l2" field. + public const int WeightedL2FieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.WeightedL2LocalizationLoss WeightedL2 { + get { return localizationLossCase_ == LocalizationLossOneofCase.WeightedL2 ? (global::Tensorflow.Models.ObjectDetection.Protos.WeightedL2LocalizationLoss) localizationLoss_ : null; } + set { + localizationLoss_ = value; + localizationLossCase_ = value == null ? LocalizationLossOneofCase.None : LocalizationLossOneofCase.WeightedL2; + } + } + + /// Field number for the "weighted_smooth_l1" field. + public const int WeightedSmoothL1FieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.WeightedSmoothL1LocalizationLoss WeightedSmoothL1 { + get { return localizationLossCase_ == LocalizationLossOneofCase.WeightedSmoothL1 ? (global::Tensorflow.Models.ObjectDetection.Protos.WeightedSmoothL1LocalizationLoss) localizationLoss_ : null; } + set { + localizationLoss_ = value; + localizationLossCase_ = value == null ? LocalizationLossOneofCase.None : LocalizationLossOneofCase.WeightedSmoothL1; + } + } + + /// Field number for the "weighted_iou" field. + public const int WeightedIouFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.WeightedIOULocalizationLoss WeightedIou { + get { return localizationLossCase_ == LocalizationLossOneofCase.WeightedIou ? (global::Tensorflow.Models.ObjectDetection.Protos.WeightedIOULocalizationLoss) localizationLoss_ : null; } + set { + localizationLoss_ = value; + localizationLossCase_ = value == null ? LocalizationLossOneofCase.None : LocalizationLossOneofCase.WeightedIou; + } + } + + private object localizationLoss_; + /// Enum of possible cases for the "localization_loss" oneof. + public enum LocalizationLossOneofCase { + None = 0, + WeightedL2 = 1, + WeightedSmoothL1 = 2, + WeightedIou = 3, + } + private LocalizationLossOneofCase localizationLossCase_ = LocalizationLossOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LocalizationLossOneofCase LocalizationLossCase { + get { return localizationLossCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearLocalizationLoss() { + localizationLossCase_ = LocalizationLossOneofCase.None; + localizationLoss_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as LocalizationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(LocalizationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(WeightedL2, other.WeightedL2)) return false; + if (!object.Equals(WeightedSmoothL1, other.WeightedSmoothL1)) return false; + if (!object.Equals(WeightedIou, other.WeightedIou)) return false; + if (LocalizationLossCase != other.LocalizationLossCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedL2) hash ^= WeightedL2.GetHashCode(); + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedSmoothL1) hash ^= WeightedSmoothL1.GetHashCode(); + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedIou) hash ^= WeightedIou.GetHashCode(); + hash ^= (int) localizationLossCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedL2) { + output.WriteRawTag(10); + output.WriteMessage(WeightedL2); + } + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedSmoothL1) { + output.WriteRawTag(18); + output.WriteMessage(WeightedSmoothL1); + } + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedIou) { + output.WriteRawTag(26); + output.WriteMessage(WeightedIou); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedL2) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WeightedL2); + } + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedSmoothL1) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WeightedSmoothL1); + } + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedIou) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WeightedIou); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(LocalizationLoss other) { + if (other == null) { + return; + } + switch (other.LocalizationLossCase) { + case LocalizationLossOneofCase.WeightedL2: + if (WeightedL2 == null) { + WeightedL2 = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedL2LocalizationLoss(); + } + WeightedL2.MergeFrom(other.WeightedL2); + break; + case LocalizationLossOneofCase.WeightedSmoothL1: + if (WeightedSmoothL1 == null) { + WeightedSmoothL1 = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedSmoothL1LocalizationLoss(); + } + WeightedSmoothL1.MergeFrom(other.WeightedSmoothL1); + break; + case LocalizationLossOneofCase.WeightedIou: + if (WeightedIou == null) { + WeightedIou = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedIOULocalizationLoss(); + } + WeightedIou.MergeFrom(other.WeightedIou); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.WeightedL2LocalizationLoss subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedL2LocalizationLoss(); + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedL2) { + subBuilder.MergeFrom(WeightedL2); + } + input.ReadMessage(subBuilder); + WeightedL2 = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.WeightedSmoothL1LocalizationLoss subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedSmoothL1LocalizationLoss(); + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedSmoothL1) { + subBuilder.MergeFrom(WeightedSmoothL1); + } + input.ReadMessage(subBuilder); + WeightedSmoothL1 = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.WeightedIOULocalizationLoss subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedIOULocalizationLoss(); + if (localizationLossCase_ == LocalizationLossOneofCase.WeightedIou) { + subBuilder.MergeFrom(WeightedIou); + } + input.ReadMessage(subBuilder); + WeightedIou = subBuilder; + break; + } + } + } + } + + } + + /// + /// L2 location loss: 0.5 * ||weight * (a - b)|| ^ 2 + /// + public sealed partial class WeightedL2LocalizationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WeightedL2LocalizationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedL2LocalizationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedL2LocalizationLoss(WeightedL2LocalizationLoss other) : this() { + anchorwiseOutput_ = other.anchorwiseOutput_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedL2LocalizationLoss Clone() { + return new WeightedL2LocalizationLoss(this); + } + + /// Field number for the "anchorwise_output" field. + public const int AnchorwiseOutputFieldNumber = 1; + private bool anchorwiseOutput_; + /// + /// DEPRECATED, do not use. + /// Output loss per anchor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AnchorwiseOutput { + get { return anchorwiseOutput_; } + set { + anchorwiseOutput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as WeightedL2LocalizationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(WeightedL2LocalizationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AnchorwiseOutput != other.AnchorwiseOutput) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (AnchorwiseOutput != false) hash ^= AnchorwiseOutput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (AnchorwiseOutput != false) { + output.WriteRawTag(8); + output.WriteBool(AnchorwiseOutput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (AnchorwiseOutput != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(WeightedL2LocalizationLoss other) { + if (other == null) { + return; + } + if (other.AnchorwiseOutput != false) { + AnchorwiseOutput = other.AnchorwiseOutput; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + AnchorwiseOutput = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// SmoothL1 (Huber) location loss. + /// The smooth L1_loss is defined elementwise as .5 x^2 if |x| <= delta and + /// delta * (|x|-0.5*delta) otherwise, where x is the difference between + /// predictions and target. + /// + public sealed partial class WeightedSmoothL1LocalizationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WeightedSmoothL1LocalizationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSmoothL1LocalizationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSmoothL1LocalizationLoss(WeightedSmoothL1LocalizationLoss other) : this() { + anchorwiseOutput_ = other.anchorwiseOutput_; + delta_ = other.delta_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSmoothL1LocalizationLoss Clone() { + return new WeightedSmoothL1LocalizationLoss(this); + } + + /// Field number for the "anchorwise_output" field. + public const int AnchorwiseOutputFieldNumber = 1; + private bool anchorwiseOutput_; + /// + /// DEPRECATED, do not use. + /// Output loss per anchor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AnchorwiseOutput { + get { return anchorwiseOutput_; } + set { + anchorwiseOutput_ = value; + } + } + + /// Field number for the "delta" field. + public const int DeltaFieldNumber = 2; + private float delta_; + /// + /// Delta value for huber loss. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Delta { + get { return delta_; } + set { + delta_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as WeightedSmoothL1LocalizationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(WeightedSmoothL1LocalizationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AnchorwiseOutput != other.AnchorwiseOutput) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Delta, other.Delta)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (AnchorwiseOutput != false) hash ^= AnchorwiseOutput.GetHashCode(); + if (Delta != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Delta); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (AnchorwiseOutput != false) { + output.WriteRawTag(8); + output.WriteBool(AnchorwiseOutput); + } + if (Delta != 0F) { + output.WriteRawTag(21); + output.WriteFloat(Delta); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (AnchorwiseOutput != false) { + size += 1 + 1; + } + if (Delta != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(WeightedSmoothL1LocalizationLoss other) { + if (other == null) { + return; + } + if (other.AnchorwiseOutput != false) { + AnchorwiseOutput = other.AnchorwiseOutput; + } + if (other.Delta != 0F) { + Delta = other.Delta; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + AnchorwiseOutput = input.ReadBool(); + break; + } + case 21: { + Delta = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Intersection over union location loss: 1 - IOU + /// + public sealed partial class WeightedIOULocalizationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WeightedIOULocalizationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedIOULocalizationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedIOULocalizationLoss(WeightedIOULocalizationLoss other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedIOULocalizationLoss Clone() { + return new WeightedIOULocalizationLoss(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as WeightedIOULocalizationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(WeightedIOULocalizationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(WeightedIOULocalizationLoss other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + /// + /// Configuration for class prediction loss function. + /// + public sealed partial class ClassificationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ClassificationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassificationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassificationLoss(ClassificationLoss other) : this() { + switch (other.ClassificationLossCase) { + case ClassificationLossOneofCase.WeightedSigmoid: + WeightedSigmoid = other.WeightedSigmoid.Clone(); + break; + case ClassificationLossOneofCase.WeightedSoftmax: + WeightedSoftmax = other.WeightedSoftmax.Clone(); + break; + case ClassificationLossOneofCase.WeightedLogitsSoftmax: + WeightedLogitsSoftmax = other.WeightedLogitsSoftmax.Clone(); + break; + case ClassificationLossOneofCase.BootstrappedSigmoid: + BootstrappedSigmoid = other.BootstrappedSigmoid.Clone(); + break; + case ClassificationLossOneofCase.WeightedSigmoidFocal: + WeightedSigmoidFocal = other.WeightedSigmoidFocal.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassificationLoss Clone() { + return new ClassificationLoss(this); + } + + /// Field number for the "weighted_sigmoid" field. + public const int WeightedSigmoidFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.WeightedSigmoidClassificationLoss WeightedSigmoid { + get { return classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoid ? (global::Tensorflow.Models.ObjectDetection.Protos.WeightedSigmoidClassificationLoss) classificationLoss_ : null; } + set { + classificationLoss_ = value; + classificationLossCase_ = value == null ? ClassificationLossOneofCase.None : ClassificationLossOneofCase.WeightedSigmoid; + } + } + + /// Field number for the "weighted_softmax" field. + public const int WeightedSoftmaxFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationLoss WeightedSoftmax { + get { return classificationLossCase_ == ClassificationLossOneofCase.WeightedSoftmax ? (global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationLoss) classificationLoss_ : null; } + set { + classificationLoss_ = value; + classificationLossCase_ = value == null ? ClassificationLossOneofCase.None : ClassificationLossOneofCase.WeightedSoftmax; + } + } + + /// Field number for the "weighted_logits_softmax" field. + public const int WeightedLogitsSoftmaxFieldNumber = 5; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationAgainstLogitsLoss WeightedLogitsSoftmax { + get { return classificationLossCase_ == ClassificationLossOneofCase.WeightedLogitsSoftmax ? (global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationAgainstLogitsLoss) classificationLoss_ : null; } + set { + classificationLoss_ = value; + classificationLossCase_ = value == null ? ClassificationLossOneofCase.None : ClassificationLossOneofCase.WeightedLogitsSoftmax; + } + } + + /// Field number for the "bootstrapped_sigmoid" field. + public const int BootstrappedSigmoidFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.BootstrappedSigmoidClassificationLoss BootstrappedSigmoid { + get { return classificationLossCase_ == ClassificationLossOneofCase.BootstrappedSigmoid ? (global::Tensorflow.Models.ObjectDetection.Protos.BootstrappedSigmoidClassificationLoss) classificationLoss_ : null; } + set { + classificationLoss_ = value; + classificationLossCase_ = value == null ? ClassificationLossOneofCase.None : ClassificationLossOneofCase.BootstrappedSigmoid; + } + } + + /// Field number for the "weighted_sigmoid_focal" field. + public const int WeightedSigmoidFocalFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SigmoidFocalClassificationLoss WeightedSigmoidFocal { + get { return classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoidFocal ? (global::Tensorflow.Models.ObjectDetection.Protos.SigmoidFocalClassificationLoss) classificationLoss_ : null; } + set { + classificationLoss_ = value; + classificationLossCase_ = value == null ? ClassificationLossOneofCase.None : ClassificationLossOneofCase.WeightedSigmoidFocal; + } + } + + private object classificationLoss_; + /// Enum of possible cases for the "classification_loss" oneof. + public enum ClassificationLossOneofCase { + None = 0, + WeightedSigmoid = 1, + WeightedSoftmax = 2, + WeightedLogitsSoftmax = 5, + BootstrappedSigmoid = 3, + WeightedSigmoidFocal = 4, + } + private ClassificationLossOneofCase classificationLossCase_ = ClassificationLossOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ClassificationLossOneofCase ClassificationLossCase { + get { return classificationLossCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearClassificationLoss() { + classificationLossCase_ = ClassificationLossOneofCase.None; + classificationLoss_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ClassificationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ClassificationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(WeightedSigmoid, other.WeightedSigmoid)) return false; + if (!object.Equals(WeightedSoftmax, other.WeightedSoftmax)) return false; + if (!object.Equals(WeightedLogitsSoftmax, other.WeightedLogitsSoftmax)) return false; + if (!object.Equals(BootstrappedSigmoid, other.BootstrappedSigmoid)) return false; + if (!object.Equals(WeightedSigmoidFocal, other.WeightedSigmoidFocal)) return false; + if (ClassificationLossCase != other.ClassificationLossCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoid) hash ^= WeightedSigmoid.GetHashCode(); + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSoftmax) hash ^= WeightedSoftmax.GetHashCode(); + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedLogitsSoftmax) hash ^= WeightedLogitsSoftmax.GetHashCode(); + if (classificationLossCase_ == ClassificationLossOneofCase.BootstrappedSigmoid) hash ^= BootstrappedSigmoid.GetHashCode(); + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoidFocal) hash ^= WeightedSigmoidFocal.GetHashCode(); + hash ^= (int) classificationLossCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoid) { + output.WriteRawTag(10); + output.WriteMessage(WeightedSigmoid); + } + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSoftmax) { + output.WriteRawTag(18); + output.WriteMessage(WeightedSoftmax); + } + if (classificationLossCase_ == ClassificationLossOneofCase.BootstrappedSigmoid) { + output.WriteRawTag(26); + output.WriteMessage(BootstrappedSigmoid); + } + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoidFocal) { + output.WriteRawTag(34); + output.WriteMessage(WeightedSigmoidFocal); + } + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedLogitsSoftmax) { + output.WriteRawTag(42); + output.WriteMessage(WeightedLogitsSoftmax); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoid) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WeightedSigmoid); + } + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSoftmax) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WeightedSoftmax); + } + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedLogitsSoftmax) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WeightedLogitsSoftmax); + } + if (classificationLossCase_ == ClassificationLossOneofCase.BootstrappedSigmoid) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BootstrappedSigmoid); + } + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoidFocal) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WeightedSigmoidFocal); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ClassificationLoss other) { + if (other == null) { + return; + } + switch (other.ClassificationLossCase) { + case ClassificationLossOneofCase.WeightedSigmoid: + if (WeightedSigmoid == null) { + WeightedSigmoid = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedSigmoidClassificationLoss(); + } + WeightedSigmoid.MergeFrom(other.WeightedSigmoid); + break; + case ClassificationLossOneofCase.WeightedSoftmax: + if (WeightedSoftmax == null) { + WeightedSoftmax = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationLoss(); + } + WeightedSoftmax.MergeFrom(other.WeightedSoftmax); + break; + case ClassificationLossOneofCase.WeightedLogitsSoftmax: + if (WeightedLogitsSoftmax == null) { + WeightedLogitsSoftmax = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationAgainstLogitsLoss(); + } + WeightedLogitsSoftmax.MergeFrom(other.WeightedLogitsSoftmax); + break; + case ClassificationLossOneofCase.BootstrappedSigmoid: + if (BootstrappedSigmoid == null) { + BootstrappedSigmoid = new global::Tensorflow.Models.ObjectDetection.Protos.BootstrappedSigmoidClassificationLoss(); + } + BootstrappedSigmoid.MergeFrom(other.BootstrappedSigmoid); + break; + case ClassificationLossOneofCase.WeightedSigmoidFocal: + if (WeightedSigmoidFocal == null) { + WeightedSigmoidFocal = new global::Tensorflow.Models.ObjectDetection.Protos.SigmoidFocalClassificationLoss(); + } + WeightedSigmoidFocal.MergeFrom(other.WeightedSigmoidFocal); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.WeightedSigmoidClassificationLoss subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedSigmoidClassificationLoss(); + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoid) { + subBuilder.MergeFrom(WeightedSigmoid); + } + input.ReadMessage(subBuilder); + WeightedSigmoid = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationLoss subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationLoss(); + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSoftmax) { + subBuilder.MergeFrom(WeightedSoftmax); + } + input.ReadMessage(subBuilder); + WeightedSoftmax = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.BootstrappedSigmoidClassificationLoss subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.BootstrappedSigmoidClassificationLoss(); + if (classificationLossCase_ == ClassificationLossOneofCase.BootstrappedSigmoid) { + subBuilder.MergeFrom(BootstrappedSigmoid); + } + input.ReadMessage(subBuilder); + BootstrappedSigmoid = subBuilder; + break; + } + case 34: { + global::Tensorflow.Models.ObjectDetection.Protos.SigmoidFocalClassificationLoss subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.SigmoidFocalClassificationLoss(); + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedSigmoidFocal) { + subBuilder.MergeFrom(WeightedSigmoidFocal); + } + input.ReadMessage(subBuilder); + WeightedSigmoidFocal = subBuilder; + break; + } + case 42: { + global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationAgainstLogitsLoss subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.WeightedSoftmaxClassificationAgainstLogitsLoss(); + if (classificationLossCase_ == ClassificationLossOneofCase.WeightedLogitsSoftmax) { + subBuilder.MergeFrom(WeightedLogitsSoftmax); + } + input.ReadMessage(subBuilder); + WeightedLogitsSoftmax = subBuilder; + break; + } + } + } + } + + } + + /// + /// Classification loss using a sigmoid function over class predictions. + /// + public sealed partial class WeightedSigmoidClassificationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WeightedSigmoidClassificationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSigmoidClassificationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSigmoidClassificationLoss(WeightedSigmoidClassificationLoss other) : this() { + anchorwiseOutput_ = other.anchorwiseOutput_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSigmoidClassificationLoss Clone() { + return new WeightedSigmoidClassificationLoss(this); + } + + /// Field number for the "anchorwise_output" field. + public const int AnchorwiseOutputFieldNumber = 1; + private bool anchorwiseOutput_; + /// + /// DEPRECATED, do not use. + /// Output loss per anchor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AnchorwiseOutput { + get { return anchorwiseOutput_; } + set { + anchorwiseOutput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as WeightedSigmoidClassificationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(WeightedSigmoidClassificationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AnchorwiseOutput != other.AnchorwiseOutput) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (AnchorwiseOutput != false) hash ^= AnchorwiseOutput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (AnchorwiseOutput != false) { + output.WriteRawTag(8); + output.WriteBool(AnchorwiseOutput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (AnchorwiseOutput != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(WeightedSigmoidClassificationLoss other) { + if (other == null) { + return; + } + if (other.AnchorwiseOutput != false) { + AnchorwiseOutput = other.AnchorwiseOutput; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + AnchorwiseOutput = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Sigmoid Focal cross entropy loss as described in + /// https://arxiv.org/abs/1708.02002 + /// + public sealed partial class SigmoidFocalClassificationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SigmoidFocalClassificationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SigmoidFocalClassificationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SigmoidFocalClassificationLoss(SigmoidFocalClassificationLoss other) : this() { + anchorwiseOutput_ = other.anchorwiseOutput_; + gamma_ = other.gamma_; + alpha_ = other.alpha_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SigmoidFocalClassificationLoss Clone() { + return new SigmoidFocalClassificationLoss(this); + } + + /// Field number for the "anchorwise_output" field. + public const int AnchorwiseOutputFieldNumber = 1; + private bool anchorwiseOutput_; + /// + /// DEPRECATED, do not use. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AnchorwiseOutput { + get { return anchorwiseOutput_; } + set { + anchorwiseOutput_ = value; + } + } + + /// Field number for the "gamma" field. + public const int GammaFieldNumber = 2; + private float gamma_; + /// + /// modulating factor for the loss. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Gamma { + get { return gamma_; } + set { + gamma_ = value; + } + } + + /// Field number for the "alpha" field. + public const int AlphaFieldNumber = 3; + private float alpha_; + /// + /// alpha weighting factor for the loss. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Alpha { + get { return alpha_; } + set { + alpha_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SigmoidFocalClassificationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SigmoidFocalClassificationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AnchorwiseOutput != other.AnchorwiseOutput) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Gamma, other.Gamma)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Alpha, other.Alpha)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (AnchorwiseOutput != false) hash ^= AnchorwiseOutput.GetHashCode(); + if (Gamma != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Gamma); + if (Alpha != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Alpha); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (AnchorwiseOutput != false) { + output.WriteRawTag(8); + output.WriteBool(AnchorwiseOutput); + } + if (Gamma != 0F) { + output.WriteRawTag(21); + output.WriteFloat(Gamma); + } + if (Alpha != 0F) { + output.WriteRawTag(29); + output.WriteFloat(Alpha); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (AnchorwiseOutput != false) { + size += 1 + 1; + } + if (Gamma != 0F) { + size += 1 + 4; + } + if (Alpha != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SigmoidFocalClassificationLoss other) { + if (other == null) { + return; + } + if (other.AnchorwiseOutput != false) { + AnchorwiseOutput = other.AnchorwiseOutput; + } + if (other.Gamma != 0F) { + Gamma = other.Gamma; + } + if (other.Alpha != 0F) { + Alpha = other.Alpha; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + AnchorwiseOutput = input.ReadBool(); + break; + } + case 21: { + Gamma = input.ReadFloat(); + break; + } + case 29: { + Alpha = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Classification loss using a softmax function over class predictions. + /// + public sealed partial class WeightedSoftmaxClassificationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WeightedSoftmaxClassificationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSoftmaxClassificationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSoftmaxClassificationLoss(WeightedSoftmaxClassificationLoss other) : this() { + anchorwiseOutput_ = other.anchorwiseOutput_; + logitScale_ = other.logitScale_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSoftmaxClassificationLoss Clone() { + return new WeightedSoftmaxClassificationLoss(this); + } + + /// Field number for the "anchorwise_output" field. + public const int AnchorwiseOutputFieldNumber = 1; + private bool anchorwiseOutput_; + /// + /// DEPRECATED, do not use. + /// Output loss per anchor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AnchorwiseOutput { + get { return anchorwiseOutput_; } + set { + anchorwiseOutput_ = value; + } + } + + /// Field number for the "logit_scale" field. + public const int LogitScaleFieldNumber = 2; + private float logitScale_; + /// + /// Scale logit (input) value before calculating softmax classification loss. + /// Typically used for softmax distillation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float LogitScale { + get { return logitScale_; } + set { + logitScale_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as WeightedSoftmaxClassificationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(WeightedSoftmaxClassificationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AnchorwiseOutput != other.AnchorwiseOutput) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(LogitScale, other.LogitScale)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (AnchorwiseOutput != false) hash ^= AnchorwiseOutput.GetHashCode(); + if (LogitScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(LogitScale); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (AnchorwiseOutput != false) { + output.WriteRawTag(8); + output.WriteBool(AnchorwiseOutput); + } + if (LogitScale != 0F) { + output.WriteRawTag(21); + output.WriteFloat(LogitScale); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (AnchorwiseOutput != false) { + size += 1 + 1; + } + if (LogitScale != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(WeightedSoftmaxClassificationLoss other) { + if (other == null) { + return; + } + if (other.AnchorwiseOutput != false) { + AnchorwiseOutput = other.AnchorwiseOutput; + } + if (other.LogitScale != 0F) { + LogitScale = other.LogitScale; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + AnchorwiseOutput = input.ReadBool(); + break; + } + case 21: { + LogitScale = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Classification loss using a softmax function over class predictions and + /// a softmax function over the groundtruth labels (assumed to be logits). + /// + public sealed partial class WeightedSoftmaxClassificationAgainstLogitsLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WeightedSoftmaxClassificationAgainstLogitsLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSoftmaxClassificationAgainstLogitsLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSoftmaxClassificationAgainstLogitsLoss(WeightedSoftmaxClassificationAgainstLogitsLoss other) : this() { + anchorwiseOutput_ = other.anchorwiseOutput_; + logitScale_ = other.logitScale_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WeightedSoftmaxClassificationAgainstLogitsLoss Clone() { + return new WeightedSoftmaxClassificationAgainstLogitsLoss(this); + } + + /// Field number for the "anchorwise_output" field. + public const int AnchorwiseOutputFieldNumber = 1; + private bool anchorwiseOutput_; + /// + /// DEPRECATED, do not use. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AnchorwiseOutput { + get { return anchorwiseOutput_; } + set { + anchorwiseOutput_ = value; + } + } + + /// Field number for the "logit_scale" field. + public const int LogitScaleFieldNumber = 2; + private float logitScale_; + /// + /// Scale and softmax groundtruth logits before calculating softmax + /// classification loss. Typically used for softmax distillation with teacher + /// annotations stored as logits. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float LogitScale { + get { return logitScale_; } + set { + logitScale_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as WeightedSoftmaxClassificationAgainstLogitsLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(WeightedSoftmaxClassificationAgainstLogitsLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (AnchorwiseOutput != other.AnchorwiseOutput) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(LogitScale, other.LogitScale)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (AnchorwiseOutput != false) hash ^= AnchorwiseOutput.GetHashCode(); + if (LogitScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(LogitScale); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (AnchorwiseOutput != false) { + output.WriteRawTag(8); + output.WriteBool(AnchorwiseOutput); + } + if (LogitScale != 0F) { + output.WriteRawTag(21); + output.WriteFloat(LogitScale); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (AnchorwiseOutput != false) { + size += 1 + 1; + } + if (LogitScale != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(WeightedSoftmaxClassificationAgainstLogitsLoss other) { + if (other == null) { + return; + } + if (other.AnchorwiseOutput != false) { + AnchorwiseOutput = other.AnchorwiseOutput; + } + if (other.LogitScale != 0F) { + LogitScale = other.LogitScale; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + AnchorwiseOutput = input.ReadBool(); + break; + } + case 21: { + LogitScale = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Classification loss using a sigmoid function over the class prediction with + /// the highest prediction score. + /// + public sealed partial class BootstrappedSigmoidClassificationLoss : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BootstrappedSigmoidClassificationLoss()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[10]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BootstrappedSigmoidClassificationLoss() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BootstrappedSigmoidClassificationLoss(BootstrappedSigmoidClassificationLoss other) : this() { + alpha_ = other.alpha_; + hardBootstrap_ = other.hardBootstrap_; + anchorwiseOutput_ = other.anchorwiseOutput_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BootstrappedSigmoidClassificationLoss Clone() { + return new BootstrappedSigmoidClassificationLoss(this); + } + + /// Field number for the "alpha" field. + public const int AlphaFieldNumber = 1; + private float alpha_; + /// + /// Interpolation weight between 0 and 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Alpha { + get { return alpha_; } + set { + alpha_ = value; + } + } + + /// Field number for the "hard_bootstrap" field. + public const int HardBootstrapFieldNumber = 2; + private bool hardBootstrap_; + /// + /// Whether hard boot strapping should be used or not. If true, will only use + /// one class favored by model. Othewise, will use all predicted class + /// probabilities. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool HardBootstrap { + get { return hardBootstrap_; } + set { + hardBootstrap_ = value; + } + } + + /// Field number for the "anchorwise_output" field. + public const int AnchorwiseOutputFieldNumber = 3; + private bool anchorwiseOutput_; + /// + /// DEPRECATED, do not use. + /// Output loss per anchor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AnchorwiseOutput { + get { return anchorwiseOutput_; } + set { + anchorwiseOutput_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BootstrappedSigmoidClassificationLoss); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BootstrappedSigmoidClassificationLoss other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Alpha, other.Alpha)) return false; + if (HardBootstrap != other.HardBootstrap) return false; + if (AnchorwiseOutput != other.AnchorwiseOutput) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Alpha != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Alpha); + if (HardBootstrap != false) hash ^= HardBootstrap.GetHashCode(); + if (AnchorwiseOutput != false) hash ^= AnchorwiseOutput.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Alpha != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Alpha); + } + if (HardBootstrap != false) { + output.WriteRawTag(16); + output.WriteBool(HardBootstrap); + } + if (AnchorwiseOutput != false) { + output.WriteRawTag(24); + output.WriteBool(AnchorwiseOutput); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Alpha != 0F) { + size += 1 + 4; + } + if (HardBootstrap != false) { + size += 1 + 1; + } + if (AnchorwiseOutput != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BootstrappedSigmoidClassificationLoss other) { + if (other == null) { + return; + } + if (other.Alpha != 0F) { + Alpha = other.Alpha; + } + if (other.HardBootstrap != false) { + HardBootstrap = other.HardBootstrap; + } + if (other.AnchorwiseOutput != false) { + AnchorwiseOutput = other.AnchorwiseOutput; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Alpha = input.ReadFloat(); + break; + } + case 16: { + HardBootstrap = input.ReadBool(); + break; + } + case 24: { + AnchorwiseOutput = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Configuration for hard example miner. + /// + public sealed partial class HardExampleMiner : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new HardExampleMiner()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[11]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public HardExampleMiner() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public HardExampleMiner(HardExampleMiner other) : this() { + numHardExamples_ = other.numHardExamples_; + iouThreshold_ = other.iouThreshold_; + lossType_ = other.lossType_; + maxNegativesPerPositive_ = other.maxNegativesPerPositive_; + minNegativesPerImage_ = other.minNegativesPerImage_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public HardExampleMiner Clone() { + return new HardExampleMiner(this); + } + + /// Field number for the "num_hard_examples" field. + public const int NumHardExamplesFieldNumber = 1; + private int numHardExamples_; + /// + /// Maximum number of hard examples to be selected per image (prior to + /// enforcing max negative to positive ratio constraint). If set to 0, + /// all examples obtained after NMS are considered. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumHardExamples { + get { return numHardExamples_; } + set { + numHardExamples_ = value; + } + } + + /// Field number for the "iou_threshold" field. + public const int IouThresholdFieldNumber = 2; + private float iouThreshold_; + /// + /// Minimum intersection over union for an example to be discarded during NMS. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float IouThreshold { + get { return iouThreshold_; } + set { + iouThreshold_ = value; + } + } + + /// Field number for the "loss_type" field. + public const int LossTypeFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner.Types.LossType lossType_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner.Types.LossType LossType { + get { return lossType_; } + set { + lossType_ = value; + } + } + + /// Field number for the "max_negatives_per_positive" field. + public const int MaxNegativesPerPositiveFieldNumber = 4; + private int maxNegativesPerPositive_; + /// + /// Maximum number of negatives to retain for each positive anchor. If + /// num_negatives_per_positive is 0 no prespecified negative:positive ratio is + /// enforced. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxNegativesPerPositive { + get { return maxNegativesPerPositive_; } + set { + maxNegativesPerPositive_ = value; + } + } + + /// Field number for the "min_negatives_per_image" field. + public const int MinNegativesPerImageFieldNumber = 5; + private int minNegativesPerImage_; + /// + /// Minimum number of negative anchors to sample for a given image. Setting + /// this to a positive number samples negatives in an image without any + /// positive anchors and thus not bias the model towards having at least one + /// detection per image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinNegativesPerImage { + get { return minNegativesPerImage_; } + set { + minNegativesPerImage_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as HardExampleMiner); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(HardExampleMiner other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NumHardExamples != other.NumHardExamples) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(IouThreshold, other.IouThreshold)) return false; + if (LossType != other.LossType) return false; + if (MaxNegativesPerPositive != other.MaxNegativesPerPositive) return false; + if (MinNegativesPerImage != other.MinNegativesPerImage) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (NumHardExamples != 0) hash ^= NumHardExamples.GetHashCode(); + if (IouThreshold != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(IouThreshold); + if (LossType != 0) hash ^= LossType.GetHashCode(); + if (MaxNegativesPerPositive != 0) hash ^= MaxNegativesPerPositive.GetHashCode(); + if (MinNegativesPerImage != 0) hash ^= MinNegativesPerImage.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (NumHardExamples != 0) { + output.WriteRawTag(8); + output.WriteInt32(NumHardExamples); + } + if (IouThreshold != 0F) { + output.WriteRawTag(21); + output.WriteFloat(IouThreshold); + } + if (LossType != 0) { + output.WriteRawTag(24); + output.WriteEnum((int) LossType); + } + if (MaxNegativesPerPositive != 0) { + output.WriteRawTag(32); + output.WriteInt32(MaxNegativesPerPositive); + } + if (MinNegativesPerImage != 0) { + output.WriteRawTag(40); + output.WriteInt32(MinNegativesPerImage); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (NumHardExamples != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumHardExamples); + } + if (IouThreshold != 0F) { + size += 1 + 4; + } + if (LossType != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) LossType); + } + if (MaxNegativesPerPositive != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxNegativesPerPositive); + } + if (MinNegativesPerImage != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinNegativesPerImage); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(HardExampleMiner other) { + if (other == null) { + return; + } + if (other.NumHardExamples != 0) { + NumHardExamples = other.NumHardExamples; + } + if (other.IouThreshold != 0F) { + IouThreshold = other.IouThreshold; + } + if (other.LossType != 0) { + LossType = other.LossType; + } + if (other.MaxNegativesPerPositive != 0) { + MaxNegativesPerPositive = other.MaxNegativesPerPositive; + } + if (other.MinNegativesPerImage != 0) { + MinNegativesPerImage = other.MinNegativesPerImage; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NumHardExamples = input.ReadInt32(); + break; + } + case 21: { + IouThreshold = input.ReadFloat(); + break; + } + case 24: { + lossType_ = (global::Tensorflow.Models.ObjectDetection.Protos.HardExampleMiner.Types.LossType) input.ReadEnum(); + break; + } + case 32: { + MaxNegativesPerPositive = input.ReadInt32(); + break; + } + case 40: { + MinNegativesPerImage = input.ReadInt32(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the HardExampleMiner message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// Whether to use classification losses ('cls', default), localization losses + /// ('loc') or both losses ('both'). In the case of 'both', cls_loss_weight and + /// loc_loss_weight are used to compute weighted sum of the two losses. + /// + public enum LossType { + [pbr::OriginalName("BOTH")] Both = 0, + [pbr::OriginalName("CLASSIFICATION")] Classification = 1, + [pbr::OriginalName("LOCALIZATION")] Localization = 2, + } + + } + #endregion + + } + + /// + /// Configuration for random example sampler. + /// + public sealed partial class RandomExampleSampler : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomExampleSampler()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor.MessageTypes[12]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomExampleSampler() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomExampleSampler(RandomExampleSampler other) : this() { + positiveSampleFraction_ = other.positiveSampleFraction_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomExampleSampler Clone() { + return new RandomExampleSampler(this); + } + + /// Field number for the "positive_sample_fraction" field. + public const int PositiveSampleFractionFieldNumber = 1; + private float positiveSampleFraction_; + /// + /// The desired fraction of positive samples in batch when applying random + /// example sampling. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float PositiveSampleFraction { + get { return positiveSampleFraction_; } + set { + positiveSampleFraction_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomExampleSampler); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomExampleSampler other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(PositiveSampleFraction, other.PositiveSampleFraction)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (PositiveSampleFraction != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(PositiveSampleFraction); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (PositiveSampleFraction != 0F) { + output.WriteRawTag(13); + output.WriteFloat(PositiveSampleFraction); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (PositiveSampleFraction != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomExampleSampler other) { + if (other == null) { + return; + } + if (other.PositiveSampleFraction != 0F) { + PositiveSampleFraction = other.PositiveSampleFraction; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + PositiveSampleFraction = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Matcher.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Matcher.cs new file mode 100644 index 00000000..72aa0188 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Matcher.cs @@ -0,0 +1,257 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/matcher.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/matcher.proto + public static partial class MatcherReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/matcher.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static MatcherReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiVvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9tYXRjaGVyLnByb3RvEhdvYmpl", + "Y3RfZGV0ZWN0aW9uLnByb3Rvcxosb2JqZWN0X2RldGVjdGlvbi9wcm90b3Mv", + "YXJnbWF4X21hdGNoZXIucHJvdG8aL29iamVjdF9kZXRlY3Rpb24vcHJvdG9z", + "L2JpcGFydGl0ZV9tYXRjaGVyLnByb3RvIqQBCgdNYXRjaGVyEkAKDmFyZ21h", + "eF9tYXRjaGVyGAEgASgLMiYub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuQXJn", + "TWF4TWF0Y2hlckgAEkYKEWJpcGFydGl0ZV9tYXRjaGVyGAIgASgLMikub2Jq", + "ZWN0X2RldGVjdGlvbi5wcm90b3MuQmlwYXJ0aXRlTWF0Y2hlckgAQg8KDW1h", + "dGNoZXJfb25lb2ZiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.ArgmaxMatcherReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.BipartiteMatcherReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Matcher), global::Tensorflow.Models.ObjectDetection.Protos.Matcher.Parser, new[]{ "ArgmaxMatcher", "BipartiteMatcher" }, new[]{ "MatcherOneof" }, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for the matcher to be used in the object detection + /// pipeline. See core/matcher.py for details. + /// + public sealed partial class Matcher : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Matcher()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.MatcherReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Matcher() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Matcher(Matcher other) : this() { + switch (other.MatcherOneofCase) { + case MatcherOneofOneofCase.ArgmaxMatcher: + ArgmaxMatcher = other.ArgmaxMatcher.Clone(); + break; + case MatcherOneofOneofCase.BipartiteMatcher: + BipartiteMatcher = other.BipartiteMatcher.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Matcher Clone() { + return new Matcher(this); + } + + /// Field number for the "argmax_matcher" field. + public const int ArgmaxMatcherFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ArgMaxMatcher ArgmaxMatcher { + get { return matcherOneofCase_ == MatcherOneofOneofCase.ArgmaxMatcher ? (global::Tensorflow.Models.ObjectDetection.Protos.ArgMaxMatcher) matcherOneof_ : null; } + set { + matcherOneof_ = value; + matcherOneofCase_ = value == null ? MatcherOneofOneofCase.None : MatcherOneofOneofCase.ArgmaxMatcher; + } + } + + /// Field number for the "bipartite_matcher" field. + public const int BipartiteMatcherFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.BipartiteMatcher BipartiteMatcher { + get { return matcherOneofCase_ == MatcherOneofOneofCase.BipartiteMatcher ? (global::Tensorflow.Models.ObjectDetection.Protos.BipartiteMatcher) matcherOneof_ : null; } + set { + matcherOneof_ = value; + matcherOneofCase_ = value == null ? MatcherOneofOneofCase.None : MatcherOneofOneofCase.BipartiteMatcher; + } + } + + private object matcherOneof_; + /// Enum of possible cases for the "matcher_oneof" oneof. + public enum MatcherOneofOneofCase { + None = 0, + ArgmaxMatcher = 1, + BipartiteMatcher = 2, + } + private MatcherOneofOneofCase matcherOneofCase_ = MatcherOneofOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MatcherOneofOneofCase MatcherOneofCase { + get { return matcherOneofCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearMatcherOneof() { + matcherOneofCase_ = MatcherOneofOneofCase.None; + matcherOneof_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Matcher); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Matcher other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(ArgmaxMatcher, other.ArgmaxMatcher)) return false; + if (!object.Equals(BipartiteMatcher, other.BipartiteMatcher)) return false; + if (MatcherOneofCase != other.MatcherOneofCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (matcherOneofCase_ == MatcherOneofOneofCase.ArgmaxMatcher) hash ^= ArgmaxMatcher.GetHashCode(); + if (matcherOneofCase_ == MatcherOneofOneofCase.BipartiteMatcher) hash ^= BipartiteMatcher.GetHashCode(); + hash ^= (int) matcherOneofCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (matcherOneofCase_ == MatcherOneofOneofCase.ArgmaxMatcher) { + output.WriteRawTag(10); + output.WriteMessage(ArgmaxMatcher); + } + if (matcherOneofCase_ == MatcherOneofOneofCase.BipartiteMatcher) { + output.WriteRawTag(18); + output.WriteMessage(BipartiteMatcher); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (matcherOneofCase_ == MatcherOneofOneofCase.ArgmaxMatcher) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ArgmaxMatcher); + } + if (matcherOneofCase_ == MatcherOneofOneofCase.BipartiteMatcher) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BipartiteMatcher); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Matcher other) { + if (other == null) { + return; + } + switch (other.MatcherOneofCase) { + case MatcherOneofOneofCase.ArgmaxMatcher: + if (ArgmaxMatcher == null) { + ArgmaxMatcher = new global::Tensorflow.Models.ObjectDetection.Protos.ArgMaxMatcher(); + } + ArgmaxMatcher.MergeFrom(other.ArgmaxMatcher); + break; + case MatcherOneofOneofCase.BipartiteMatcher: + if (BipartiteMatcher == null) { + BipartiteMatcher = new global::Tensorflow.Models.ObjectDetection.Protos.BipartiteMatcher(); + } + BipartiteMatcher.MergeFrom(other.BipartiteMatcher); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.ArgMaxMatcher subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ArgMaxMatcher(); + if (matcherOneofCase_ == MatcherOneofOneofCase.ArgmaxMatcher) { + subBuilder.MergeFrom(ArgmaxMatcher); + } + input.ReadMessage(subBuilder); + ArgmaxMatcher = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.BipartiteMatcher subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.BipartiteMatcher(); + if (matcherOneofCase_ == MatcherOneofOneofCase.BipartiteMatcher) { + subBuilder.MergeFrom(BipartiteMatcher); + } + input.ReadMessage(subBuilder); + BipartiteMatcher = subBuilder; + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/MeanStddevBoxCoder.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/MeanStddevBoxCoder.cs new file mode 100644 index 00000000..a5cf55a7 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/MeanStddevBoxCoder.cs @@ -0,0 +1,180 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/mean_stddev_box_coder.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/mean_stddev_box_coder.proto + public static partial class MeanStddevBoxCoderReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/mean_stddev_box_coder.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static MeanStddevBoxCoderReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjNvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9tZWFuX3N0ZGRldl9ib3hfY29k", + "ZXIucHJvdG8SF29iamVjdF9kZXRlY3Rpb24ucHJvdG9zIiQKEk1lYW5TdGRk", + "ZXZCb3hDb2RlchIOCgZzdGRkZXYYASABKAJiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.MeanStddevBoxCoder), global::Tensorflow.Models.ObjectDetection.Protos.MeanStddevBoxCoder.Parser, new[]{ "Stddev" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for MeanStddevBoxCoder. See + /// box_coders/mean_stddev_box_coder.py for details. + /// + public sealed partial class MeanStddevBoxCoder : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MeanStddevBoxCoder()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.MeanStddevBoxCoderReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MeanStddevBoxCoder() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MeanStddevBoxCoder(MeanStddevBoxCoder other) : this() { + stddev_ = other.stddev_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MeanStddevBoxCoder Clone() { + return new MeanStddevBoxCoder(this); + } + + /// Field number for the "stddev" field. + public const int StddevFieldNumber = 1; + private float stddev_; + /// + /// The standard deviation used to encode and decode boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Stddev { + get { return stddev_; } + set { + stddev_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as MeanStddevBoxCoder); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(MeanStddevBoxCoder other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Stddev, other.Stddev)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Stddev != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Stddev); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Stddev != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Stddev); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Stddev != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(MeanStddevBoxCoder other) { + if (other == null) { + return; + } + if (other.Stddev != 0F) { + Stddev = other.Stddev; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Stddev = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Model.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Model.cs new file mode 100644 index 00000000..11001ac6 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Model.cs @@ -0,0 +1,255 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/model.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/model.proto + public static partial class ModelReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/model.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ModelReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiNvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9tb2RlbC5wcm90bxIXb2JqZWN0", + "X2RldGVjdGlvbi5wcm90b3MaKW9iamVjdF9kZXRlY3Rpb24vcHJvdG9zL2Zh", + "c3Rlcl9yY25uLnByb3RvGiFvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9zc2Qu", + "cHJvdG8iggEKDkRldGVjdGlvbk1vZGVsEjoKC2Zhc3Rlcl9yY25uGAEgASgL", + "MiMub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuRmFzdGVyUmNubkgAEisKA3Nz", + "ZBgCIAEoCzIcLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlNzZEgAQgcKBW1v", + "ZGVsYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnnReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.SsdReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.DetectionModel), global::Tensorflow.Models.ObjectDetection.Protos.DetectionModel.Parser, new[]{ "FasterRcnn", "Ssd" }, new[]{ "Model" }, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Top level configuration for DetectionModels. + /// + public sealed partial class DetectionModel : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DetectionModel()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.ModelReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public DetectionModel() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public DetectionModel(DetectionModel other) : this() { + switch (other.ModelCase) { + case ModelOneofCase.FasterRcnn: + FasterRcnn = other.FasterRcnn.Clone(); + break; + case ModelOneofCase.Ssd: + Ssd = other.Ssd.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public DetectionModel Clone() { + return new DetectionModel(this); + } + + /// Field number for the "faster_rcnn" field. + public const int FasterRcnnFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnn FasterRcnn { + get { return modelCase_ == ModelOneofCase.FasterRcnn ? (global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnn) model_ : null; } + set { + model_ = value; + modelCase_ = value == null ? ModelOneofCase.None : ModelOneofCase.FasterRcnn; + } + } + + /// Field number for the "ssd" field. + public const int SsdFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Ssd Ssd { + get { return modelCase_ == ModelOneofCase.Ssd ? (global::Tensorflow.Models.ObjectDetection.Protos.Ssd) model_ : null; } + set { + model_ = value; + modelCase_ = value == null ? ModelOneofCase.None : ModelOneofCase.Ssd; + } + } + + private object model_; + /// Enum of possible cases for the "model" oneof. + public enum ModelOneofCase { + None = 0, + FasterRcnn = 1, + Ssd = 2, + } + private ModelOneofCase modelCase_ = ModelOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ModelOneofCase ModelCase { + get { return modelCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearModel() { + modelCase_ = ModelOneofCase.None; + model_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as DetectionModel); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(DetectionModel other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(FasterRcnn, other.FasterRcnn)) return false; + if (!object.Equals(Ssd, other.Ssd)) return false; + if (ModelCase != other.ModelCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (modelCase_ == ModelOneofCase.FasterRcnn) hash ^= FasterRcnn.GetHashCode(); + if (modelCase_ == ModelOneofCase.Ssd) hash ^= Ssd.GetHashCode(); + hash ^= (int) modelCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (modelCase_ == ModelOneofCase.FasterRcnn) { + output.WriteRawTag(10); + output.WriteMessage(FasterRcnn); + } + if (modelCase_ == ModelOneofCase.Ssd) { + output.WriteRawTag(18); + output.WriteMessage(Ssd); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (modelCase_ == ModelOneofCase.FasterRcnn) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FasterRcnn); + } + if (modelCase_ == ModelOneofCase.Ssd) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Ssd); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(DetectionModel other) { + if (other == null) { + return; + } + switch (other.ModelCase) { + case ModelOneofCase.FasterRcnn: + if (FasterRcnn == null) { + FasterRcnn = new global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnn(); + } + FasterRcnn.MergeFrom(other.FasterRcnn); + break; + case ModelOneofCase.Ssd: + if (Ssd == null) { + Ssd = new global::Tensorflow.Models.ObjectDetection.Protos.Ssd(); + } + Ssd.MergeFrom(other.Ssd); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnn subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.FasterRcnn(); + if (modelCase_ == ModelOneofCase.FasterRcnn) { + subBuilder.MergeFrom(FasterRcnn); + } + input.ReadMessage(subBuilder); + FasterRcnn = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.Ssd subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.Ssd(); + if (modelCase_ == ModelOneofCase.Ssd) { + subBuilder.MergeFrom(Ssd); + } + input.ReadMessage(subBuilder); + Ssd = subBuilder; + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/MultiscaleAnchorGenerator.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/MultiscaleAnchorGenerator.cs new file mode 100644 index 00000000..b4d577fc --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/MultiscaleAnchorGenerator.cs @@ -0,0 +1,332 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/multiscale_anchor_generator.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/multiscale_anchor_generator.proto + public static partial class MultiscaleAnchorGeneratorReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/multiscale_anchor_generator.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static MultiscaleAnchorGeneratorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjlvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9tdWx0aXNjYWxlX2FuY2hvcl9n", + "ZW5lcmF0b3IucHJvdG8SF29iamVjdF9kZXRlY3Rpb24ucHJvdG9zIqgBChlN", + "dWx0aXNjYWxlQW5jaG9yR2VuZXJhdG9yEhEKCW1pbl9sZXZlbBgBIAEoBRIR", + "CgltYXhfbGV2ZWwYAiABKAUSFAoMYW5jaG9yX3NjYWxlGAMgASgCEhUKDWFz", + "cGVjdF9yYXRpb3MYBCADKAISGQoRc2NhbGVzX3Blcl9vY3RhdmUYBSABKAUS", + "HQoVbm9ybWFsaXplX2Nvb3JkaW5hdGVzGAYgASgIYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.MultiscaleAnchorGenerator), global::Tensorflow.Models.ObjectDetection.Protos.MultiscaleAnchorGenerator.Parser, new[]{ "MinLevel", "MaxLevel", "AnchorScale", "AspectRatios", "ScalesPerOctave", "NormalizeCoordinates" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for RetinaNet anchor generator described in + /// https://arxiv.org/abs/1708.02002. See + /// anchor_generators/multiscale_grid_anchor_generator.py for details. + /// + public sealed partial class MultiscaleAnchorGenerator : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MultiscaleAnchorGenerator()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.MultiscaleAnchorGeneratorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MultiscaleAnchorGenerator() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MultiscaleAnchorGenerator(MultiscaleAnchorGenerator other) : this() { + minLevel_ = other.minLevel_; + maxLevel_ = other.maxLevel_; + anchorScale_ = other.anchorScale_; + aspectRatios_ = other.aspectRatios_.Clone(); + scalesPerOctave_ = other.scalesPerOctave_; + normalizeCoordinates_ = other.normalizeCoordinates_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MultiscaleAnchorGenerator Clone() { + return new MultiscaleAnchorGenerator(this); + } + + /// Field number for the "min_level" field. + public const int MinLevelFieldNumber = 1; + private int minLevel_; + /// + /// minimum level in feature pyramid + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinLevel { + get { return minLevel_; } + set { + minLevel_ = value; + } + } + + /// Field number for the "max_level" field. + public const int MaxLevelFieldNumber = 2; + private int maxLevel_; + /// + /// maximum level in feature pyramid + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxLevel { + get { return maxLevel_; } + set { + maxLevel_ = value; + } + } + + /// Field number for the "anchor_scale" field. + public const int AnchorScaleFieldNumber = 3; + private float anchorScale_; + /// + /// Scale of anchor to feature stride + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float AnchorScale { + get { return anchorScale_; } + set { + anchorScale_ = value; + } + } + + /// Field number for the "aspect_ratios" field. + public const int AspectRatiosFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_aspectRatios_codec + = pb::FieldCodec.ForFloat(34); + private readonly pbc::RepeatedField aspectRatios_ = new pbc::RepeatedField(); + /// + /// Aspect ratios for anchors at each grid point. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField AspectRatios { + get { return aspectRatios_; } + } + + /// Field number for the "scales_per_octave" field. + public const int ScalesPerOctaveFieldNumber = 5; + private int scalesPerOctave_; + /// + /// Number of intermediate scale each scale octave + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int ScalesPerOctave { + get { return scalesPerOctave_; } + set { + scalesPerOctave_ = value; + } + } + + /// Field number for the "normalize_coordinates" field. + public const int NormalizeCoordinatesFieldNumber = 6; + private bool normalizeCoordinates_; + /// + /// Whether to produce anchors in normalized coordinates. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool NormalizeCoordinates { + get { return normalizeCoordinates_; } + set { + normalizeCoordinates_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as MultiscaleAnchorGenerator); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(MultiscaleAnchorGenerator other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MinLevel != other.MinLevel) return false; + if (MaxLevel != other.MaxLevel) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(AnchorScale, other.AnchorScale)) return false; + if(!aspectRatios_.Equals(other.aspectRatios_)) return false; + if (ScalesPerOctave != other.ScalesPerOctave) return false; + if (NormalizeCoordinates != other.NormalizeCoordinates) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinLevel != 0) hash ^= MinLevel.GetHashCode(); + if (MaxLevel != 0) hash ^= MaxLevel.GetHashCode(); + if (AnchorScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(AnchorScale); + hash ^= aspectRatios_.GetHashCode(); + if (ScalesPerOctave != 0) hash ^= ScalesPerOctave.GetHashCode(); + if (NormalizeCoordinates != false) hash ^= NormalizeCoordinates.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinLevel != 0) { + output.WriteRawTag(8); + output.WriteInt32(MinLevel); + } + if (MaxLevel != 0) { + output.WriteRawTag(16); + output.WriteInt32(MaxLevel); + } + if (AnchorScale != 0F) { + output.WriteRawTag(29); + output.WriteFloat(AnchorScale); + } + aspectRatios_.WriteTo(output, _repeated_aspectRatios_codec); + if (ScalesPerOctave != 0) { + output.WriteRawTag(40); + output.WriteInt32(ScalesPerOctave); + } + if (NormalizeCoordinates != false) { + output.WriteRawTag(48); + output.WriteBool(NormalizeCoordinates); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinLevel != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinLevel); + } + if (MaxLevel != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxLevel); + } + if (AnchorScale != 0F) { + size += 1 + 4; + } + size += aspectRatios_.CalculateSize(_repeated_aspectRatios_codec); + if (ScalesPerOctave != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ScalesPerOctave); + } + if (NormalizeCoordinates != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(MultiscaleAnchorGenerator other) { + if (other == null) { + return; + } + if (other.MinLevel != 0) { + MinLevel = other.MinLevel; + } + if (other.MaxLevel != 0) { + MaxLevel = other.MaxLevel; + } + if (other.AnchorScale != 0F) { + AnchorScale = other.AnchorScale; + } + aspectRatios_.Add(other.aspectRatios_); + if (other.ScalesPerOctave != 0) { + ScalesPerOctave = other.ScalesPerOctave; + } + if (other.NormalizeCoordinates != false) { + NormalizeCoordinates = other.NormalizeCoordinates; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + MinLevel = input.ReadInt32(); + break; + } + case 16: { + MaxLevel = input.ReadInt32(); + break; + } + case 29: { + AnchorScale = input.ReadFloat(); + break; + } + case 34: + case 37: { + aspectRatios_.AddEntriesFrom(input, _repeated_aspectRatios_codec); + break; + } + case 40: { + ScalesPerOctave = input.ReadInt32(); + break; + } + case 48: { + NormalizeCoordinates = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Optimizer.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Optimizer.cs new file mode 100644 index 00000000..c71adfa8 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Optimizer.cs @@ -0,0 +1,2231 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/optimizer.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/optimizer.proto + public static partial class OptimizerReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/optimizer.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static OptimizerReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CidvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9vcHRpbWl6ZXIucHJvdG8SF29i", + "amVjdF9kZXRlY3Rpb24ucHJvdG9zIqcCCglPcHRpbWl6ZXISRwoScm1zX3By", + "b3Bfb3B0aW1pemVyGAEgASgLMikub2JqZWN0X2RldGVjdGlvbi5wcm90b3Mu", + "Uk1TUHJvcE9wdGltaXplckgAEkgKEm1vbWVudHVtX29wdGltaXplchgCIAEo", + "CzIqLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLk1vbWVudHVtT3B0aW1pemVy", + "SAASQAoOYWRhbV9vcHRpbWl6ZXIYAyABKAsyJi5vYmplY3RfZGV0ZWN0aW9u", + "LnByb3Rvcy5BZGFtT3B0aW1pemVySAASGgoSdXNlX21vdmluZ19hdmVyYWdl", + "GAQgASgIEhwKFG1vdmluZ19hdmVyYWdlX2RlY2F5GAUgASgCQgsKCW9wdGlt", + "aXplciKSAQoQUk1TUHJvcE9wdGltaXplchI8Cg1sZWFybmluZ19yYXRlGAEg", + "ASgLMiUub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuTGVhcm5pbmdSYXRlEiAK", + "GG1vbWVudHVtX29wdGltaXplcl92YWx1ZRgCIAEoAhINCgVkZWNheRgDIAEo", + "AhIPCgdlcHNpbG9uGAQgASgCInMKEU1vbWVudHVtT3B0aW1pemVyEjwKDWxl", + "YXJuaW5nX3JhdGUYASABKAsyJS5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5M", + "ZWFybmluZ1JhdGUSIAoYbW9tZW50dW1fb3B0aW1pemVyX3ZhbHVlGAIgASgC", + "Ik0KDUFkYW1PcHRpbWl6ZXISPAoNbGVhcm5pbmdfcmF0ZRgBIAEoCzIlLm9i", + "amVjdF9kZXRlY3Rpb24ucHJvdG9zLkxlYXJuaW5nUmF0ZSKAAwoMTGVhcm5p", + "bmdSYXRlEk8KFmNvbnN0YW50X2xlYXJuaW5nX3JhdGUYASABKAsyLS5vYmpl", + "Y3RfZGV0ZWN0aW9uLnByb3Rvcy5Db25zdGFudExlYXJuaW5nUmF0ZUgAEmAK", + "H2V4cG9uZW50aWFsX2RlY2F5X2xlYXJuaW5nX3JhdGUYAiABKAsyNS5vYmpl", + "Y3RfZGV0ZWN0aW9uLnByb3Rvcy5FeHBvbmVudGlhbERlY2F5TGVhcm5pbmdS", + "YXRlSAASVAoZbWFudWFsX3N0ZXBfbGVhcm5pbmdfcmF0ZRgDIAEoCzIvLm9i", + "amVjdF9kZXRlY3Rpb24ucHJvdG9zLk1hbnVhbFN0ZXBMZWFybmluZ1JhdGVI", + "ABJWChpjb3NpbmVfZGVjYXlfbGVhcm5pbmdfcmF0ZRgEIAEoCzIwLm9iamVj", + "dF9kZXRlY3Rpb24ucHJvdG9zLkNvc2luZURlY2F5TGVhcm5pbmdSYXRlSABC", + "DwoNbGVhcm5pbmdfcmF0ZSItChRDb25zdGFudExlYXJuaW5nUmF0ZRIVCg1s", + "ZWFybmluZ19yYXRlGAEgASgCIsoBChxFeHBvbmVudGlhbERlY2F5TGVhcm5p", + "bmdSYXRlEh0KFWluaXRpYWxfbGVhcm5pbmdfcmF0ZRgBIAEoAhITCgtkZWNh", + "eV9zdGVwcxgCIAEoDRIUCgxkZWNheV9mYWN0b3IYAyABKAISEQoJc3RhaXJj", + "YXNlGAQgASgIEhwKFGJ1cm5pbl9sZWFybmluZ19yYXRlGAUgASgCEhQKDGJ1", + "cm5pbl9zdGVwcxgGIAEoDRIZChFtaW5fbGVhcm5pbmdfcmF0ZRgHIAEoAiLc", + "AQoWTWFudWFsU3RlcExlYXJuaW5nUmF0ZRIdChVpbml0aWFsX2xlYXJuaW5n", + "X3JhdGUYASABKAISVgoIc2NoZWR1bGUYAiADKAsyRC5vYmplY3RfZGV0ZWN0", + "aW9uLnByb3Rvcy5NYW51YWxTdGVwTGVhcm5pbmdSYXRlLkxlYXJuaW5nUmF0", + "ZVNjaGVkdWxlEg4KBndhcm11cBgDIAEoCBo7ChRMZWFybmluZ1JhdGVTY2hl", + "ZHVsZRIMCgRzdGVwGAEgASgNEhUKDWxlYXJuaW5nX3JhdGUYAiABKAIinAEK", + "F0Nvc2luZURlY2F5TGVhcm5pbmdSYXRlEhoKEmxlYXJuaW5nX3JhdGVfYmFz", + "ZRgBIAEoAhITCgt0b3RhbF9zdGVwcxgCIAEoDRIcChR3YXJtdXBfbGVhcm5p", + "bmdfcmF0ZRgDIAEoAhIUCgx3YXJtdXBfc3RlcHMYBCABKA0SHAoUaG9sZF9i", + "YXNlX3JhdGVfc3RlcHMYBSABKA1iBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Optimizer), global::Tensorflow.Models.ObjectDetection.Protos.Optimizer.Parser, new[]{ "RmsPropOptimizer", "MomentumOptimizer", "AdamOptimizer", "UseMovingAverage", "MovingAverageDecay" }, new[]{ "Optimizer" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RMSPropOptimizer), global::Tensorflow.Models.ObjectDetection.Protos.RMSPropOptimizer.Parser, new[]{ "LearningRate", "MomentumOptimizerValue", "Decay", "Epsilon" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.MomentumOptimizer), global::Tensorflow.Models.ObjectDetection.Protos.MomentumOptimizer.Parser, new[]{ "LearningRate", "MomentumOptimizerValue" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.AdamOptimizer), global::Tensorflow.Models.ObjectDetection.Protos.AdamOptimizer.Parser, new[]{ "LearningRate" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.LearningRate), global::Tensorflow.Models.ObjectDetection.Protos.LearningRate.Parser, new[]{ "ConstantLearningRate", "ExponentialDecayLearningRate", "ManualStepLearningRate", "CosineDecayLearningRate" }, new[]{ "LearningRate" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ConstantLearningRate), global::Tensorflow.Models.ObjectDetection.Protos.ConstantLearningRate.Parser, new[]{ "LearningRate" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ExponentialDecayLearningRate), global::Tensorflow.Models.ObjectDetection.Protos.ExponentialDecayLearningRate.Parser, new[]{ "InitialLearningRate", "DecaySteps", "DecayFactor", "Staircase", "BurninLearningRate", "BurninSteps", "MinLearningRate" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate), global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate.Parser, new[]{ "InitialLearningRate", "Schedule", "Warmup" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate.Types.LearningRateSchedule), global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate.Types.LearningRateSchedule.Parser, new[]{ "Step", "LearningRate" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.CosineDecayLearningRate), global::Tensorflow.Models.ObjectDetection.Protos.CosineDecayLearningRate.Parser, new[]{ "LearningRateBase", "TotalSteps", "WarmupLearningRate", "WarmupSteps", "HoldBaseRateSteps" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Top level optimizer message. + /// + public sealed partial class Optimizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Optimizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Optimizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Optimizer(Optimizer other) : this() { + useMovingAverage_ = other.useMovingAverage_; + movingAverageDecay_ = other.movingAverageDecay_; + switch (other.OptimizerCase) { + case OptimizerOneofCase.RmsPropOptimizer: + RmsPropOptimizer = other.RmsPropOptimizer.Clone(); + break; + case OptimizerOneofCase.MomentumOptimizer: + MomentumOptimizer = other.MomentumOptimizer.Clone(); + break; + case OptimizerOneofCase.AdamOptimizer: + AdamOptimizer = other.AdamOptimizer.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Optimizer Clone() { + return new Optimizer(this); + } + + /// Field number for the "rms_prop_optimizer" field. + public const int RmsPropOptimizerFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RMSPropOptimizer RmsPropOptimizer { + get { return optimizerCase_ == OptimizerOneofCase.RmsPropOptimizer ? (global::Tensorflow.Models.ObjectDetection.Protos.RMSPropOptimizer) optimizer_ : null; } + set { + optimizer_ = value; + optimizerCase_ = value == null ? OptimizerOneofCase.None : OptimizerOneofCase.RmsPropOptimizer; + } + } + + /// Field number for the "momentum_optimizer" field. + public const int MomentumOptimizerFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.MomentumOptimizer MomentumOptimizer { + get { return optimizerCase_ == OptimizerOneofCase.MomentumOptimizer ? (global::Tensorflow.Models.ObjectDetection.Protos.MomentumOptimizer) optimizer_ : null; } + set { + optimizer_ = value; + optimizerCase_ = value == null ? OptimizerOneofCase.None : OptimizerOneofCase.MomentumOptimizer; + } + } + + /// Field number for the "adam_optimizer" field. + public const int AdamOptimizerFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.AdamOptimizer AdamOptimizer { + get { return optimizerCase_ == OptimizerOneofCase.AdamOptimizer ? (global::Tensorflow.Models.ObjectDetection.Protos.AdamOptimizer) optimizer_ : null; } + set { + optimizer_ = value; + optimizerCase_ = value == null ? OptimizerOneofCase.None : OptimizerOneofCase.AdamOptimizer; + } + } + + /// Field number for the "use_moving_average" field. + public const int UseMovingAverageFieldNumber = 4; + private bool useMovingAverage_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseMovingAverage { + get { return useMovingAverage_; } + set { + useMovingAverage_ = value; + } + } + + /// Field number for the "moving_average_decay" field. + public const int MovingAverageDecayFieldNumber = 5; + private float movingAverageDecay_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MovingAverageDecay { + get { return movingAverageDecay_; } + set { + movingAverageDecay_ = value; + } + } + + private object optimizer_; + /// Enum of possible cases for the "optimizer" oneof. + public enum OptimizerOneofCase { + None = 0, + RmsPropOptimizer = 1, + MomentumOptimizer = 2, + AdamOptimizer = 3, + } + private OptimizerOneofCase optimizerCase_ = OptimizerOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public OptimizerOneofCase OptimizerCase { + get { return optimizerCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearOptimizer() { + optimizerCase_ = OptimizerOneofCase.None; + optimizer_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Optimizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Optimizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(RmsPropOptimizer, other.RmsPropOptimizer)) return false; + if (!object.Equals(MomentumOptimizer, other.MomentumOptimizer)) return false; + if (!object.Equals(AdamOptimizer, other.AdamOptimizer)) return false; + if (UseMovingAverage != other.UseMovingAverage) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MovingAverageDecay, other.MovingAverageDecay)) return false; + if (OptimizerCase != other.OptimizerCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (optimizerCase_ == OptimizerOneofCase.RmsPropOptimizer) hash ^= RmsPropOptimizer.GetHashCode(); + if (optimizerCase_ == OptimizerOneofCase.MomentumOptimizer) hash ^= MomentumOptimizer.GetHashCode(); + if (optimizerCase_ == OptimizerOneofCase.AdamOptimizer) hash ^= AdamOptimizer.GetHashCode(); + if (UseMovingAverage != false) hash ^= UseMovingAverage.GetHashCode(); + if (MovingAverageDecay != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MovingAverageDecay); + hash ^= (int) optimizerCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (optimizerCase_ == OptimizerOneofCase.RmsPropOptimizer) { + output.WriteRawTag(10); + output.WriteMessage(RmsPropOptimizer); + } + if (optimizerCase_ == OptimizerOneofCase.MomentumOptimizer) { + output.WriteRawTag(18); + output.WriteMessage(MomentumOptimizer); + } + if (optimizerCase_ == OptimizerOneofCase.AdamOptimizer) { + output.WriteRawTag(26); + output.WriteMessage(AdamOptimizer); + } + if (UseMovingAverage != false) { + output.WriteRawTag(32); + output.WriteBool(UseMovingAverage); + } + if (MovingAverageDecay != 0F) { + output.WriteRawTag(45); + output.WriteFloat(MovingAverageDecay); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (optimizerCase_ == OptimizerOneofCase.RmsPropOptimizer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RmsPropOptimizer); + } + if (optimizerCase_ == OptimizerOneofCase.MomentumOptimizer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(MomentumOptimizer); + } + if (optimizerCase_ == OptimizerOneofCase.AdamOptimizer) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AdamOptimizer); + } + if (UseMovingAverage != false) { + size += 1 + 1; + } + if (MovingAverageDecay != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Optimizer other) { + if (other == null) { + return; + } + if (other.UseMovingAverage != false) { + UseMovingAverage = other.UseMovingAverage; + } + if (other.MovingAverageDecay != 0F) { + MovingAverageDecay = other.MovingAverageDecay; + } + switch (other.OptimizerCase) { + case OptimizerOneofCase.RmsPropOptimizer: + if (RmsPropOptimizer == null) { + RmsPropOptimizer = new global::Tensorflow.Models.ObjectDetection.Protos.RMSPropOptimizer(); + } + RmsPropOptimizer.MergeFrom(other.RmsPropOptimizer); + break; + case OptimizerOneofCase.MomentumOptimizer: + if (MomentumOptimizer == null) { + MomentumOptimizer = new global::Tensorflow.Models.ObjectDetection.Protos.MomentumOptimizer(); + } + MomentumOptimizer.MergeFrom(other.MomentumOptimizer); + break; + case OptimizerOneofCase.AdamOptimizer: + if (AdamOptimizer == null) { + AdamOptimizer = new global::Tensorflow.Models.ObjectDetection.Protos.AdamOptimizer(); + } + AdamOptimizer.MergeFrom(other.AdamOptimizer); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.RMSPropOptimizer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RMSPropOptimizer(); + if (optimizerCase_ == OptimizerOneofCase.RmsPropOptimizer) { + subBuilder.MergeFrom(RmsPropOptimizer); + } + input.ReadMessage(subBuilder); + RmsPropOptimizer = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.MomentumOptimizer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.MomentumOptimizer(); + if (optimizerCase_ == OptimizerOneofCase.MomentumOptimizer) { + subBuilder.MergeFrom(MomentumOptimizer); + } + input.ReadMessage(subBuilder); + MomentumOptimizer = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.AdamOptimizer subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.AdamOptimizer(); + if (optimizerCase_ == OptimizerOneofCase.AdamOptimizer) { + subBuilder.MergeFrom(AdamOptimizer); + } + input.ReadMessage(subBuilder); + AdamOptimizer = subBuilder; + break; + } + case 32: { + UseMovingAverage = input.ReadBool(); + break; + } + case 45: { + MovingAverageDecay = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Configuration message for the RMSPropOptimizer + /// See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer + /// + public sealed partial class RMSPropOptimizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RMSPropOptimizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RMSPropOptimizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RMSPropOptimizer(RMSPropOptimizer other) : this() { + learningRate_ = other.learningRate_ != null ? other.learningRate_.Clone() : null; + momentumOptimizerValue_ = other.momentumOptimizerValue_; + decay_ = other.decay_; + epsilon_ = other.epsilon_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RMSPropOptimizer Clone() { + return new RMSPropOptimizer(this); + } + + /// Field number for the "learning_rate" field. + public const int LearningRateFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.LearningRate learningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.LearningRate LearningRate { + get { return learningRate_; } + set { + learningRate_ = value; + } + } + + /// Field number for the "momentum_optimizer_value" field. + public const int MomentumOptimizerValueFieldNumber = 2; + private float momentumOptimizerValue_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MomentumOptimizerValue { + get { return momentumOptimizerValue_; } + set { + momentumOptimizerValue_ = value; + } + } + + /// Field number for the "decay" field. + public const int DecayFieldNumber = 3; + private float decay_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Decay { + get { return decay_; } + set { + decay_ = value; + } + } + + /// Field number for the "epsilon" field. + public const int EpsilonFieldNumber = 4; + private float epsilon_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Epsilon { + get { return epsilon_; } + set { + epsilon_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RMSPropOptimizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RMSPropOptimizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(LearningRate, other.LearningRate)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MomentumOptimizerValue, other.MomentumOptimizerValue)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Decay, other.Decay)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Epsilon, other.Epsilon)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (learningRate_ != null) hash ^= LearningRate.GetHashCode(); + if (MomentumOptimizerValue != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MomentumOptimizerValue); + if (Decay != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Decay); + if (Epsilon != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Epsilon); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (learningRate_ != null) { + output.WriteRawTag(10); + output.WriteMessage(LearningRate); + } + if (MomentumOptimizerValue != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MomentumOptimizerValue); + } + if (Decay != 0F) { + output.WriteRawTag(29); + output.WriteFloat(Decay); + } + if (Epsilon != 0F) { + output.WriteRawTag(37); + output.WriteFloat(Epsilon); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (learningRate_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(LearningRate); + } + if (MomentumOptimizerValue != 0F) { + size += 1 + 4; + } + if (Decay != 0F) { + size += 1 + 4; + } + if (Epsilon != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RMSPropOptimizer other) { + if (other == null) { + return; + } + if (other.learningRate_ != null) { + if (learningRate_ == null) { + learningRate_ = new global::Tensorflow.Models.ObjectDetection.Protos.LearningRate(); + } + LearningRate.MergeFrom(other.LearningRate); + } + if (other.MomentumOptimizerValue != 0F) { + MomentumOptimizerValue = other.MomentumOptimizerValue; + } + if (other.Decay != 0F) { + Decay = other.Decay; + } + if (other.Epsilon != 0F) { + Epsilon = other.Epsilon; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (learningRate_ == null) { + learningRate_ = new global::Tensorflow.Models.ObjectDetection.Protos.LearningRate(); + } + input.ReadMessage(learningRate_); + break; + } + case 21: { + MomentumOptimizerValue = input.ReadFloat(); + break; + } + case 29: { + Decay = input.ReadFloat(); + break; + } + case 37: { + Epsilon = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Configuration message for the MomentumOptimizer + /// See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer + /// + public sealed partial class MomentumOptimizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MomentumOptimizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MomentumOptimizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MomentumOptimizer(MomentumOptimizer other) : this() { + learningRate_ = other.learningRate_ != null ? other.learningRate_.Clone() : null; + momentumOptimizerValue_ = other.momentumOptimizerValue_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MomentumOptimizer Clone() { + return new MomentumOptimizer(this); + } + + /// Field number for the "learning_rate" field. + public const int LearningRateFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.LearningRate learningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.LearningRate LearningRate { + get { return learningRate_; } + set { + learningRate_ = value; + } + } + + /// Field number for the "momentum_optimizer_value" field. + public const int MomentumOptimizerValueFieldNumber = 2; + private float momentumOptimizerValue_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MomentumOptimizerValue { + get { return momentumOptimizerValue_; } + set { + momentumOptimizerValue_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as MomentumOptimizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(MomentumOptimizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(LearningRate, other.LearningRate)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MomentumOptimizerValue, other.MomentumOptimizerValue)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (learningRate_ != null) hash ^= LearningRate.GetHashCode(); + if (MomentumOptimizerValue != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MomentumOptimizerValue); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (learningRate_ != null) { + output.WriteRawTag(10); + output.WriteMessage(LearningRate); + } + if (MomentumOptimizerValue != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MomentumOptimizerValue); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (learningRate_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(LearningRate); + } + if (MomentumOptimizerValue != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(MomentumOptimizer other) { + if (other == null) { + return; + } + if (other.learningRate_ != null) { + if (learningRate_ == null) { + learningRate_ = new global::Tensorflow.Models.ObjectDetection.Protos.LearningRate(); + } + LearningRate.MergeFrom(other.LearningRate); + } + if (other.MomentumOptimizerValue != 0F) { + MomentumOptimizerValue = other.MomentumOptimizerValue; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (learningRate_ == null) { + learningRate_ = new global::Tensorflow.Models.ObjectDetection.Protos.LearningRate(); + } + input.ReadMessage(learningRate_); + break; + } + case 21: { + MomentumOptimizerValue = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Configuration message for the AdamOptimizer + /// See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer + /// + public sealed partial class AdamOptimizer : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AdamOptimizer()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AdamOptimizer() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AdamOptimizer(AdamOptimizer other) : this() { + learningRate_ = other.learningRate_ != null ? other.learningRate_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AdamOptimizer Clone() { + return new AdamOptimizer(this); + } + + /// Field number for the "learning_rate" field. + public const int LearningRateFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.LearningRate learningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.LearningRate LearningRate { + get { return learningRate_; } + set { + learningRate_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AdamOptimizer); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AdamOptimizer other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(LearningRate, other.LearningRate)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (learningRate_ != null) hash ^= LearningRate.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (learningRate_ != null) { + output.WriteRawTag(10); + output.WriteMessage(LearningRate); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (learningRate_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(LearningRate); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AdamOptimizer other) { + if (other == null) { + return; + } + if (other.learningRate_ != null) { + if (learningRate_ == null) { + learningRate_ = new global::Tensorflow.Models.ObjectDetection.Protos.LearningRate(); + } + LearningRate.MergeFrom(other.LearningRate); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (learningRate_ == null) { + learningRate_ = new global::Tensorflow.Models.ObjectDetection.Protos.LearningRate(); + } + input.ReadMessage(learningRate_); + break; + } + } + } + } + + } + + /// + /// Configuration message for optimizer learning rate. + /// + public sealed partial class LearningRate : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LearningRate()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LearningRate() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LearningRate(LearningRate other) : this() { + switch (other.LearningRateCase) { + case LearningRateOneofCase.ConstantLearningRate: + ConstantLearningRate = other.ConstantLearningRate.Clone(); + break; + case LearningRateOneofCase.ExponentialDecayLearningRate: + ExponentialDecayLearningRate = other.ExponentialDecayLearningRate.Clone(); + break; + case LearningRateOneofCase.ManualStepLearningRate: + ManualStepLearningRate = other.ManualStepLearningRate.Clone(); + break; + case LearningRateOneofCase.CosineDecayLearningRate: + CosineDecayLearningRate = other.CosineDecayLearningRate.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LearningRate Clone() { + return new LearningRate(this); + } + + /// Field number for the "constant_learning_rate" field. + public const int ConstantLearningRateFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ConstantLearningRate ConstantLearningRate { + get { return learningRateCase_ == LearningRateOneofCase.ConstantLearningRate ? (global::Tensorflow.Models.ObjectDetection.Protos.ConstantLearningRate) learningRate_ : null; } + set { + learningRate_ = value; + learningRateCase_ = value == null ? LearningRateOneofCase.None : LearningRateOneofCase.ConstantLearningRate; + } + } + + /// Field number for the "exponential_decay_learning_rate" field. + public const int ExponentialDecayLearningRateFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ExponentialDecayLearningRate ExponentialDecayLearningRate { + get { return learningRateCase_ == LearningRateOneofCase.ExponentialDecayLearningRate ? (global::Tensorflow.Models.ObjectDetection.Protos.ExponentialDecayLearningRate) learningRate_ : null; } + set { + learningRate_ = value; + learningRateCase_ = value == null ? LearningRateOneofCase.None : LearningRateOneofCase.ExponentialDecayLearningRate; + } + } + + /// Field number for the "manual_step_learning_rate" field. + public const int ManualStepLearningRateFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate ManualStepLearningRate { + get { return learningRateCase_ == LearningRateOneofCase.ManualStepLearningRate ? (global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate) learningRate_ : null; } + set { + learningRate_ = value; + learningRateCase_ = value == null ? LearningRateOneofCase.None : LearningRateOneofCase.ManualStepLearningRate; + } + } + + /// Field number for the "cosine_decay_learning_rate" field. + public const int CosineDecayLearningRateFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.CosineDecayLearningRate CosineDecayLearningRate { + get { return learningRateCase_ == LearningRateOneofCase.CosineDecayLearningRate ? (global::Tensorflow.Models.ObjectDetection.Protos.CosineDecayLearningRate) learningRate_ : null; } + set { + learningRate_ = value; + learningRateCase_ = value == null ? LearningRateOneofCase.None : LearningRateOneofCase.CosineDecayLearningRate; + } + } + + private object learningRate_; + /// Enum of possible cases for the "learning_rate" oneof. + public enum LearningRateOneofCase { + None = 0, + ConstantLearningRate = 1, + ExponentialDecayLearningRate = 2, + ManualStepLearningRate = 3, + CosineDecayLearningRate = 4, + } + private LearningRateOneofCase learningRateCase_ = LearningRateOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LearningRateOneofCase LearningRateCase { + get { return learningRateCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearLearningRate() { + learningRateCase_ = LearningRateOneofCase.None; + learningRate_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as LearningRate); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(LearningRate other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(ConstantLearningRate, other.ConstantLearningRate)) return false; + if (!object.Equals(ExponentialDecayLearningRate, other.ExponentialDecayLearningRate)) return false; + if (!object.Equals(ManualStepLearningRate, other.ManualStepLearningRate)) return false; + if (!object.Equals(CosineDecayLearningRate, other.CosineDecayLearningRate)) return false; + if (LearningRateCase != other.LearningRateCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (learningRateCase_ == LearningRateOneofCase.ConstantLearningRate) hash ^= ConstantLearningRate.GetHashCode(); + if (learningRateCase_ == LearningRateOneofCase.ExponentialDecayLearningRate) hash ^= ExponentialDecayLearningRate.GetHashCode(); + if (learningRateCase_ == LearningRateOneofCase.ManualStepLearningRate) hash ^= ManualStepLearningRate.GetHashCode(); + if (learningRateCase_ == LearningRateOneofCase.CosineDecayLearningRate) hash ^= CosineDecayLearningRate.GetHashCode(); + hash ^= (int) learningRateCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (learningRateCase_ == LearningRateOneofCase.ConstantLearningRate) { + output.WriteRawTag(10); + output.WriteMessage(ConstantLearningRate); + } + if (learningRateCase_ == LearningRateOneofCase.ExponentialDecayLearningRate) { + output.WriteRawTag(18); + output.WriteMessage(ExponentialDecayLearningRate); + } + if (learningRateCase_ == LearningRateOneofCase.ManualStepLearningRate) { + output.WriteRawTag(26); + output.WriteMessage(ManualStepLearningRate); + } + if (learningRateCase_ == LearningRateOneofCase.CosineDecayLearningRate) { + output.WriteRawTag(34); + output.WriteMessage(CosineDecayLearningRate); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (learningRateCase_ == LearningRateOneofCase.ConstantLearningRate) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ConstantLearningRate); + } + if (learningRateCase_ == LearningRateOneofCase.ExponentialDecayLearningRate) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ExponentialDecayLearningRate); + } + if (learningRateCase_ == LearningRateOneofCase.ManualStepLearningRate) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ManualStepLearningRate); + } + if (learningRateCase_ == LearningRateOneofCase.CosineDecayLearningRate) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CosineDecayLearningRate); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(LearningRate other) { + if (other == null) { + return; + } + switch (other.LearningRateCase) { + case LearningRateOneofCase.ConstantLearningRate: + if (ConstantLearningRate == null) { + ConstantLearningRate = new global::Tensorflow.Models.ObjectDetection.Protos.ConstantLearningRate(); + } + ConstantLearningRate.MergeFrom(other.ConstantLearningRate); + break; + case LearningRateOneofCase.ExponentialDecayLearningRate: + if (ExponentialDecayLearningRate == null) { + ExponentialDecayLearningRate = new global::Tensorflow.Models.ObjectDetection.Protos.ExponentialDecayLearningRate(); + } + ExponentialDecayLearningRate.MergeFrom(other.ExponentialDecayLearningRate); + break; + case LearningRateOneofCase.ManualStepLearningRate: + if (ManualStepLearningRate == null) { + ManualStepLearningRate = new global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate(); + } + ManualStepLearningRate.MergeFrom(other.ManualStepLearningRate); + break; + case LearningRateOneofCase.CosineDecayLearningRate: + if (CosineDecayLearningRate == null) { + CosineDecayLearningRate = new global::Tensorflow.Models.ObjectDetection.Protos.CosineDecayLearningRate(); + } + CosineDecayLearningRate.MergeFrom(other.CosineDecayLearningRate); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.ConstantLearningRate subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ConstantLearningRate(); + if (learningRateCase_ == LearningRateOneofCase.ConstantLearningRate) { + subBuilder.MergeFrom(ConstantLearningRate); + } + input.ReadMessage(subBuilder); + ConstantLearningRate = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.ExponentialDecayLearningRate subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ExponentialDecayLearningRate(); + if (learningRateCase_ == LearningRateOneofCase.ExponentialDecayLearningRate) { + subBuilder.MergeFrom(ExponentialDecayLearningRate); + } + input.ReadMessage(subBuilder); + ExponentialDecayLearningRate = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate(); + if (learningRateCase_ == LearningRateOneofCase.ManualStepLearningRate) { + subBuilder.MergeFrom(ManualStepLearningRate); + } + input.ReadMessage(subBuilder); + ManualStepLearningRate = subBuilder; + break; + } + case 34: { + global::Tensorflow.Models.ObjectDetection.Protos.CosineDecayLearningRate subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.CosineDecayLearningRate(); + if (learningRateCase_ == LearningRateOneofCase.CosineDecayLearningRate) { + subBuilder.MergeFrom(CosineDecayLearningRate); + } + input.ReadMessage(subBuilder); + CosineDecayLearningRate = subBuilder; + break; + } + } + } + } + + } + + /// + /// Configuration message for a constant learning rate. + /// + public sealed partial class ConstantLearningRate : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ConstantLearningRate()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConstantLearningRate() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConstantLearningRate(ConstantLearningRate other) : this() { + learningRate_ = other.learningRate_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConstantLearningRate Clone() { + return new ConstantLearningRate(this); + } + + /// Field number for the "learning_rate" field. + public const int LearningRateFieldNumber = 1; + private float learningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float LearningRate { + get { return learningRate_; } + set { + learningRate_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ConstantLearningRate); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ConstantLearningRate other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(LearningRate, other.LearningRate)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (LearningRate != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(LearningRate); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (LearningRate != 0F) { + output.WriteRawTag(13); + output.WriteFloat(LearningRate); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (LearningRate != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ConstantLearningRate other) { + if (other == null) { + return; + } + if (other.LearningRate != 0F) { + LearningRate = other.LearningRate; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + LearningRate = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Configuration message for an exponentially decaying learning rate. + /// See https://www.tensorflow.org/versions/master/api_docs/python/train/ \ + /// decaying_the_learning_rate#exponential_decay + /// + public sealed partial class ExponentialDecayLearningRate : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ExponentialDecayLearningRate()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ExponentialDecayLearningRate() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ExponentialDecayLearningRate(ExponentialDecayLearningRate other) : this() { + initialLearningRate_ = other.initialLearningRate_; + decaySteps_ = other.decaySteps_; + decayFactor_ = other.decayFactor_; + staircase_ = other.staircase_; + burninLearningRate_ = other.burninLearningRate_; + burninSteps_ = other.burninSteps_; + minLearningRate_ = other.minLearningRate_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ExponentialDecayLearningRate Clone() { + return new ExponentialDecayLearningRate(this); + } + + /// Field number for the "initial_learning_rate" field. + public const int InitialLearningRateFieldNumber = 1; + private float initialLearningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float InitialLearningRate { + get { return initialLearningRate_; } + set { + initialLearningRate_ = value; + } + } + + /// Field number for the "decay_steps" field. + public const int DecayStepsFieldNumber = 2; + private uint decaySteps_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint DecaySteps { + get { return decaySteps_; } + set { + decaySteps_ = value; + } + } + + /// Field number for the "decay_factor" field. + public const int DecayFactorFieldNumber = 3; + private float decayFactor_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float DecayFactor { + get { return decayFactor_; } + set { + decayFactor_ = value; + } + } + + /// Field number for the "staircase" field. + public const int StaircaseFieldNumber = 4; + private bool staircase_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Staircase { + get { return staircase_; } + set { + staircase_ = value; + } + } + + /// Field number for the "burnin_learning_rate" field. + public const int BurninLearningRateFieldNumber = 5; + private float burninLearningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float BurninLearningRate { + get { return burninLearningRate_; } + set { + burninLearningRate_ = value; + } + } + + /// Field number for the "burnin_steps" field. + public const int BurninStepsFieldNumber = 6; + private uint burninSteps_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint BurninSteps { + get { return burninSteps_; } + set { + burninSteps_ = value; + } + } + + /// Field number for the "min_learning_rate" field. + public const int MinLearningRateFieldNumber = 7; + private float minLearningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinLearningRate { + get { return minLearningRate_; } + set { + minLearningRate_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ExponentialDecayLearningRate); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ExponentialDecayLearningRate other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(InitialLearningRate, other.InitialLearningRate)) return false; + if (DecaySteps != other.DecaySteps) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(DecayFactor, other.DecayFactor)) return false; + if (Staircase != other.Staircase) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(BurninLearningRate, other.BurninLearningRate)) return false; + if (BurninSteps != other.BurninSteps) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinLearningRate, other.MinLearningRate)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (InitialLearningRate != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(InitialLearningRate); + if (DecaySteps != 0) hash ^= DecaySteps.GetHashCode(); + if (DecayFactor != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(DecayFactor); + if (Staircase != false) hash ^= Staircase.GetHashCode(); + if (BurninLearningRate != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(BurninLearningRate); + if (BurninSteps != 0) hash ^= BurninSteps.GetHashCode(); + if (MinLearningRate != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinLearningRate); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (InitialLearningRate != 0F) { + output.WriteRawTag(13); + output.WriteFloat(InitialLearningRate); + } + if (DecaySteps != 0) { + output.WriteRawTag(16); + output.WriteUInt32(DecaySteps); + } + if (DecayFactor != 0F) { + output.WriteRawTag(29); + output.WriteFloat(DecayFactor); + } + if (Staircase != false) { + output.WriteRawTag(32); + output.WriteBool(Staircase); + } + if (BurninLearningRate != 0F) { + output.WriteRawTag(45); + output.WriteFloat(BurninLearningRate); + } + if (BurninSteps != 0) { + output.WriteRawTag(48); + output.WriteUInt32(BurninSteps); + } + if (MinLearningRate != 0F) { + output.WriteRawTag(61); + output.WriteFloat(MinLearningRate); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (InitialLearningRate != 0F) { + size += 1 + 4; + } + if (DecaySteps != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(DecaySteps); + } + if (DecayFactor != 0F) { + size += 1 + 4; + } + if (Staircase != false) { + size += 1 + 1; + } + if (BurninLearningRate != 0F) { + size += 1 + 4; + } + if (BurninSteps != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(BurninSteps); + } + if (MinLearningRate != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ExponentialDecayLearningRate other) { + if (other == null) { + return; + } + if (other.InitialLearningRate != 0F) { + InitialLearningRate = other.InitialLearningRate; + } + if (other.DecaySteps != 0) { + DecaySteps = other.DecaySteps; + } + if (other.DecayFactor != 0F) { + DecayFactor = other.DecayFactor; + } + if (other.Staircase != false) { + Staircase = other.Staircase; + } + if (other.BurninLearningRate != 0F) { + BurninLearningRate = other.BurninLearningRate; + } + if (other.BurninSteps != 0) { + BurninSteps = other.BurninSteps; + } + if (other.MinLearningRate != 0F) { + MinLearningRate = other.MinLearningRate; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + InitialLearningRate = input.ReadFloat(); + break; + } + case 16: { + DecaySteps = input.ReadUInt32(); + break; + } + case 29: { + DecayFactor = input.ReadFloat(); + break; + } + case 32: { + Staircase = input.ReadBool(); + break; + } + case 45: { + BurninLearningRate = input.ReadFloat(); + break; + } + case 48: { + BurninSteps = input.ReadUInt32(); + break; + } + case 61: { + MinLearningRate = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Configuration message for a manually defined learning rate schedule. + /// + public sealed partial class ManualStepLearningRate : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ManualStepLearningRate()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ManualStepLearningRate() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ManualStepLearningRate(ManualStepLearningRate other) : this() { + initialLearningRate_ = other.initialLearningRate_; + schedule_ = other.schedule_.Clone(); + warmup_ = other.warmup_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ManualStepLearningRate Clone() { + return new ManualStepLearningRate(this); + } + + /// Field number for the "initial_learning_rate" field. + public const int InitialLearningRateFieldNumber = 1; + private float initialLearningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float InitialLearningRate { + get { return initialLearningRate_; } + set { + initialLearningRate_ = value; + } + } + + /// Field number for the "schedule" field. + public const int ScheduleFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_schedule_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate.Types.LearningRateSchedule.Parser); + private readonly pbc::RepeatedField schedule_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Schedule { + get { return schedule_; } + } + + /// Field number for the "warmup" field. + public const int WarmupFieldNumber = 3; + private bool warmup_; + /// + /// Whether to linearly interpolate learning rates for steps in + /// [0, schedule[0].step]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Warmup { + get { return warmup_; } + set { + warmup_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ManualStepLearningRate); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ManualStepLearningRate other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(InitialLearningRate, other.InitialLearningRate)) return false; + if(!schedule_.Equals(other.schedule_)) return false; + if (Warmup != other.Warmup) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (InitialLearningRate != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(InitialLearningRate); + hash ^= schedule_.GetHashCode(); + if (Warmup != false) hash ^= Warmup.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (InitialLearningRate != 0F) { + output.WriteRawTag(13); + output.WriteFloat(InitialLearningRate); + } + schedule_.WriteTo(output, _repeated_schedule_codec); + if (Warmup != false) { + output.WriteRawTag(24); + output.WriteBool(Warmup); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (InitialLearningRate != 0F) { + size += 1 + 4; + } + size += schedule_.CalculateSize(_repeated_schedule_codec); + if (Warmup != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ManualStepLearningRate other) { + if (other == null) { + return; + } + if (other.InitialLearningRate != 0F) { + InitialLearningRate = other.InitialLearningRate; + } + schedule_.Add(other.schedule_); + if (other.Warmup != false) { + Warmup = other.Warmup; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + InitialLearningRate = input.ReadFloat(); + break; + } + case 18: { + schedule_.AddEntriesFrom(input, _repeated_schedule_codec); + break; + } + case 24: { + Warmup = input.ReadBool(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the ManualStepLearningRate message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + public sealed partial class LearningRateSchedule : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new LearningRateSchedule()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.ManualStepLearningRate.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LearningRateSchedule() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LearningRateSchedule(LearningRateSchedule other) : this() { + step_ = other.step_; + learningRate_ = other.learningRate_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public LearningRateSchedule Clone() { + return new LearningRateSchedule(this); + } + + /// Field number for the "step" field. + public const int StepFieldNumber = 1; + private uint step_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint Step { + get { return step_; } + set { + step_ = value; + } + } + + /// Field number for the "learning_rate" field. + public const int LearningRateFieldNumber = 2; + private float learningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float LearningRate { + get { return learningRate_; } + set { + learningRate_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as LearningRateSchedule); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(LearningRateSchedule other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Step != other.Step) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(LearningRate, other.LearningRate)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Step != 0) hash ^= Step.GetHashCode(); + if (LearningRate != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(LearningRate); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Step != 0) { + output.WriteRawTag(8); + output.WriteUInt32(Step); + } + if (LearningRate != 0F) { + output.WriteRawTag(21); + output.WriteFloat(LearningRate); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Step != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(Step); + } + if (LearningRate != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(LearningRateSchedule other) { + if (other == null) { + return; + } + if (other.Step != 0) { + Step = other.Step; + } + if (other.LearningRate != 0F) { + LearningRate = other.LearningRate; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Step = input.ReadUInt32(); + break; + } + case 21: { + LearningRate = input.ReadFloat(); + break; + } + } + } + } + + } + + } + #endregion + + } + + /// + /// Configuration message for a cosine decaying learning rate as defined in + /// object_detection/utils/learning_schedules.py + /// + public sealed partial class CosineDecayLearningRate : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CosineDecayLearningRate()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CosineDecayLearningRate() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CosineDecayLearningRate(CosineDecayLearningRate other) : this() { + learningRateBase_ = other.learningRateBase_; + totalSteps_ = other.totalSteps_; + warmupLearningRate_ = other.warmupLearningRate_; + warmupSteps_ = other.warmupSteps_; + holdBaseRateSteps_ = other.holdBaseRateSteps_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CosineDecayLearningRate Clone() { + return new CosineDecayLearningRate(this); + } + + /// Field number for the "learning_rate_base" field. + public const int LearningRateBaseFieldNumber = 1; + private float learningRateBase_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float LearningRateBase { + get { return learningRateBase_; } + set { + learningRateBase_ = value; + } + } + + /// Field number for the "total_steps" field. + public const int TotalStepsFieldNumber = 2; + private uint totalSteps_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint TotalSteps { + get { return totalSteps_; } + set { + totalSteps_ = value; + } + } + + /// Field number for the "warmup_learning_rate" field. + public const int WarmupLearningRateFieldNumber = 3; + private float warmupLearningRate_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float WarmupLearningRate { + get { return warmupLearningRate_; } + set { + warmupLearningRate_ = value; + } + } + + /// Field number for the "warmup_steps" field. + public const int WarmupStepsFieldNumber = 4; + private uint warmupSteps_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint WarmupSteps { + get { return warmupSteps_; } + set { + warmupSteps_ = value; + } + } + + /// Field number for the "hold_base_rate_steps" field. + public const int HoldBaseRateStepsFieldNumber = 5; + private uint holdBaseRateSteps_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint HoldBaseRateSteps { + get { return holdBaseRateSteps_; } + set { + holdBaseRateSteps_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as CosineDecayLearningRate); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(CosineDecayLearningRate other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(LearningRateBase, other.LearningRateBase)) return false; + if (TotalSteps != other.TotalSteps) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(WarmupLearningRate, other.WarmupLearningRate)) return false; + if (WarmupSteps != other.WarmupSteps) return false; + if (HoldBaseRateSteps != other.HoldBaseRateSteps) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (LearningRateBase != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(LearningRateBase); + if (TotalSteps != 0) hash ^= TotalSteps.GetHashCode(); + if (WarmupLearningRate != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(WarmupLearningRate); + if (WarmupSteps != 0) hash ^= WarmupSteps.GetHashCode(); + if (HoldBaseRateSteps != 0) hash ^= HoldBaseRateSteps.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (LearningRateBase != 0F) { + output.WriteRawTag(13); + output.WriteFloat(LearningRateBase); + } + if (TotalSteps != 0) { + output.WriteRawTag(16); + output.WriteUInt32(TotalSteps); + } + if (WarmupLearningRate != 0F) { + output.WriteRawTag(29); + output.WriteFloat(WarmupLearningRate); + } + if (WarmupSteps != 0) { + output.WriteRawTag(32); + output.WriteUInt32(WarmupSteps); + } + if (HoldBaseRateSteps != 0) { + output.WriteRawTag(40); + output.WriteUInt32(HoldBaseRateSteps); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (LearningRateBase != 0F) { + size += 1 + 4; + } + if (TotalSteps != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(TotalSteps); + } + if (WarmupLearningRate != 0F) { + size += 1 + 4; + } + if (WarmupSteps != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(WarmupSteps); + } + if (HoldBaseRateSteps != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(HoldBaseRateSteps); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(CosineDecayLearningRate other) { + if (other == null) { + return; + } + if (other.LearningRateBase != 0F) { + LearningRateBase = other.LearningRateBase; + } + if (other.TotalSteps != 0) { + TotalSteps = other.TotalSteps; + } + if (other.WarmupLearningRate != 0F) { + WarmupLearningRate = other.WarmupLearningRate; + } + if (other.WarmupSteps != 0) { + WarmupSteps = other.WarmupSteps; + } + if (other.HoldBaseRateSteps != 0) { + HoldBaseRateSteps = other.HoldBaseRateSteps; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + LearningRateBase = input.ReadFloat(); + break; + } + case 16: { + TotalSteps = input.ReadUInt32(); + break; + } + case 29: { + WarmupLearningRate = input.ReadFloat(); + break; + } + case 32: { + WarmupSteps = input.ReadUInt32(); + break; + } + case 40: { + HoldBaseRateSteps = input.ReadUInt32(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Pipeline.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Pipeline.cs new file mode 100644 index 00000000..d5bef7cb --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Pipeline.cs @@ -0,0 +1,352 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/pipeline.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/pipeline.proto + public static partial class PipelineReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/pipeline.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static PipelineReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiZvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9waXBlbGluZS5wcm90bxIXb2Jq", + "ZWN0X2RldGVjdGlvbi5wcm90b3MaIm9iamVjdF9kZXRlY3Rpb24vcHJvdG9z", + "L2V2YWwucHJvdG8aLG9iamVjdF9kZXRlY3Rpb24vcHJvdG9zL2dyYXBoX3Jl", + "d3JpdGVyLnByb3RvGipvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9pbnB1dF9y", + "ZWFkZXIucHJvdG8aI29iamVjdF9kZXRlY3Rpb24vcHJvdG9zL21vZGVsLnBy", + "b3RvGiNvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy90cmFpbi5wcm90byKKAwoX", + "VHJhaW5FdmFsUGlwZWxpbmVDb25maWcSNgoFbW9kZWwYASABKAsyJy5vYmpl", + "Y3RfZGV0ZWN0aW9uLnByb3Rvcy5EZXRlY3Rpb25Nb2RlbBI6Cgx0cmFpbl9j", + "b25maWcYAiABKAsyJC5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5UcmFpbkNv", + "bmZpZxJAChJ0cmFpbl9pbnB1dF9yZWFkZXIYAyABKAsyJC5vYmplY3RfZGV0", + "ZWN0aW9uLnByb3Rvcy5JbnB1dFJlYWRlchI4CgtldmFsX2NvbmZpZxgEIAEo", + "CzIjLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLkV2YWxDb25maWcSPwoRZXZh", + "bF9pbnB1dF9yZWFkZXIYBSADKAsyJC5vYmplY3RfZGV0ZWN0aW9uLnByb3Rv", + "cy5JbnB1dFJlYWRlchI+Cg5ncmFwaF9yZXdyaXRlchgGIAEoCzImLm9iamVj", + "dF9kZXRlY3Rpb24ucHJvdG9zLkdyYXBoUmV3cml0ZXJiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.EvalReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.GraphRewriterReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.InputReaderReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.ModelReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.TrainReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.TrainEvalPipelineConfig), global::Tensorflow.Models.ObjectDetection.Protos.TrainEvalPipelineConfig.Parser, new[]{ "Model", "TrainConfig", "TrainInputReader", "EvalConfig", "EvalInputReader", "GraphRewriter" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Convenience message for configuring a training and eval pipeline. Allows all + /// of the pipeline parameters to be configured from one file. + /// Next id: 7 + /// + public sealed partial class TrainEvalPipelineConfig : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrainEvalPipelineConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PipelineReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainEvalPipelineConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainEvalPipelineConfig(TrainEvalPipelineConfig other) : this() { + model_ = other.model_ != null ? other.model_.Clone() : null; + trainConfig_ = other.trainConfig_ != null ? other.trainConfig_.Clone() : null; + trainInputReader_ = other.trainInputReader_ != null ? other.trainInputReader_.Clone() : null; + evalConfig_ = other.evalConfig_ != null ? other.evalConfig_.Clone() : null; + evalInputReader_ = other.evalInputReader_.Clone(); + graphRewriter_ = other.graphRewriter_ != null ? other.graphRewriter_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainEvalPipelineConfig Clone() { + return new TrainEvalPipelineConfig(this); + } + + /// Field number for the "model" field. + public const int ModelFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.DetectionModel model_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.DetectionModel Model { + get { return model_; } + set { + model_ = value; + } + } + + /// Field number for the "train_config" field. + public const int TrainConfigFieldNumber = 2; + private global::Tensorflow.Models.ObjectDetection.Protos.TrainConfig trainConfig_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.TrainConfig TrainConfig { + get { return trainConfig_; } + set { + trainConfig_ = value; + } + } + + /// Field number for the "train_input_reader" field. + public const int TrainInputReaderFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.InputReader trainInputReader_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.InputReader TrainInputReader { + get { return trainInputReader_; } + set { + trainInputReader_ = value; + } + } + + /// Field number for the "eval_config" field. + public const int EvalConfigFieldNumber = 4; + private global::Tensorflow.Models.ObjectDetection.Protos.EvalConfig evalConfig_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.EvalConfig EvalConfig { + get { return evalConfig_; } + set { + evalConfig_ = value; + } + } + + /// Field number for the "eval_input_reader" field. + public const int EvalInputReaderFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_evalInputReader_codec + = pb::FieldCodec.ForMessage(42, global::Tensorflow.Models.ObjectDetection.Protos.InputReader.Parser); + private readonly pbc::RepeatedField evalInputReader_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField EvalInputReader { + get { return evalInputReader_; } + } + + /// Field number for the "graph_rewriter" field. + public const int GraphRewriterFieldNumber = 6; + private global::Tensorflow.Models.ObjectDetection.Protos.GraphRewriter graphRewriter_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.GraphRewriter GraphRewriter { + get { return graphRewriter_; } + set { + graphRewriter_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TrainEvalPipelineConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TrainEvalPipelineConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(Model, other.Model)) return false; + if (!object.Equals(TrainConfig, other.TrainConfig)) return false; + if (!object.Equals(TrainInputReader, other.TrainInputReader)) return false; + if (!object.Equals(EvalConfig, other.EvalConfig)) return false; + if(!evalInputReader_.Equals(other.evalInputReader_)) return false; + if (!object.Equals(GraphRewriter, other.GraphRewriter)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (model_ != null) hash ^= Model.GetHashCode(); + if (trainConfig_ != null) hash ^= TrainConfig.GetHashCode(); + if (trainInputReader_ != null) hash ^= TrainInputReader.GetHashCode(); + if (evalConfig_ != null) hash ^= EvalConfig.GetHashCode(); + hash ^= evalInputReader_.GetHashCode(); + if (graphRewriter_ != null) hash ^= GraphRewriter.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (model_ != null) { + output.WriteRawTag(10); + output.WriteMessage(Model); + } + if (trainConfig_ != null) { + output.WriteRawTag(18); + output.WriteMessage(TrainConfig); + } + if (trainInputReader_ != null) { + output.WriteRawTag(26); + output.WriteMessage(TrainInputReader); + } + if (evalConfig_ != null) { + output.WriteRawTag(34); + output.WriteMessage(EvalConfig); + } + evalInputReader_.WriteTo(output, _repeated_evalInputReader_codec); + if (graphRewriter_ != null) { + output.WriteRawTag(50); + output.WriteMessage(GraphRewriter); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (model_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Model); + } + if (trainConfig_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TrainConfig); + } + if (trainInputReader_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(TrainInputReader); + } + if (evalConfig_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(EvalConfig); + } + size += evalInputReader_.CalculateSize(_repeated_evalInputReader_codec); + if (graphRewriter_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(GraphRewriter); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TrainEvalPipelineConfig other) { + if (other == null) { + return; + } + if (other.model_ != null) { + if (model_ == null) { + model_ = new global::Tensorflow.Models.ObjectDetection.Protos.DetectionModel(); + } + Model.MergeFrom(other.Model); + } + if (other.trainConfig_ != null) { + if (trainConfig_ == null) { + trainConfig_ = new global::Tensorflow.Models.ObjectDetection.Protos.TrainConfig(); + } + TrainConfig.MergeFrom(other.TrainConfig); + } + if (other.trainInputReader_ != null) { + if (trainInputReader_ == null) { + trainInputReader_ = new global::Tensorflow.Models.ObjectDetection.Protos.InputReader(); + } + TrainInputReader.MergeFrom(other.TrainInputReader); + } + if (other.evalConfig_ != null) { + if (evalConfig_ == null) { + evalConfig_ = new global::Tensorflow.Models.ObjectDetection.Protos.EvalConfig(); + } + EvalConfig.MergeFrom(other.EvalConfig); + } + evalInputReader_.Add(other.evalInputReader_); + if (other.graphRewriter_ != null) { + if (graphRewriter_ == null) { + graphRewriter_ = new global::Tensorflow.Models.ObjectDetection.Protos.GraphRewriter(); + } + GraphRewriter.MergeFrom(other.GraphRewriter); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (model_ == null) { + model_ = new global::Tensorflow.Models.ObjectDetection.Protos.DetectionModel(); + } + input.ReadMessage(model_); + break; + } + case 18: { + if (trainConfig_ == null) { + trainConfig_ = new global::Tensorflow.Models.ObjectDetection.Protos.TrainConfig(); + } + input.ReadMessage(trainConfig_); + break; + } + case 26: { + if (trainInputReader_ == null) { + trainInputReader_ = new global::Tensorflow.Models.ObjectDetection.Protos.InputReader(); + } + input.ReadMessage(trainInputReader_); + break; + } + case 34: { + if (evalConfig_ == null) { + evalConfig_ = new global::Tensorflow.Models.ObjectDetection.Protos.EvalConfig(); + } + input.ReadMessage(evalConfig_); + break; + } + case 42: { + evalInputReader_.AddEntriesFrom(input, _repeated_evalInputReader_codec); + break; + } + case 50: { + if (graphRewriter_ == null) { + graphRewriter_ = new global::Tensorflow.Models.ObjectDetection.Protos.GraphRewriter(); + } + input.ReadMessage(graphRewriter_); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/PostProcessing.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/PostProcessing.cs new file mode 100644 index 00000000..708364dd --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/PostProcessing.cs @@ -0,0 +1,685 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/post_processing.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/post_processing.proto + public static partial class PostProcessingReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/post_processing.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static PostProcessingReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci1vYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9wb3N0X3Byb2Nlc3NpbmcucHJv", + "dG8SF29iamVjdF9kZXRlY3Rpb24ucHJvdG9zGilvYmplY3RfZGV0ZWN0aW9u", + "L3Byb3Rvcy9jYWxpYnJhdGlvbi5wcm90byL+AQoWQmF0Y2hOb25NYXhTdXBw", + "cmVzc2lvbhIXCg9zY29yZV90aHJlc2hvbGQYASABKAISFQoNaW91X3RocmVz", + "aG9sZBgCIAEoAhIgChhtYXhfZGV0ZWN0aW9uc19wZXJfY2xhc3MYAyABKAUS", + "HAoUbWF4X3RvdGFsX2RldGVjdGlvbnMYBSABKAUSGQoRdXNlX3N0YXRpY19z", + "aGFwZXMYBiABKAgSHgoWdXNlX2NsYXNzX2Fnbm9zdGljX25tcxgHIAEoCBIh", + "ChltYXhfY2xhc3Nlc19wZXJfZGV0ZWN0aW9uGAggASgFEhYKDnNvZnRfbm1z", + "X3NpZ21hGAkgASgCIswCCg5Qb3N0UHJvY2Vzc2luZxJSChliYXRjaF9ub25f", + "bWF4X3N1cHByZXNzaW9uGAEgASgLMi8ub2JqZWN0X2RldGVjdGlvbi5wcm90", + "b3MuQmF0Y2hOb25NYXhTdXBwcmVzc2lvbhJPCg9zY29yZV9jb252ZXJ0ZXIY", + "AiABKA4yNi5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5Qb3N0UHJvY2Vzc2lu", + "Zy5TY29yZUNvbnZlcnRlchITCgtsb2dpdF9zY2FsZRgDIAEoAhJGChJjYWxp", + "YnJhdGlvbl9jb25maWcYBCABKAsyKi5vYmplY3RfZGV0ZWN0aW9uLnByb3Rv", + "cy5DYWxpYnJhdGlvbkNvbmZpZyI4Cg5TY29yZUNvbnZlcnRlchIMCghJREVO", + "VElUWRAAEgsKB1NJR01PSUQQARILCgdTT0ZUTUFYEAJiBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.CalibrationReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.BatchNonMaxSuppression), global::Tensorflow.Models.ObjectDetection.Protos.BatchNonMaxSuppression.Parser, new[]{ "ScoreThreshold", "IouThreshold", "MaxDetectionsPerClass", "MaxTotalDetections", "UseStaticShapes", "UseClassAgnosticNms", "MaxClassesPerDetection", "SoftNmsSigma" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing), global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing.Parser, new[]{ "BatchNonMaxSuppression", "ScoreConverter", "LogitScale", "CalibrationConfig" }, null, new[]{ typeof(global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing.Types.ScoreConverter) }, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for non-max-suppression operation on a batch of + /// detections. + /// + public sealed partial class BatchNonMaxSuppression : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new BatchNonMaxSuppression()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PostProcessingReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BatchNonMaxSuppression() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BatchNonMaxSuppression(BatchNonMaxSuppression other) : this() { + scoreThreshold_ = other.scoreThreshold_; + iouThreshold_ = other.iouThreshold_; + maxDetectionsPerClass_ = other.maxDetectionsPerClass_; + maxTotalDetections_ = other.maxTotalDetections_; + useStaticShapes_ = other.useStaticShapes_; + useClassAgnosticNms_ = other.useClassAgnosticNms_; + maxClassesPerDetection_ = other.maxClassesPerDetection_; + softNmsSigma_ = other.softNmsSigma_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public BatchNonMaxSuppression Clone() { + return new BatchNonMaxSuppression(this); + } + + /// Field number for the "score_threshold" field. + public const int ScoreThresholdFieldNumber = 1; + private float scoreThreshold_; + /// + /// Scalar threshold for score (low scoring boxes are removed). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float ScoreThreshold { + get { return scoreThreshold_; } + set { + scoreThreshold_ = value; + } + } + + /// Field number for the "iou_threshold" field. + public const int IouThresholdFieldNumber = 2; + private float iouThreshold_; + /// + /// Scalar threshold for IOU (boxes that have high IOU overlap + /// with previously selected boxes are removed). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float IouThreshold { + get { return iouThreshold_; } + set { + iouThreshold_ = value; + } + } + + /// Field number for the "max_detections_per_class" field. + public const int MaxDetectionsPerClassFieldNumber = 3; + private int maxDetectionsPerClass_; + /// + /// Maximum number of detections to retain per class. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxDetectionsPerClass { + get { return maxDetectionsPerClass_; } + set { + maxDetectionsPerClass_ = value; + } + } + + /// Field number for the "max_total_detections" field. + public const int MaxTotalDetectionsFieldNumber = 5; + private int maxTotalDetections_; + /// + /// Maximum number of detections to retain across all classes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxTotalDetections { + get { return maxTotalDetections_; } + set { + maxTotalDetections_ = value; + } + } + + /// Field number for the "use_static_shapes" field. + public const int UseStaticShapesFieldNumber = 6; + private bool useStaticShapes_; + /// + /// Whether to use the implementation of NMS that guarantees static shapes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseStaticShapes { + get { return useStaticShapes_; } + set { + useStaticShapes_ = value; + } + } + + /// Field number for the "use_class_agnostic_nms" field. + public const int UseClassAgnosticNmsFieldNumber = 7; + private bool useClassAgnosticNms_; + /// + /// Whether to use class agnostic NMS. + /// Class-agnostic NMS function implements a class-agnostic version + /// of Non Maximal Suppression where if max_classes_per_detection=k, + /// 1) we keep the top-k scores for each detection and + /// 2) during NMS, each detection only uses the highest class score for sorting. + /// 3) Compared to regular NMS, the worst runtime of this version is O(N^2) + /// instead of O(KN^2) where N is the number of detections and K the number of + /// classes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseClassAgnosticNms { + get { return useClassAgnosticNms_; } + set { + useClassAgnosticNms_ = value; + } + } + + /// Field number for the "max_classes_per_detection" field. + public const int MaxClassesPerDetectionFieldNumber = 8; + private int maxClassesPerDetection_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxClassesPerDetection { + get { return maxClassesPerDetection_; } + set { + maxClassesPerDetection_ = value; + } + } + + /// Field number for the "soft_nms_sigma" field. + public const int SoftNmsSigmaFieldNumber = 9; + private float softNmsSigma_; + /// + /// Soft NMS sigma parameter; Bodla et al, https://arxiv.org/abs/1704.04503) + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float SoftNmsSigma { + get { return softNmsSigma_; } + set { + softNmsSigma_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as BatchNonMaxSuppression); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(BatchNonMaxSuppression other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(ScoreThreshold, other.ScoreThreshold)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(IouThreshold, other.IouThreshold)) return false; + if (MaxDetectionsPerClass != other.MaxDetectionsPerClass) return false; + if (MaxTotalDetections != other.MaxTotalDetections) return false; + if (UseStaticShapes != other.UseStaticShapes) return false; + if (UseClassAgnosticNms != other.UseClassAgnosticNms) return false; + if (MaxClassesPerDetection != other.MaxClassesPerDetection) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(SoftNmsSigma, other.SoftNmsSigma)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ScoreThreshold != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(ScoreThreshold); + if (IouThreshold != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(IouThreshold); + if (MaxDetectionsPerClass != 0) hash ^= MaxDetectionsPerClass.GetHashCode(); + if (MaxTotalDetections != 0) hash ^= MaxTotalDetections.GetHashCode(); + if (UseStaticShapes != false) hash ^= UseStaticShapes.GetHashCode(); + if (UseClassAgnosticNms != false) hash ^= UseClassAgnosticNms.GetHashCode(); + if (MaxClassesPerDetection != 0) hash ^= MaxClassesPerDetection.GetHashCode(); + if (SoftNmsSigma != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(SoftNmsSigma); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ScoreThreshold != 0F) { + output.WriteRawTag(13); + output.WriteFloat(ScoreThreshold); + } + if (IouThreshold != 0F) { + output.WriteRawTag(21); + output.WriteFloat(IouThreshold); + } + if (MaxDetectionsPerClass != 0) { + output.WriteRawTag(24); + output.WriteInt32(MaxDetectionsPerClass); + } + if (MaxTotalDetections != 0) { + output.WriteRawTag(40); + output.WriteInt32(MaxTotalDetections); + } + if (UseStaticShapes != false) { + output.WriteRawTag(48); + output.WriteBool(UseStaticShapes); + } + if (UseClassAgnosticNms != false) { + output.WriteRawTag(56); + output.WriteBool(UseClassAgnosticNms); + } + if (MaxClassesPerDetection != 0) { + output.WriteRawTag(64); + output.WriteInt32(MaxClassesPerDetection); + } + if (SoftNmsSigma != 0F) { + output.WriteRawTag(77); + output.WriteFloat(SoftNmsSigma); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ScoreThreshold != 0F) { + size += 1 + 4; + } + if (IouThreshold != 0F) { + size += 1 + 4; + } + if (MaxDetectionsPerClass != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxDetectionsPerClass); + } + if (MaxTotalDetections != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxTotalDetections); + } + if (UseStaticShapes != false) { + size += 1 + 1; + } + if (UseClassAgnosticNms != false) { + size += 1 + 1; + } + if (MaxClassesPerDetection != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxClassesPerDetection); + } + if (SoftNmsSigma != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(BatchNonMaxSuppression other) { + if (other == null) { + return; + } + if (other.ScoreThreshold != 0F) { + ScoreThreshold = other.ScoreThreshold; + } + if (other.IouThreshold != 0F) { + IouThreshold = other.IouThreshold; + } + if (other.MaxDetectionsPerClass != 0) { + MaxDetectionsPerClass = other.MaxDetectionsPerClass; + } + if (other.MaxTotalDetections != 0) { + MaxTotalDetections = other.MaxTotalDetections; + } + if (other.UseStaticShapes != false) { + UseStaticShapes = other.UseStaticShapes; + } + if (other.UseClassAgnosticNms != false) { + UseClassAgnosticNms = other.UseClassAgnosticNms; + } + if (other.MaxClassesPerDetection != 0) { + MaxClassesPerDetection = other.MaxClassesPerDetection; + } + if (other.SoftNmsSigma != 0F) { + SoftNmsSigma = other.SoftNmsSigma; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + ScoreThreshold = input.ReadFloat(); + break; + } + case 21: { + IouThreshold = input.ReadFloat(); + break; + } + case 24: { + MaxDetectionsPerClass = input.ReadInt32(); + break; + } + case 40: { + MaxTotalDetections = input.ReadInt32(); + break; + } + case 48: { + UseStaticShapes = input.ReadBool(); + break; + } + case 56: { + UseClassAgnosticNms = input.ReadBool(); + break; + } + case 64: { + MaxClassesPerDetection = input.ReadInt32(); + break; + } + case 77: { + SoftNmsSigma = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Configuration proto for post-processing predicted boxes and + /// scores. + /// + public sealed partial class PostProcessing : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PostProcessing()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PostProcessingReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public PostProcessing() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public PostProcessing(PostProcessing other) : this() { + batchNonMaxSuppression_ = other.batchNonMaxSuppression_ != null ? other.batchNonMaxSuppression_.Clone() : null; + scoreConverter_ = other.scoreConverter_; + logitScale_ = other.logitScale_; + calibrationConfig_ = other.calibrationConfig_ != null ? other.calibrationConfig_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public PostProcessing Clone() { + return new PostProcessing(this); + } + + /// Field number for the "batch_non_max_suppression" field. + public const int BatchNonMaxSuppressionFieldNumber = 1; + private global::Tensorflow.Models.ObjectDetection.Protos.BatchNonMaxSuppression batchNonMaxSuppression_; + /// + /// Non max suppression parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.BatchNonMaxSuppression BatchNonMaxSuppression { + get { return batchNonMaxSuppression_; } + set { + batchNonMaxSuppression_ = value; + } + } + + /// Field number for the "score_converter" field. + public const int ScoreConverterFieldNumber = 2; + private global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing.Types.ScoreConverter scoreConverter_ = 0; + /// + /// Score converter to use. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing.Types.ScoreConverter ScoreConverter { + get { return scoreConverter_; } + set { + scoreConverter_ = value; + } + } + + /// Field number for the "logit_scale" field. + public const int LogitScaleFieldNumber = 3; + private float logitScale_; + /// + /// Scale logit (input) value before conversion in post-processing step. + /// Typically used for softmax distillation, though can be used to scale for + /// other reasons. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float LogitScale { + get { return logitScale_; } + set { + logitScale_ = value; + } + } + + /// Field number for the "calibration_config" field. + public const int CalibrationConfigFieldNumber = 4; + private global::Tensorflow.Models.ObjectDetection.Protos.CalibrationConfig calibrationConfig_; + /// + /// Calibrate score outputs. Calibration is applied after score converter + /// and before non max suppression. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.CalibrationConfig CalibrationConfig { + get { return calibrationConfig_; } + set { + calibrationConfig_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as PostProcessing); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(PostProcessing other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(BatchNonMaxSuppression, other.BatchNonMaxSuppression)) return false; + if (ScoreConverter != other.ScoreConverter) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(LogitScale, other.LogitScale)) return false; + if (!object.Equals(CalibrationConfig, other.CalibrationConfig)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (batchNonMaxSuppression_ != null) hash ^= BatchNonMaxSuppression.GetHashCode(); + if (ScoreConverter != 0) hash ^= ScoreConverter.GetHashCode(); + if (LogitScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(LogitScale); + if (calibrationConfig_ != null) hash ^= CalibrationConfig.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (batchNonMaxSuppression_ != null) { + output.WriteRawTag(10); + output.WriteMessage(BatchNonMaxSuppression); + } + if (ScoreConverter != 0) { + output.WriteRawTag(16); + output.WriteEnum((int) ScoreConverter); + } + if (LogitScale != 0F) { + output.WriteRawTag(29); + output.WriteFloat(LogitScale); + } + if (calibrationConfig_ != null) { + output.WriteRawTag(34); + output.WriteMessage(CalibrationConfig); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (batchNonMaxSuppression_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BatchNonMaxSuppression); + } + if (ScoreConverter != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) ScoreConverter); + } + if (LogitScale != 0F) { + size += 1 + 4; + } + if (calibrationConfig_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CalibrationConfig); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(PostProcessing other) { + if (other == null) { + return; + } + if (other.batchNonMaxSuppression_ != null) { + if (batchNonMaxSuppression_ == null) { + batchNonMaxSuppression_ = new global::Tensorflow.Models.ObjectDetection.Protos.BatchNonMaxSuppression(); + } + BatchNonMaxSuppression.MergeFrom(other.BatchNonMaxSuppression); + } + if (other.ScoreConverter != 0) { + ScoreConverter = other.ScoreConverter; + } + if (other.LogitScale != 0F) { + LogitScale = other.LogitScale; + } + if (other.calibrationConfig_ != null) { + if (calibrationConfig_ == null) { + calibrationConfig_ = new global::Tensorflow.Models.ObjectDetection.Protos.CalibrationConfig(); + } + CalibrationConfig.MergeFrom(other.CalibrationConfig); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + if (batchNonMaxSuppression_ == null) { + batchNonMaxSuppression_ = new global::Tensorflow.Models.ObjectDetection.Protos.BatchNonMaxSuppression(); + } + input.ReadMessage(batchNonMaxSuppression_); + break; + } + case 16: { + scoreConverter_ = (global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing.Types.ScoreConverter) input.ReadEnum(); + break; + } + case 29: { + LogitScale = input.ReadFloat(); + break; + } + case 34: { + if (calibrationConfig_ == null) { + calibrationConfig_ = new global::Tensorflow.Models.ObjectDetection.Protos.CalibrationConfig(); + } + input.ReadMessage(calibrationConfig_); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the PostProcessing message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// Enum to specify how to convert the detection scores. + /// + public enum ScoreConverter { + /// + /// Input scores equals output scores. + /// + [pbr::OriginalName("IDENTITY")] Identity = 0, + /// + /// Applies a sigmoid on input scores. + /// + [pbr::OriginalName("SIGMOID")] Sigmoid = 1, + /// + /// Applies a softmax on input scores + /// + [pbr::OriginalName("SOFTMAX")] Softmax = 2, + } + + } + #endregion + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Preprocessor.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Preprocessor.cs new file mode 100644 index 00000000..7412fb52 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Preprocessor.cs @@ -0,0 +1,8697 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/preprocessor.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/preprocessor.proto + public static partial class PreprocessorReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/preprocessor.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static PreprocessorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CipvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9wcmVwcm9jZXNzb3IucHJvdG8S", + "F29iamVjdF9kZXRlY3Rpb24ucHJvdG9zItgUChFQcmVwcm9jZXNzaW5nU3Rl", + "cBJCCg9ub3JtYWxpemVfaW1hZ2UYASABKAsyJy5vYmplY3RfZGV0ZWN0aW9u", + "LnByb3Rvcy5Ob3JtYWxpemVJbWFnZUgAEk8KFnJhbmRvbV9ob3Jpem9udGFs", + "X2ZsaXAYAiABKAsyLS5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5SYW5kb21I", + "b3Jpem9udGFsRmxpcEgAElIKGHJhbmRvbV9waXhlbF92YWx1ZV9zY2FsZRgD", + "IAEoCzIuLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlJhbmRvbVBpeGVsVmFs", + "dWVTY2FsZUgAEkcKEnJhbmRvbV9pbWFnZV9zY2FsZRgEIAEoCzIpLm9iamVj", + "dF9kZXRlY3Rpb24ucHJvdG9zLlJhbmRvbUltYWdlU2NhbGVIABJGChJyYW5k", + "b21fcmdiX3RvX2dyYXkYBSABKAsyKC5vYmplY3RfZGV0ZWN0aW9uLnByb3Rv", + "cy5SYW5kb21SR0J0b0dyYXlIABJTChhyYW5kb21fYWRqdXN0X2JyaWdodG5l", + "c3MYBiABKAsyLy5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5SYW5kb21BZGp1", + "c3RCcmlnaHRuZXNzSAASTwoWcmFuZG9tX2FkanVzdF9jb250cmFzdBgHIAEo", + "CzItLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlJhbmRvbUFkanVzdENvbnRy", + "YXN0SAASRQoRcmFuZG9tX2FkanVzdF9odWUYCCABKAsyKC5vYmplY3RfZGV0", + "ZWN0aW9uLnByb3Rvcy5SYW5kb21BZGp1c3RIdWVIABJTChhyYW5kb21fYWRq", + "dXN0X3NhdHVyYXRpb24YCSABKAsyLy5vYmplY3RfZGV0ZWN0aW9uLnByb3Rv", + "cy5SYW5kb21BZGp1c3RTYXR1cmF0aW9uSAASSwoUcmFuZG9tX2Rpc3RvcnRf", + "Y29sb3IYCiABKAsyKy5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5SYW5kb21E", + "aXN0b3J0Q29sb3JIABJJChNyYW5kb21faml0dGVyX2JveGVzGAsgASgLMiou", + "b2JqZWN0X2RldGVjdGlvbi5wcm90b3MuUmFuZG9tSml0dGVyQm94ZXNIABJF", + "ChFyYW5kb21fY3JvcF9pbWFnZRgMIAEoCzIoLm9iamVjdF9kZXRlY3Rpb24u", + "cHJvdG9zLlJhbmRvbUNyb3BJbWFnZUgAEkMKEHJhbmRvbV9wYWRfaW1hZ2UY", + "DSABKAsyJy5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5SYW5kb21QYWRJbWFn", + "ZUgAEkwKFXJhbmRvbV9jcm9wX3BhZF9pbWFnZRgOIAEoCzIrLm9iamVjdF9k", + "ZXRlY3Rpb24ucHJvdG9zLlJhbmRvbUNyb3BQYWRJbWFnZUgAElcKG3JhbmRv", + "bV9jcm9wX3RvX2FzcGVjdF9yYXRpbxgPIAEoCzIwLm9iamVjdF9kZXRlY3Rp", + "b24ucHJvdG9zLlJhbmRvbUNyb3BUb0FzcGVjdFJhdGlvSAASSwoUcmFuZG9t", + "X2JsYWNrX3BhdGNoZXMYECABKAsyKy5vYmplY3RfZGV0ZWN0aW9uLnByb3Rv", + "cy5SYW5kb21CbGFja1BhdGNoZXNIABJLChRyYW5kb21fcmVzaXplX21ldGhv", + "ZBgRIAEoCzIrLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlJhbmRvbVJlc2l6", + "ZU1ldGhvZEgAEmEKIHNjYWxlX2JveGVzX3RvX3BpeGVsX2Nvb3JkaW5hdGVz", + "GBIgASgLMjUub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuU2NhbGVCb3hlc1Rv", + "UGl4ZWxDb29yZGluYXRlc0gAEjwKDHJlc2l6ZV9pbWFnZRgTIAEoCzIkLm9i", + "amVjdF9kZXRlY3Rpb24ucHJvdG9zLlJlc2l6ZUltYWdlSAASTQoVc3VidHJh", + "Y3RfY2hhbm5lbF9tZWFuGBQgASgLMiwub2JqZWN0X2RldGVjdGlvbi5wcm90", + "b3MuU3VidHJhY3RDaGFubmVsTWVhbkgAEkEKD3NzZF9yYW5kb21fY3JvcBgV", + "IAEoCzImLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlNTRFJhbmRvbUNyb3BI", + "ABJIChNzc2RfcmFuZG9tX2Nyb3BfcGFkGBYgASgLMikub2JqZWN0X2RldGVj", + "dGlvbi5wcm90b3MuU1NEUmFuZG9tQ3JvcFBhZEgAEmQKInNzZF9yYW5kb21f", + "Y3JvcF9maXhlZF9hc3BlY3RfcmF0aW8YFyABKAsyNi5vYmplY3RfZGV0ZWN0", + "aW9uLnByb3Rvcy5TU0RSYW5kb21Dcm9wRml4ZWRBc3BlY3RSYXRpb0gAEmsK", + "JnNzZF9yYW5kb21fY3JvcF9wYWRfZml4ZWRfYXNwZWN0X3JhdGlvGBggASgL", + "Mjkub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuU1NEUmFuZG9tQ3JvcFBhZEZp", + "eGVkQXNwZWN0UmF0aW9IABJLChRyYW5kb21fdmVydGljYWxfZmxpcBgZIAEo", + "CzIrLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlJhbmRvbVZlcnRpY2FsRmxp", + "cEgAEkYKEXJhbmRvbV9yb3RhdGlvbjkwGBogASgLMikub2JqZWN0X2RldGVj", + "dGlvbi5wcm90b3MuUmFuZG9tUm90YXRpb245MEgAEjkKC3JnYl90b19ncmF5", + "GBsgASgLMiIub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuUkdCdG9HcmF5SAAS", + "XwofY29udmVydF9jbGFzc19sb2dpdHNfdG9fc29mdG1heBgcIAEoCzI0Lm9i", + "amVjdF9kZXRlY3Rpb24ucHJvdG9zLkNvbnZlcnRDbGFzc0xvZ2l0c1RvU29m", + "dG1heEgAElQKGXJhbmRvbV9hYnNvbHV0ZV9wYWRfaW1hZ2UYHSABKAsyLy5v", + "YmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5SYW5kb21BYnNvbHV0ZVBhZEltYWdl", + "SAASUgoYcmFuZG9tX3NlbGZfY29uY2F0X2ltYWdlGB4gASgLMi4ub2JqZWN0", + "X2RldGVjdGlvbi5wcm90b3MuUmFuZG9tU2VsZkNvbmNhdEltYWdlSAASRgoR", + "YXV0b2F1Z21lbnRfaW1hZ2UYHyABKAsyKS5vYmplY3RfZGV0ZWN0aW9uLnBy", + "b3Rvcy5BdXRvQXVnbWVudEltYWdlSAASWwocZHJvcF9sYWJlbF9wcm9iYWJp", + "bGlzdGljYWxseRggIAEoCzIzLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLkRy", + "b3BMYWJlbFByb2JhYmlsaXN0aWNhbGx5SAASPAoMcmVtYXBfbGFiZWxzGCEg", + "ASgLMiQub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuUmVtYXBMYWJlbHNIAEIU", + "ChJwcmVwcm9jZXNzaW5nX3N0ZXAicAoOTm9ybWFsaXplSW1hZ2USFwoPb3Jp", + "Z2luYWxfbWludmFsGAEgASgCEhcKD29yaWdpbmFsX21heHZhbBgCIAEoAhIV", + "Cg10YXJnZXRfbWludmFsGAMgASgCEhUKDXRhcmdldF9tYXh2YWwYBCABKAIi", + "OQoUUmFuZG9tSG9yaXpvbnRhbEZsaXASIQoZa2V5cG9pbnRfZmxpcF9wZXJt", + "dXRhdGlvbhgBIAMoBSI3ChJSYW5kb21WZXJ0aWNhbEZsaXASIQoZa2V5cG9p", + "bnRfZmxpcF9wZXJtdXRhdGlvbhgBIAMoBSISChBSYW5kb21Sb3RhdGlvbjkw", + "IjcKFVJhbmRvbVBpeGVsVmFsdWVTY2FsZRIOCgZtaW52YWwYASABKAISDgoG", + "bWF4dmFsGAIgASgCIkQKEFJhbmRvbUltYWdlU2NhbGUSFwoPbWluX3NjYWxl", + "X3JhdGlvGAEgASgCEhcKD21heF9zY2FsZV9yYXRpbxgCIAEoAiImCg9SYW5k", + "b21SR0J0b0dyYXkSEwoLcHJvYmFiaWxpdHkYASABKAIiKwoWUmFuZG9tQWRq", + "dXN0QnJpZ2h0bmVzcxIRCgltYXhfZGVsdGEYASABKAIiPAoUUmFuZG9tQWRq", + "dXN0Q29udHJhc3QSEQoJbWluX2RlbHRhGAEgASgCEhEKCW1heF9kZWx0YRgC", + "IAEoAiIkCg9SYW5kb21BZGp1c3RIdWUSEQoJbWF4X2RlbHRhGAEgASgCIj4K", + "FlJhbmRvbUFkanVzdFNhdHVyYXRpb24SEQoJbWluX2RlbHRhGAEgASgCEhEK", + "CW1heF9kZWx0YRgCIAEoAiIsChJSYW5kb21EaXN0b3J0Q29sb3ISFgoOY29s", + "b3Jfb3JkZXJpbmcYASABKAUiIgoRUmFuZG9tSml0dGVyQm94ZXMSDQoFcmF0", + "aW8YASABKAIixgEKD1JhbmRvbUNyb3BJbWFnZRIaChJtaW5fb2JqZWN0X2Nv", + "dmVyZWQYASABKAISGAoQbWluX2FzcGVjdF9yYXRpbxgCIAEoAhIYChBtYXhf", + "YXNwZWN0X3JhdGlvGAMgASgCEhAKCG1pbl9hcmVhGAQgASgCEhAKCG1heF9h", + "cmVhGAUgASgCEhYKDm92ZXJsYXBfdGhyZXNoGAYgASgCEhIKCmNsaXBfYm94", + "ZXMYCCABKAgSEwoLcmFuZG9tX2NvZWYYByABKAIiiQEKDlJhbmRvbVBhZElt", + "YWdlEhgKEG1pbl9pbWFnZV9oZWlnaHQYASABKAUSFwoPbWluX2ltYWdlX3dp", + "ZHRoGAIgASgFEhgKEG1heF9pbWFnZV9oZWlnaHQYAyABKAUSFwoPbWF4X2lt", + "YWdlX3dpZHRoGAQgASgFEhEKCXBhZF9jb2xvchgFIAMoAiJiChZSYW5kb21B", + "YnNvbHV0ZVBhZEltYWdlEhoKEm1heF9oZWlnaHRfcGFkZGluZxgBIAEoBRIZ", + "ChFtYXhfd2lkdGhfcGFkZGluZxgCIAEoBRIRCglwYWRfY29sb3IYAyADKAIi", + "mgIKElJhbmRvbUNyb3BQYWRJbWFnZRIaChJtaW5fb2JqZWN0X2NvdmVyZWQY", + "ASABKAISGAoQbWluX2FzcGVjdF9yYXRpbxgCIAEoAhIYChBtYXhfYXNwZWN0", + "X3JhdGlvGAMgASgCEhAKCG1pbl9hcmVhGAQgASgCEhAKCG1heF9hcmVhGAUg", + "ASgCEhYKDm92ZXJsYXBfdGhyZXNoGAYgASgCEhIKCmNsaXBfYm94ZXMYCyAB", + "KAgSEwoLcmFuZG9tX2NvZWYYByABKAISHQoVbWluX3BhZGRlZF9zaXplX3Jh", + "dGlvGAggAygCEh0KFW1heF9wYWRkZWRfc2l6ZV9yYXRpbxgJIAMoAhIRCglw", + "YWRfY29sb3IYCiADKAIiWwoXUmFuZG9tQ3JvcFRvQXNwZWN0UmF0aW8SFAoM", + "YXNwZWN0X3JhdGlvGAEgASgCEhYKDm92ZXJsYXBfdGhyZXNoGAIgASgCEhIK", + "CmNsaXBfYm94ZXMYAyABKAgiYQoSUmFuZG9tQmxhY2tQYXRjaGVzEhkKEW1h", + "eF9ibGFja19wYXRjaGVzGAEgASgFEhMKC3Byb2JhYmlsaXR5GAIgASgCEhsK", + "E3NpemVfdG9faW1hZ2VfcmF0aW8YAyABKAIiQQoSUmFuZG9tUmVzaXplTWV0", + "aG9kEhUKDXRhcmdldF9oZWlnaHQYASABKAUSFAoMdGFyZ2V0X3dpZHRoGAIg", + "ASgFIgsKCVJHQnRvR3JheSIeChxTY2FsZUJveGVzVG9QaXhlbENvb3JkaW5h", + "dGVzIsABCgtSZXNpemVJbWFnZRISCgpuZXdfaGVpZ2h0GAEgASgFEhEKCW5l", + "d193aWR0aBgCIAEoBRI7CgZtZXRob2QYAyABKA4yKy5vYmplY3RfZGV0ZWN0", + "aW9uLnByb3Rvcy5SZXNpemVJbWFnZS5NZXRob2QiTQoGTWV0aG9kEggKBE5P", + "TkUQABIICgRBUkVBEAESCwoHQklDVUJJQxACEgwKCEJJTElORUFSEAMSFAoQ", + "TkVBUkVTVF9ORUlHSEJPUhAEIiQKE1N1YnRyYWN0Q2hhbm5lbE1lYW4SDQoF", + "bWVhbnMYASADKAIizQEKFlNTRFJhbmRvbUNyb3BPcGVyYXRpb24SGgoSbWlu", + "X29iamVjdF9jb3ZlcmVkGAEgASgCEhgKEG1pbl9hc3BlY3RfcmF0aW8YAiAB", + "KAISGAoQbWF4X2FzcGVjdF9yYXRpbxgDIAEoAhIQCghtaW5fYXJlYRgEIAEo", + "AhIQCghtYXhfYXJlYRgFIAEoAhIWCg5vdmVybGFwX3RocmVzaBgGIAEoAhIS", + "CgpjbGlwX2JveGVzGAggASgIEhMKC3JhbmRvbV9jb2VmGAcgASgCIlQKDVNT", + "RFJhbmRvbUNyb3ASQwoKb3BlcmF0aW9ucxgBIAMoCzIvLm9iamVjdF9kZXRl", + "Y3Rpb24ucHJvdG9zLlNTRFJhbmRvbUNyb3BPcGVyYXRpb24izQIKGVNTRFJh", + "bmRvbUNyb3BQYWRPcGVyYXRpb24SGgoSbWluX29iamVjdF9jb3ZlcmVkGAEg", + "ASgCEhgKEG1pbl9hc3BlY3RfcmF0aW8YAiABKAISGAoQbWF4X2FzcGVjdF9y", + "YXRpbxgDIAEoAhIQCghtaW5fYXJlYRgEIAEoAhIQCghtYXhfYXJlYRgFIAEo", + "AhIWCg5vdmVybGFwX3RocmVzaBgGIAEoAhISCgpjbGlwX2JveGVzGA0gASgI", + "EhMKC3JhbmRvbV9jb2VmGAcgASgCEh0KFW1pbl9wYWRkZWRfc2l6ZV9yYXRp", + "bxgIIAMoAhIdChVtYXhfcGFkZGVkX3NpemVfcmF0aW8YCSADKAISEwoLcGFk", + "X2NvbG9yX3IYCiABKAISEwoLcGFkX2NvbG9yX2cYCyABKAISEwoLcGFkX2Nv", + "bG9yX2IYDCABKAIiWgoQU1NEUmFuZG9tQ3JvcFBhZBJGCgpvcGVyYXRpb25z", + "GAEgAygLMjIub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuU1NEUmFuZG9tQ3Jv", + "cFBhZE9wZXJhdGlvbiKpAQomU1NEUmFuZG9tQ3JvcEZpeGVkQXNwZWN0UmF0", + "aW9PcGVyYXRpb24SGgoSbWluX29iamVjdF9jb3ZlcmVkGAEgASgCEhAKCG1p", + "bl9hcmVhGAQgASgCEhAKCG1heF9hcmVhGAUgASgCEhYKDm92ZXJsYXBfdGhy", + "ZXNoGAYgASgCEhIKCmNsaXBfYm94ZXMYCCABKAgSEwoLcmFuZG9tX2NvZWYY", + "ByABKAIiigEKHVNTRFJhbmRvbUNyb3BGaXhlZEFzcGVjdFJhdGlvElMKCm9w", + "ZXJhdGlvbnMYASADKAsyPy5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5TU0RS", + "YW5kb21Dcm9wRml4ZWRBc3BlY3RSYXRpb09wZXJhdGlvbhIUCgxhc3BlY3Rf", + "cmF0aW8YAiABKAIi4AEKKVNTRFJhbmRvbUNyb3BQYWRGaXhlZEFzcGVjdFJh", + "dGlvT3BlcmF0aW9uEhoKEm1pbl9vYmplY3RfY292ZXJlZBgBIAEoAhIYChBt", + "aW5fYXNwZWN0X3JhdGlvGAIgASgCEhgKEG1heF9hc3BlY3RfcmF0aW8YAyAB", + "KAISEAoIbWluX2FyZWEYBCABKAISEAoIbWF4X2FyZWEYBSABKAISFgoOb3Zl", + "cmxhcF90aHJlc2gYBiABKAISEgoKY2xpcF9ib3hlcxgIIAEoCBITCgtyYW5k", + "b21fY29lZhgHIAEoAiLOAQogU1NEUmFuZG9tQ3JvcFBhZEZpeGVkQXNwZWN0", + "UmF0aW8SVgoKb3BlcmF0aW9ucxgBIAMoCzJCLm9iamVjdF9kZXRlY3Rpb24u", + "cHJvdG9zLlNTRFJhbmRvbUNyb3BQYWRGaXhlZEFzcGVjdFJhdGlvT3BlcmF0", + "aW9uEhQKDGFzcGVjdF9yYXRpbxgCIAEoAhIdChVtaW5fcGFkZGVkX3NpemVf", + "cmF0aW8YAyADKAISHQoVbWF4X3BhZGRlZF9zaXplX3JhdGlvGAQgAygCIjIK", + "G0NvbnZlcnRDbGFzc0xvZ2l0c1RvU29mdG1heBITCgt0ZW1wZXJhdHVyZRgB", + "IAEoAiJjChVSYW5kb21TZWxmQ29uY2F0SW1hZ2USIwobY29uY2F0X3ZlcnRp", + "Y2FsX3Byb2JhYmlsaXR5GAEgASgCEiUKHWNvbmNhdF9ob3Jpem9udGFsX3By", + "b2JhYmlsaXR5GAIgASgCIicKEEF1dG9BdWdtZW50SW1hZ2USEwoLcG9saWN5", + "X25hbWUYASABKAkiRQoaRHJvcExhYmVsUHJvYmFiaWxpc3RpY2FsbHkSDQoF", + "bGFiZWwYASABKAUSGAoQZHJvcF9wcm9iYWJpbGl0eRgCIAEoAiI5CgtSZW1h", + "cExhYmVscxIXCg9vcmlnaW5hbF9sYWJlbHMYASADKAUSEQoJbmV3X2xhYmVs", + "GAIgASgFYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.PreprocessingStep), global::Tensorflow.Models.ObjectDetection.Protos.PreprocessingStep.Parser, new[]{ "NormalizeImage", "RandomHorizontalFlip", "RandomPixelValueScale", "RandomImageScale", "RandomRgbToGray", "RandomAdjustBrightness", "RandomAdjustContrast", "RandomAdjustHue", "RandomAdjustSaturation", "RandomDistortColor", "RandomJitterBoxes", "RandomCropImage", "RandomPadImage", "RandomCropPadImage", "RandomCropToAspectRatio", "RandomBlackPatches", "RandomResizeMethod", "ScaleBoxesToPixelCoordinates", "ResizeImage", "SubtractChannelMean", "SsdRandomCrop", "SsdRandomCropPad", "SsdRandomCropFixedAspectRatio", "SsdRandomCropPadFixedAspectRatio", "RandomVerticalFlip", "RandomRotation90", "RgbToGray", "ConvertClassLogitsToSoftmax", "RandomAbsolutePadImage", "RandomSelfConcatImage", "AutoaugmentImage", "DropLabelProbabilistically", "RemapLabels" }, new[]{ "PreprocessingStep" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.NormalizeImage), global::Tensorflow.Models.ObjectDetection.Protos.NormalizeImage.Parser, new[]{ "OriginalMinval", "OriginalMaxval", "TargetMinval", "TargetMaxval" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomHorizontalFlip), global::Tensorflow.Models.ObjectDetection.Protos.RandomHorizontalFlip.Parser, new[]{ "KeypointFlipPermutation" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomVerticalFlip), global::Tensorflow.Models.ObjectDetection.Protos.RandomVerticalFlip.Parser, new[]{ "KeypointFlipPermutation" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomRotation90), global::Tensorflow.Models.ObjectDetection.Protos.RandomRotation90.Parser, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomPixelValueScale), global::Tensorflow.Models.ObjectDetection.Protos.RandomPixelValueScale.Parser, new[]{ "Minval", "Maxval" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomImageScale), global::Tensorflow.Models.ObjectDetection.Protos.RandomImageScale.Parser, new[]{ "MinScaleRatio", "MaxScaleRatio" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomRGBtoGray), global::Tensorflow.Models.ObjectDetection.Protos.RandomRGBtoGray.Parser, new[]{ "Probability" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustBrightness), global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustBrightness.Parser, new[]{ "MaxDelta" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustContrast), global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustContrast.Parser, new[]{ "MinDelta", "MaxDelta" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustHue), global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustHue.Parser, new[]{ "MaxDelta" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustSaturation), global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustSaturation.Parser, new[]{ "MinDelta", "MaxDelta" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomDistortColor), global::Tensorflow.Models.ObjectDetection.Protos.RandomDistortColor.Parser, new[]{ "ColorOrdering" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomJitterBoxes), global::Tensorflow.Models.ObjectDetection.Protos.RandomJitterBoxes.Parser, new[]{ "Ratio" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomCropImage), global::Tensorflow.Models.ObjectDetection.Protos.RandomCropImage.Parser, new[]{ "MinObjectCovered", "MinAspectRatio", "MaxAspectRatio", "MinArea", "MaxArea", "OverlapThresh", "ClipBoxes", "RandomCoef" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomPadImage), global::Tensorflow.Models.ObjectDetection.Protos.RandomPadImage.Parser, new[]{ "MinImageHeight", "MinImageWidth", "MaxImageHeight", "MaxImageWidth", "PadColor" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomAbsolutePadImage), global::Tensorflow.Models.ObjectDetection.Protos.RandomAbsolutePadImage.Parser, new[]{ "MaxHeightPadding", "MaxWidthPadding", "PadColor" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomCropPadImage), global::Tensorflow.Models.ObjectDetection.Protos.RandomCropPadImage.Parser, new[]{ "MinObjectCovered", "MinAspectRatio", "MaxAspectRatio", "MinArea", "MaxArea", "OverlapThresh", "ClipBoxes", "RandomCoef", "MinPaddedSizeRatio", "MaxPaddedSizeRatio", "PadColor" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomCropToAspectRatio), global::Tensorflow.Models.ObjectDetection.Protos.RandomCropToAspectRatio.Parser, new[]{ "AspectRatio", "OverlapThresh", "ClipBoxes" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomBlackPatches), global::Tensorflow.Models.ObjectDetection.Protos.RandomBlackPatches.Parser, new[]{ "MaxBlackPatches", "Probability", "SizeToImageRatio" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomResizeMethod), global::Tensorflow.Models.ObjectDetection.Protos.RandomResizeMethod.Parser, new[]{ "TargetHeight", "TargetWidth" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RGBtoGray), global::Tensorflow.Models.ObjectDetection.Protos.RGBtoGray.Parser, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ScaleBoxesToPixelCoordinates), global::Tensorflow.Models.ObjectDetection.Protos.ScaleBoxesToPixelCoordinates.Parser, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage), global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage.Parser, new[]{ "NewHeight", "NewWidth", "Method" }, null, new[]{ typeof(global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage.Types.Method) }, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SubtractChannelMean), global::Tensorflow.Models.ObjectDetection.Protos.SubtractChannelMean.Parser, new[]{ "Means" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropOperation), global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropOperation.Parser, new[]{ "MinObjectCovered", "MinAspectRatio", "MaxAspectRatio", "MinArea", "MaxArea", "OverlapThresh", "ClipBoxes", "RandomCoef" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCrop), global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCrop.Parser, new[]{ "Operations" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadOperation), global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadOperation.Parser, new[]{ "MinObjectCovered", "MinAspectRatio", "MaxAspectRatio", "MinArea", "MaxArea", "OverlapThresh", "ClipBoxes", "RandomCoef", "MinPaddedSizeRatio", "MaxPaddedSizeRatio", "PadColorR", "PadColorG", "PadColorB" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPad), global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPad.Parser, new[]{ "Operations" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatioOperation), global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatioOperation.Parser, new[]{ "MinObjectCovered", "MinArea", "MaxArea", "OverlapThresh", "ClipBoxes", "RandomCoef" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatio), global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatio.Parser, new[]{ "Operations", "AspectRatio" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatioOperation), global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatioOperation.Parser, new[]{ "MinObjectCovered", "MinAspectRatio", "MaxAspectRatio", "MinArea", "MaxArea", "OverlapThresh", "ClipBoxes", "RandomCoef" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatio), global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatio.Parser, new[]{ "Operations", "AspectRatio", "MinPaddedSizeRatio", "MaxPaddedSizeRatio" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ConvertClassLogitsToSoftmax), global::Tensorflow.Models.ObjectDetection.Protos.ConvertClassLogitsToSoftmax.Parser, new[]{ "Temperature" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RandomSelfConcatImage), global::Tensorflow.Models.ObjectDetection.Protos.RandomSelfConcatImage.Parser, new[]{ "ConcatVerticalProbability", "ConcatHorizontalProbability" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.AutoAugmentImage), global::Tensorflow.Models.ObjectDetection.Protos.AutoAugmentImage.Parser, new[]{ "PolicyName" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.DropLabelProbabilistically), global::Tensorflow.Models.ObjectDetection.Protos.DropLabelProbabilistically.Parser, new[]{ "Label", "DropProbability" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RemapLabels), global::Tensorflow.Models.ObjectDetection.Protos.RemapLabels.Parser, new[]{ "OriginalLabels", "NewLabel" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Message for defining a preprocessing operation on input data. + /// See: //third_party/tensorflow_models/object_detection/core/preprocessor.py + /// + public sealed partial class PreprocessingStep : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PreprocessingStep()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public PreprocessingStep() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public PreprocessingStep(PreprocessingStep other) : this() { + switch (other.PreprocessingStepCase) { + case PreprocessingStepOneofCase.NormalizeImage: + NormalizeImage = other.NormalizeImage.Clone(); + break; + case PreprocessingStepOneofCase.RandomHorizontalFlip: + RandomHorizontalFlip = other.RandomHorizontalFlip.Clone(); + break; + case PreprocessingStepOneofCase.RandomPixelValueScale: + RandomPixelValueScale = other.RandomPixelValueScale.Clone(); + break; + case PreprocessingStepOneofCase.RandomImageScale: + RandomImageScale = other.RandomImageScale.Clone(); + break; + case PreprocessingStepOneofCase.RandomRgbToGray: + RandomRgbToGray = other.RandomRgbToGray.Clone(); + break; + case PreprocessingStepOneofCase.RandomAdjustBrightness: + RandomAdjustBrightness = other.RandomAdjustBrightness.Clone(); + break; + case PreprocessingStepOneofCase.RandomAdjustContrast: + RandomAdjustContrast = other.RandomAdjustContrast.Clone(); + break; + case PreprocessingStepOneofCase.RandomAdjustHue: + RandomAdjustHue = other.RandomAdjustHue.Clone(); + break; + case PreprocessingStepOneofCase.RandomAdjustSaturation: + RandomAdjustSaturation = other.RandomAdjustSaturation.Clone(); + break; + case PreprocessingStepOneofCase.RandomDistortColor: + RandomDistortColor = other.RandomDistortColor.Clone(); + break; + case PreprocessingStepOneofCase.RandomJitterBoxes: + RandomJitterBoxes = other.RandomJitterBoxes.Clone(); + break; + case PreprocessingStepOneofCase.RandomCropImage: + RandomCropImage = other.RandomCropImage.Clone(); + break; + case PreprocessingStepOneofCase.RandomPadImage: + RandomPadImage = other.RandomPadImage.Clone(); + break; + case PreprocessingStepOneofCase.RandomCropPadImage: + RandomCropPadImage = other.RandomCropPadImage.Clone(); + break; + case PreprocessingStepOneofCase.RandomCropToAspectRatio: + RandomCropToAspectRatio = other.RandomCropToAspectRatio.Clone(); + break; + case PreprocessingStepOneofCase.RandomBlackPatches: + RandomBlackPatches = other.RandomBlackPatches.Clone(); + break; + case PreprocessingStepOneofCase.RandomResizeMethod: + RandomResizeMethod = other.RandomResizeMethod.Clone(); + break; + case PreprocessingStepOneofCase.ScaleBoxesToPixelCoordinates: + ScaleBoxesToPixelCoordinates = other.ScaleBoxesToPixelCoordinates.Clone(); + break; + case PreprocessingStepOneofCase.ResizeImage: + ResizeImage = other.ResizeImage.Clone(); + break; + case PreprocessingStepOneofCase.SubtractChannelMean: + SubtractChannelMean = other.SubtractChannelMean.Clone(); + break; + case PreprocessingStepOneofCase.SsdRandomCrop: + SsdRandomCrop = other.SsdRandomCrop.Clone(); + break; + case PreprocessingStepOneofCase.SsdRandomCropPad: + SsdRandomCropPad = other.SsdRandomCropPad.Clone(); + break; + case PreprocessingStepOneofCase.SsdRandomCropFixedAspectRatio: + SsdRandomCropFixedAspectRatio = other.SsdRandomCropFixedAspectRatio.Clone(); + break; + case PreprocessingStepOneofCase.SsdRandomCropPadFixedAspectRatio: + SsdRandomCropPadFixedAspectRatio = other.SsdRandomCropPadFixedAspectRatio.Clone(); + break; + case PreprocessingStepOneofCase.RandomVerticalFlip: + RandomVerticalFlip = other.RandomVerticalFlip.Clone(); + break; + case PreprocessingStepOneofCase.RandomRotation90: + RandomRotation90 = other.RandomRotation90.Clone(); + break; + case PreprocessingStepOneofCase.RgbToGray: + RgbToGray = other.RgbToGray.Clone(); + break; + case PreprocessingStepOneofCase.ConvertClassLogitsToSoftmax: + ConvertClassLogitsToSoftmax = other.ConvertClassLogitsToSoftmax.Clone(); + break; + case PreprocessingStepOneofCase.RandomAbsolutePadImage: + RandomAbsolutePadImage = other.RandomAbsolutePadImage.Clone(); + break; + case PreprocessingStepOneofCase.RandomSelfConcatImage: + RandomSelfConcatImage = other.RandomSelfConcatImage.Clone(); + break; + case PreprocessingStepOneofCase.AutoaugmentImage: + AutoaugmentImage = other.AutoaugmentImage.Clone(); + break; + case PreprocessingStepOneofCase.DropLabelProbabilistically: + DropLabelProbabilistically = other.DropLabelProbabilistically.Clone(); + break; + case PreprocessingStepOneofCase.RemapLabels: + RemapLabels = other.RemapLabels.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public PreprocessingStep Clone() { + return new PreprocessingStep(this); + } + + /// Field number for the "normalize_image" field. + public const int NormalizeImageFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.NormalizeImage NormalizeImage { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.NormalizeImage ? (global::Tensorflow.Models.ObjectDetection.Protos.NormalizeImage) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.NormalizeImage; + } + } + + /// Field number for the "random_horizontal_flip" field. + public const int RandomHorizontalFlipFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomHorizontalFlip RandomHorizontalFlip { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomHorizontalFlip ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomHorizontalFlip) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomHorizontalFlip; + } + } + + /// Field number for the "random_pixel_value_scale" field. + public const int RandomPixelValueScaleFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomPixelValueScale RandomPixelValueScale { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPixelValueScale ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomPixelValueScale) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomPixelValueScale; + } + } + + /// Field number for the "random_image_scale" field. + public const int RandomImageScaleFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomImageScale RandomImageScale { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomImageScale ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomImageScale) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomImageScale; + } + } + + /// Field number for the "random_rgb_to_gray" field. + public const int RandomRgbToGrayFieldNumber = 5; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomRGBtoGray RandomRgbToGray { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRgbToGray ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomRGBtoGray) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomRgbToGray; + } + } + + /// Field number for the "random_adjust_brightness" field. + public const int RandomAdjustBrightnessFieldNumber = 6; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustBrightness RandomAdjustBrightness { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustBrightness ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustBrightness) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomAdjustBrightness; + } + } + + /// Field number for the "random_adjust_contrast" field. + public const int RandomAdjustContrastFieldNumber = 7; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustContrast RandomAdjustContrast { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustContrast ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustContrast) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomAdjustContrast; + } + } + + /// Field number for the "random_adjust_hue" field. + public const int RandomAdjustHueFieldNumber = 8; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustHue RandomAdjustHue { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustHue ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustHue) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomAdjustHue; + } + } + + /// Field number for the "random_adjust_saturation" field. + public const int RandomAdjustSaturationFieldNumber = 9; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustSaturation RandomAdjustSaturation { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustSaturation ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustSaturation) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomAdjustSaturation; + } + } + + /// Field number for the "random_distort_color" field. + public const int RandomDistortColorFieldNumber = 10; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomDistortColor RandomDistortColor { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomDistortColor ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomDistortColor) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomDistortColor; + } + } + + /// Field number for the "random_jitter_boxes" field. + public const int RandomJitterBoxesFieldNumber = 11; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomJitterBoxes RandomJitterBoxes { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomJitterBoxes ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomJitterBoxes) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomJitterBoxes; + } + } + + /// Field number for the "random_crop_image" field. + public const int RandomCropImageFieldNumber = 12; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomCropImage RandomCropImage { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropImage ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomCropImage) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomCropImage; + } + } + + /// Field number for the "random_pad_image" field. + public const int RandomPadImageFieldNumber = 13; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomPadImage RandomPadImage { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPadImage ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomPadImage) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomPadImage; + } + } + + /// Field number for the "random_crop_pad_image" field. + public const int RandomCropPadImageFieldNumber = 14; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomCropPadImage RandomCropPadImage { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropPadImage ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomCropPadImage) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomCropPadImage; + } + } + + /// Field number for the "random_crop_to_aspect_ratio" field. + public const int RandomCropToAspectRatioFieldNumber = 15; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomCropToAspectRatio RandomCropToAspectRatio { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropToAspectRatio ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomCropToAspectRatio) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomCropToAspectRatio; + } + } + + /// Field number for the "random_black_patches" field. + public const int RandomBlackPatchesFieldNumber = 16; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomBlackPatches RandomBlackPatches { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomBlackPatches ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomBlackPatches) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomBlackPatches; + } + } + + /// Field number for the "random_resize_method" field. + public const int RandomResizeMethodFieldNumber = 17; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomResizeMethod RandomResizeMethod { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomResizeMethod ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomResizeMethod) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomResizeMethod; + } + } + + /// Field number for the "scale_boxes_to_pixel_coordinates" field. + public const int ScaleBoxesToPixelCoordinatesFieldNumber = 18; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ScaleBoxesToPixelCoordinates ScaleBoxesToPixelCoordinates { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.ScaleBoxesToPixelCoordinates ? (global::Tensorflow.Models.ObjectDetection.Protos.ScaleBoxesToPixelCoordinates) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.ScaleBoxesToPixelCoordinates; + } + } + + /// Field number for the "resize_image" field. + public const int ResizeImageFieldNumber = 19; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage ResizeImage { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.ResizeImage ? (global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.ResizeImage; + } + } + + /// Field number for the "subtract_channel_mean" field. + public const int SubtractChannelMeanFieldNumber = 20; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SubtractChannelMean SubtractChannelMean { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.SubtractChannelMean ? (global::Tensorflow.Models.ObjectDetection.Protos.SubtractChannelMean) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.SubtractChannelMean; + } + } + + /// Field number for the "ssd_random_crop" field. + public const int SsdRandomCropFieldNumber = 21; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCrop SsdRandomCrop { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCrop ? (global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCrop) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.SsdRandomCrop; + } + } + + /// Field number for the "ssd_random_crop_pad" field. + public const int SsdRandomCropPadFieldNumber = 22; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPad SsdRandomCropPad { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPad ? (global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPad) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.SsdRandomCropPad; + } + } + + /// Field number for the "ssd_random_crop_fixed_aspect_ratio" field. + public const int SsdRandomCropFixedAspectRatioFieldNumber = 23; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatio SsdRandomCropFixedAspectRatio { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropFixedAspectRatio ? (global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatio) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.SsdRandomCropFixedAspectRatio; + } + } + + /// Field number for the "ssd_random_crop_pad_fixed_aspect_ratio" field. + public const int SsdRandomCropPadFixedAspectRatioFieldNumber = 24; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatio SsdRandomCropPadFixedAspectRatio { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPadFixedAspectRatio ? (global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatio) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.SsdRandomCropPadFixedAspectRatio; + } + } + + /// Field number for the "random_vertical_flip" field. + public const int RandomVerticalFlipFieldNumber = 25; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomVerticalFlip RandomVerticalFlip { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomVerticalFlip ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomVerticalFlip) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomVerticalFlip; + } + } + + /// Field number for the "random_rotation90" field. + public const int RandomRotation90FieldNumber = 26; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomRotation90 RandomRotation90 { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRotation90 ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomRotation90) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomRotation90; + } + } + + /// Field number for the "rgb_to_gray" field. + public const int RgbToGrayFieldNumber = 27; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RGBtoGray RgbToGray { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RgbToGray ? (global::Tensorflow.Models.ObjectDetection.Protos.RGBtoGray) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RgbToGray; + } + } + + /// Field number for the "convert_class_logits_to_softmax" field. + public const int ConvertClassLogitsToSoftmaxFieldNumber = 28; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ConvertClassLogitsToSoftmax ConvertClassLogitsToSoftmax { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.ConvertClassLogitsToSoftmax ? (global::Tensorflow.Models.ObjectDetection.Protos.ConvertClassLogitsToSoftmax) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.ConvertClassLogitsToSoftmax; + } + } + + /// Field number for the "random_absolute_pad_image" field. + public const int RandomAbsolutePadImageFieldNumber = 29; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomAbsolutePadImage RandomAbsolutePadImage { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAbsolutePadImage ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomAbsolutePadImage) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomAbsolutePadImage; + } + } + + /// Field number for the "random_self_concat_image" field. + public const int RandomSelfConcatImageFieldNumber = 30; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RandomSelfConcatImage RandomSelfConcatImage { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RandomSelfConcatImage ? (global::Tensorflow.Models.ObjectDetection.Protos.RandomSelfConcatImage) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RandomSelfConcatImage; + } + } + + /// Field number for the "autoaugment_image" field. + public const int AutoaugmentImageFieldNumber = 31; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.AutoAugmentImage AutoaugmentImage { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.AutoaugmentImage ? (global::Tensorflow.Models.ObjectDetection.Protos.AutoAugmentImage) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.AutoaugmentImage; + } + } + + /// Field number for the "drop_label_probabilistically" field. + public const int DropLabelProbabilisticallyFieldNumber = 32; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.DropLabelProbabilistically DropLabelProbabilistically { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.DropLabelProbabilistically ? (global::Tensorflow.Models.ObjectDetection.Protos.DropLabelProbabilistically) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.DropLabelProbabilistically; + } + } + + /// Field number for the "remap_labels" field. + public const int RemapLabelsFieldNumber = 33; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RemapLabels RemapLabels { + get { return preprocessingStepCase_ == PreprocessingStepOneofCase.RemapLabels ? (global::Tensorflow.Models.ObjectDetection.Protos.RemapLabels) preprocessingStep_ : null; } + set { + preprocessingStep_ = value; + preprocessingStepCase_ = value == null ? PreprocessingStepOneofCase.None : PreprocessingStepOneofCase.RemapLabels; + } + } + + private object preprocessingStep_; + /// Enum of possible cases for the "preprocessing_step" oneof. + public enum PreprocessingStepOneofCase { + None = 0, + NormalizeImage = 1, + RandomHorizontalFlip = 2, + RandomPixelValueScale = 3, + RandomImageScale = 4, + RandomRgbToGray = 5, + RandomAdjustBrightness = 6, + RandomAdjustContrast = 7, + RandomAdjustHue = 8, + RandomAdjustSaturation = 9, + RandomDistortColor = 10, + RandomJitterBoxes = 11, + RandomCropImage = 12, + RandomPadImage = 13, + RandomCropPadImage = 14, + RandomCropToAspectRatio = 15, + RandomBlackPatches = 16, + RandomResizeMethod = 17, + ScaleBoxesToPixelCoordinates = 18, + ResizeImage = 19, + SubtractChannelMean = 20, + SsdRandomCrop = 21, + SsdRandomCropPad = 22, + SsdRandomCropFixedAspectRatio = 23, + SsdRandomCropPadFixedAspectRatio = 24, + RandomVerticalFlip = 25, + RandomRotation90 = 26, + RgbToGray = 27, + ConvertClassLogitsToSoftmax = 28, + RandomAbsolutePadImage = 29, + RandomSelfConcatImage = 30, + AutoaugmentImage = 31, + DropLabelProbabilistically = 32, + RemapLabels = 33, + } + private PreprocessingStepOneofCase preprocessingStepCase_ = PreprocessingStepOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public PreprocessingStepOneofCase PreprocessingStepCase { + get { return preprocessingStepCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearPreprocessingStep() { + preprocessingStepCase_ = PreprocessingStepOneofCase.None; + preprocessingStep_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as PreprocessingStep); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(PreprocessingStep other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(NormalizeImage, other.NormalizeImage)) return false; + if (!object.Equals(RandomHorizontalFlip, other.RandomHorizontalFlip)) return false; + if (!object.Equals(RandomPixelValueScale, other.RandomPixelValueScale)) return false; + if (!object.Equals(RandomImageScale, other.RandomImageScale)) return false; + if (!object.Equals(RandomRgbToGray, other.RandomRgbToGray)) return false; + if (!object.Equals(RandomAdjustBrightness, other.RandomAdjustBrightness)) return false; + if (!object.Equals(RandomAdjustContrast, other.RandomAdjustContrast)) return false; + if (!object.Equals(RandomAdjustHue, other.RandomAdjustHue)) return false; + if (!object.Equals(RandomAdjustSaturation, other.RandomAdjustSaturation)) return false; + if (!object.Equals(RandomDistortColor, other.RandomDistortColor)) return false; + if (!object.Equals(RandomJitterBoxes, other.RandomJitterBoxes)) return false; + if (!object.Equals(RandomCropImage, other.RandomCropImage)) return false; + if (!object.Equals(RandomPadImage, other.RandomPadImage)) return false; + if (!object.Equals(RandomCropPadImage, other.RandomCropPadImage)) return false; + if (!object.Equals(RandomCropToAspectRatio, other.RandomCropToAspectRatio)) return false; + if (!object.Equals(RandomBlackPatches, other.RandomBlackPatches)) return false; + if (!object.Equals(RandomResizeMethod, other.RandomResizeMethod)) return false; + if (!object.Equals(ScaleBoxesToPixelCoordinates, other.ScaleBoxesToPixelCoordinates)) return false; + if (!object.Equals(ResizeImage, other.ResizeImage)) return false; + if (!object.Equals(SubtractChannelMean, other.SubtractChannelMean)) return false; + if (!object.Equals(SsdRandomCrop, other.SsdRandomCrop)) return false; + if (!object.Equals(SsdRandomCropPad, other.SsdRandomCropPad)) return false; + if (!object.Equals(SsdRandomCropFixedAspectRatio, other.SsdRandomCropFixedAspectRatio)) return false; + if (!object.Equals(SsdRandomCropPadFixedAspectRatio, other.SsdRandomCropPadFixedAspectRatio)) return false; + if (!object.Equals(RandomVerticalFlip, other.RandomVerticalFlip)) return false; + if (!object.Equals(RandomRotation90, other.RandomRotation90)) return false; + if (!object.Equals(RgbToGray, other.RgbToGray)) return false; + if (!object.Equals(ConvertClassLogitsToSoftmax, other.ConvertClassLogitsToSoftmax)) return false; + if (!object.Equals(RandomAbsolutePadImage, other.RandomAbsolutePadImage)) return false; + if (!object.Equals(RandomSelfConcatImage, other.RandomSelfConcatImage)) return false; + if (!object.Equals(AutoaugmentImage, other.AutoaugmentImage)) return false; + if (!object.Equals(DropLabelProbabilistically, other.DropLabelProbabilistically)) return false; + if (!object.Equals(RemapLabels, other.RemapLabels)) return false; + if (PreprocessingStepCase != other.PreprocessingStepCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (preprocessingStepCase_ == PreprocessingStepOneofCase.NormalizeImage) hash ^= NormalizeImage.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomHorizontalFlip) hash ^= RandomHorizontalFlip.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPixelValueScale) hash ^= RandomPixelValueScale.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomImageScale) hash ^= RandomImageScale.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRgbToGray) hash ^= RandomRgbToGray.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustBrightness) hash ^= RandomAdjustBrightness.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustContrast) hash ^= RandomAdjustContrast.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustHue) hash ^= RandomAdjustHue.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustSaturation) hash ^= RandomAdjustSaturation.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomDistortColor) hash ^= RandomDistortColor.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomJitterBoxes) hash ^= RandomJitterBoxes.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropImage) hash ^= RandomCropImage.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPadImage) hash ^= RandomPadImage.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropPadImage) hash ^= RandomCropPadImage.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropToAspectRatio) hash ^= RandomCropToAspectRatio.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomBlackPatches) hash ^= RandomBlackPatches.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomResizeMethod) hash ^= RandomResizeMethod.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ScaleBoxesToPixelCoordinates) hash ^= ScaleBoxesToPixelCoordinates.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ResizeImage) hash ^= ResizeImage.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SubtractChannelMean) hash ^= SubtractChannelMean.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCrop) hash ^= SsdRandomCrop.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPad) hash ^= SsdRandomCropPad.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropFixedAspectRatio) hash ^= SsdRandomCropFixedAspectRatio.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPadFixedAspectRatio) hash ^= SsdRandomCropPadFixedAspectRatio.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomVerticalFlip) hash ^= RandomVerticalFlip.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRotation90) hash ^= RandomRotation90.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RgbToGray) hash ^= RgbToGray.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ConvertClassLogitsToSoftmax) hash ^= ConvertClassLogitsToSoftmax.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAbsolutePadImage) hash ^= RandomAbsolutePadImage.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomSelfConcatImage) hash ^= RandomSelfConcatImage.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.AutoaugmentImage) hash ^= AutoaugmentImage.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.DropLabelProbabilistically) hash ^= DropLabelProbabilistically.GetHashCode(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RemapLabels) hash ^= RemapLabels.GetHashCode(); + hash ^= (int) preprocessingStepCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (preprocessingStepCase_ == PreprocessingStepOneofCase.NormalizeImage) { + output.WriteRawTag(10); + output.WriteMessage(NormalizeImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomHorizontalFlip) { + output.WriteRawTag(18); + output.WriteMessage(RandomHorizontalFlip); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPixelValueScale) { + output.WriteRawTag(26); + output.WriteMessage(RandomPixelValueScale); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomImageScale) { + output.WriteRawTag(34); + output.WriteMessage(RandomImageScale); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRgbToGray) { + output.WriteRawTag(42); + output.WriteMessage(RandomRgbToGray); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustBrightness) { + output.WriteRawTag(50); + output.WriteMessage(RandomAdjustBrightness); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustContrast) { + output.WriteRawTag(58); + output.WriteMessage(RandomAdjustContrast); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustHue) { + output.WriteRawTag(66); + output.WriteMessage(RandomAdjustHue); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustSaturation) { + output.WriteRawTag(74); + output.WriteMessage(RandomAdjustSaturation); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomDistortColor) { + output.WriteRawTag(82); + output.WriteMessage(RandomDistortColor); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomJitterBoxes) { + output.WriteRawTag(90); + output.WriteMessage(RandomJitterBoxes); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropImage) { + output.WriteRawTag(98); + output.WriteMessage(RandomCropImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPadImage) { + output.WriteRawTag(106); + output.WriteMessage(RandomPadImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropPadImage) { + output.WriteRawTag(114); + output.WriteMessage(RandomCropPadImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropToAspectRatio) { + output.WriteRawTag(122); + output.WriteMessage(RandomCropToAspectRatio); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomBlackPatches) { + output.WriteRawTag(130, 1); + output.WriteMessage(RandomBlackPatches); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomResizeMethod) { + output.WriteRawTag(138, 1); + output.WriteMessage(RandomResizeMethod); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ScaleBoxesToPixelCoordinates) { + output.WriteRawTag(146, 1); + output.WriteMessage(ScaleBoxesToPixelCoordinates); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ResizeImage) { + output.WriteRawTag(154, 1); + output.WriteMessage(ResizeImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SubtractChannelMean) { + output.WriteRawTag(162, 1); + output.WriteMessage(SubtractChannelMean); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCrop) { + output.WriteRawTag(170, 1); + output.WriteMessage(SsdRandomCrop); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPad) { + output.WriteRawTag(178, 1); + output.WriteMessage(SsdRandomCropPad); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropFixedAspectRatio) { + output.WriteRawTag(186, 1); + output.WriteMessage(SsdRandomCropFixedAspectRatio); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPadFixedAspectRatio) { + output.WriteRawTag(194, 1); + output.WriteMessage(SsdRandomCropPadFixedAspectRatio); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomVerticalFlip) { + output.WriteRawTag(202, 1); + output.WriteMessage(RandomVerticalFlip); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRotation90) { + output.WriteRawTag(210, 1); + output.WriteMessage(RandomRotation90); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RgbToGray) { + output.WriteRawTag(218, 1); + output.WriteMessage(RgbToGray); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ConvertClassLogitsToSoftmax) { + output.WriteRawTag(226, 1); + output.WriteMessage(ConvertClassLogitsToSoftmax); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAbsolutePadImage) { + output.WriteRawTag(234, 1); + output.WriteMessage(RandomAbsolutePadImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomSelfConcatImage) { + output.WriteRawTag(242, 1); + output.WriteMessage(RandomSelfConcatImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.AutoaugmentImage) { + output.WriteRawTag(250, 1); + output.WriteMessage(AutoaugmentImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.DropLabelProbabilistically) { + output.WriteRawTag(130, 2); + output.WriteMessage(DropLabelProbabilistically); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RemapLabels) { + output.WriteRawTag(138, 2); + output.WriteMessage(RemapLabels); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (preprocessingStepCase_ == PreprocessingStepOneofCase.NormalizeImage) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(NormalizeImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomHorizontalFlip) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomHorizontalFlip); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPixelValueScale) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomPixelValueScale); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomImageScale) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomImageScale); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRgbToGray) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomRgbToGray); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustBrightness) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomAdjustBrightness); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustContrast) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomAdjustContrast); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustHue) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomAdjustHue); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustSaturation) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomAdjustSaturation); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomDistortColor) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomDistortColor); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomJitterBoxes) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomJitterBoxes); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropImage) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomCropImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPadImage) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomPadImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropPadImage) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomCropPadImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropToAspectRatio) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(RandomCropToAspectRatio); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomBlackPatches) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(RandomBlackPatches); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomResizeMethod) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(RandomResizeMethod); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ScaleBoxesToPixelCoordinates) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ScaleBoxesToPixelCoordinates); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ResizeImage) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ResizeImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SubtractChannelMean) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(SubtractChannelMean); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCrop) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(SsdRandomCrop); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPad) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(SsdRandomCropPad); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropFixedAspectRatio) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(SsdRandomCropFixedAspectRatio); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPadFixedAspectRatio) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(SsdRandomCropPadFixedAspectRatio); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomVerticalFlip) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(RandomVerticalFlip); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRotation90) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(RandomRotation90); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RgbToGray) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(RgbToGray); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ConvertClassLogitsToSoftmax) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ConvertClassLogitsToSoftmax); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAbsolutePadImage) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(RandomAbsolutePadImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomSelfConcatImage) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(RandomSelfConcatImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.AutoaugmentImage) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(AutoaugmentImage); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.DropLabelProbabilistically) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(DropLabelProbabilistically); + } + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RemapLabels) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(RemapLabels); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(PreprocessingStep other) { + if (other == null) { + return; + } + switch (other.PreprocessingStepCase) { + case PreprocessingStepOneofCase.NormalizeImage: + if (NormalizeImage == null) { + NormalizeImage = new global::Tensorflow.Models.ObjectDetection.Protos.NormalizeImage(); + } + NormalizeImage.MergeFrom(other.NormalizeImage); + break; + case PreprocessingStepOneofCase.RandomHorizontalFlip: + if (RandomHorizontalFlip == null) { + RandomHorizontalFlip = new global::Tensorflow.Models.ObjectDetection.Protos.RandomHorizontalFlip(); + } + RandomHorizontalFlip.MergeFrom(other.RandomHorizontalFlip); + break; + case PreprocessingStepOneofCase.RandomPixelValueScale: + if (RandomPixelValueScale == null) { + RandomPixelValueScale = new global::Tensorflow.Models.ObjectDetection.Protos.RandomPixelValueScale(); + } + RandomPixelValueScale.MergeFrom(other.RandomPixelValueScale); + break; + case PreprocessingStepOneofCase.RandomImageScale: + if (RandomImageScale == null) { + RandomImageScale = new global::Tensorflow.Models.ObjectDetection.Protos.RandomImageScale(); + } + RandomImageScale.MergeFrom(other.RandomImageScale); + break; + case PreprocessingStepOneofCase.RandomRgbToGray: + if (RandomRgbToGray == null) { + RandomRgbToGray = new global::Tensorflow.Models.ObjectDetection.Protos.RandomRGBtoGray(); + } + RandomRgbToGray.MergeFrom(other.RandomRgbToGray); + break; + case PreprocessingStepOneofCase.RandomAdjustBrightness: + if (RandomAdjustBrightness == null) { + RandomAdjustBrightness = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustBrightness(); + } + RandomAdjustBrightness.MergeFrom(other.RandomAdjustBrightness); + break; + case PreprocessingStepOneofCase.RandomAdjustContrast: + if (RandomAdjustContrast == null) { + RandomAdjustContrast = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustContrast(); + } + RandomAdjustContrast.MergeFrom(other.RandomAdjustContrast); + break; + case PreprocessingStepOneofCase.RandomAdjustHue: + if (RandomAdjustHue == null) { + RandomAdjustHue = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustHue(); + } + RandomAdjustHue.MergeFrom(other.RandomAdjustHue); + break; + case PreprocessingStepOneofCase.RandomAdjustSaturation: + if (RandomAdjustSaturation == null) { + RandomAdjustSaturation = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustSaturation(); + } + RandomAdjustSaturation.MergeFrom(other.RandomAdjustSaturation); + break; + case PreprocessingStepOneofCase.RandomDistortColor: + if (RandomDistortColor == null) { + RandomDistortColor = new global::Tensorflow.Models.ObjectDetection.Protos.RandomDistortColor(); + } + RandomDistortColor.MergeFrom(other.RandomDistortColor); + break; + case PreprocessingStepOneofCase.RandomJitterBoxes: + if (RandomJitterBoxes == null) { + RandomJitterBoxes = new global::Tensorflow.Models.ObjectDetection.Protos.RandomJitterBoxes(); + } + RandomJitterBoxes.MergeFrom(other.RandomJitterBoxes); + break; + case PreprocessingStepOneofCase.RandomCropImage: + if (RandomCropImage == null) { + RandomCropImage = new global::Tensorflow.Models.ObjectDetection.Protos.RandomCropImage(); + } + RandomCropImage.MergeFrom(other.RandomCropImage); + break; + case PreprocessingStepOneofCase.RandomPadImage: + if (RandomPadImage == null) { + RandomPadImage = new global::Tensorflow.Models.ObjectDetection.Protos.RandomPadImage(); + } + RandomPadImage.MergeFrom(other.RandomPadImage); + break; + case PreprocessingStepOneofCase.RandomCropPadImage: + if (RandomCropPadImage == null) { + RandomCropPadImage = new global::Tensorflow.Models.ObjectDetection.Protos.RandomCropPadImage(); + } + RandomCropPadImage.MergeFrom(other.RandomCropPadImage); + break; + case PreprocessingStepOneofCase.RandomCropToAspectRatio: + if (RandomCropToAspectRatio == null) { + RandomCropToAspectRatio = new global::Tensorflow.Models.ObjectDetection.Protos.RandomCropToAspectRatio(); + } + RandomCropToAspectRatio.MergeFrom(other.RandomCropToAspectRatio); + break; + case PreprocessingStepOneofCase.RandomBlackPatches: + if (RandomBlackPatches == null) { + RandomBlackPatches = new global::Tensorflow.Models.ObjectDetection.Protos.RandomBlackPatches(); + } + RandomBlackPatches.MergeFrom(other.RandomBlackPatches); + break; + case PreprocessingStepOneofCase.RandomResizeMethod: + if (RandomResizeMethod == null) { + RandomResizeMethod = new global::Tensorflow.Models.ObjectDetection.Protos.RandomResizeMethod(); + } + RandomResizeMethod.MergeFrom(other.RandomResizeMethod); + break; + case PreprocessingStepOneofCase.ScaleBoxesToPixelCoordinates: + if (ScaleBoxesToPixelCoordinates == null) { + ScaleBoxesToPixelCoordinates = new global::Tensorflow.Models.ObjectDetection.Protos.ScaleBoxesToPixelCoordinates(); + } + ScaleBoxesToPixelCoordinates.MergeFrom(other.ScaleBoxesToPixelCoordinates); + break; + case PreprocessingStepOneofCase.ResizeImage: + if (ResizeImage == null) { + ResizeImage = new global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage(); + } + ResizeImage.MergeFrom(other.ResizeImage); + break; + case PreprocessingStepOneofCase.SubtractChannelMean: + if (SubtractChannelMean == null) { + SubtractChannelMean = new global::Tensorflow.Models.ObjectDetection.Protos.SubtractChannelMean(); + } + SubtractChannelMean.MergeFrom(other.SubtractChannelMean); + break; + case PreprocessingStepOneofCase.SsdRandomCrop: + if (SsdRandomCrop == null) { + SsdRandomCrop = new global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCrop(); + } + SsdRandomCrop.MergeFrom(other.SsdRandomCrop); + break; + case PreprocessingStepOneofCase.SsdRandomCropPad: + if (SsdRandomCropPad == null) { + SsdRandomCropPad = new global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPad(); + } + SsdRandomCropPad.MergeFrom(other.SsdRandomCropPad); + break; + case PreprocessingStepOneofCase.SsdRandomCropFixedAspectRatio: + if (SsdRandomCropFixedAspectRatio == null) { + SsdRandomCropFixedAspectRatio = new global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatio(); + } + SsdRandomCropFixedAspectRatio.MergeFrom(other.SsdRandomCropFixedAspectRatio); + break; + case PreprocessingStepOneofCase.SsdRandomCropPadFixedAspectRatio: + if (SsdRandomCropPadFixedAspectRatio == null) { + SsdRandomCropPadFixedAspectRatio = new global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatio(); + } + SsdRandomCropPadFixedAspectRatio.MergeFrom(other.SsdRandomCropPadFixedAspectRatio); + break; + case PreprocessingStepOneofCase.RandomVerticalFlip: + if (RandomVerticalFlip == null) { + RandomVerticalFlip = new global::Tensorflow.Models.ObjectDetection.Protos.RandomVerticalFlip(); + } + RandomVerticalFlip.MergeFrom(other.RandomVerticalFlip); + break; + case PreprocessingStepOneofCase.RandomRotation90: + if (RandomRotation90 == null) { + RandomRotation90 = new global::Tensorflow.Models.ObjectDetection.Protos.RandomRotation90(); + } + RandomRotation90.MergeFrom(other.RandomRotation90); + break; + case PreprocessingStepOneofCase.RgbToGray: + if (RgbToGray == null) { + RgbToGray = new global::Tensorflow.Models.ObjectDetection.Protos.RGBtoGray(); + } + RgbToGray.MergeFrom(other.RgbToGray); + break; + case PreprocessingStepOneofCase.ConvertClassLogitsToSoftmax: + if (ConvertClassLogitsToSoftmax == null) { + ConvertClassLogitsToSoftmax = new global::Tensorflow.Models.ObjectDetection.Protos.ConvertClassLogitsToSoftmax(); + } + ConvertClassLogitsToSoftmax.MergeFrom(other.ConvertClassLogitsToSoftmax); + break; + case PreprocessingStepOneofCase.RandomAbsolutePadImage: + if (RandomAbsolutePadImage == null) { + RandomAbsolutePadImage = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAbsolutePadImage(); + } + RandomAbsolutePadImage.MergeFrom(other.RandomAbsolutePadImage); + break; + case PreprocessingStepOneofCase.RandomSelfConcatImage: + if (RandomSelfConcatImage == null) { + RandomSelfConcatImage = new global::Tensorflow.Models.ObjectDetection.Protos.RandomSelfConcatImage(); + } + RandomSelfConcatImage.MergeFrom(other.RandomSelfConcatImage); + break; + case PreprocessingStepOneofCase.AutoaugmentImage: + if (AutoaugmentImage == null) { + AutoaugmentImage = new global::Tensorflow.Models.ObjectDetection.Protos.AutoAugmentImage(); + } + AutoaugmentImage.MergeFrom(other.AutoaugmentImage); + break; + case PreprocessingStepOneofCase.DropLabelProbabilistically: + if (DropLabelProbabilistically == null) { + DropLabelProbabilistically = new global::Tensorflow.Models.ObjectDetection.Protos.DropLabelProbabilistically(); + } + DropLabelProbabilistically.MergeFrom(other.DropLabelProbabilistically); + break; + case PreprocessingStepOneofCase.RemapLabels: + if (RemapLabels == null) { + RemapLabels = new global::Tensorflow.Models.ObjectDetection.Protos.RemapLabels(); + } + RemapLabels.MergeFrom(other.RemapLabels); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.NormalizeImage subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.NormalizeImage(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.NormalizeImage) { + subBuilder.MergeFrom(NormalizeImage); + } + input.ReadMessage(subBuilder); + NormalizeImage = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomHorizontalFlip subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomHorizontalFlip(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomHorizontalFlip) { + subBuilder.MergeFrom(RandomHorizontalFlip); + } + input.ReadMessage(subBuilder); + RandomHorizontalFlip = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomPixelValueScale subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomPixelValueScale(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPixelValueScale) { + subBuilder.MergeFrom(RandomPixelValueScale); + } + input.ReadMessage(subBuilder); + RandomPixelValueScale = subBuilder; + break; + } + case 34: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomImageScale subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomImageScale(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomImageScale) { + subBuilder.MergeFrom(RandomImageScale); + } + input.ReadMessage(subBuilder); + RandomImageScale = subBuilder; + break; + } + case 42: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomRGBtoGray subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomRGBtoGray(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRgbToGray) { + subBuilder.MergeFrom(RandomRgbToGray); + } + input.ReadMessage(subBuilder); + RandomRgbToGray = subBuilder; + break; + } + case 50: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustBrightness subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustBrightness(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustBrightness) { + subBuilder.MergeFrom(RandomAdjustBrightness); + } + input.ReadMessage(subBuilder); + RandomAdjustBrightness = subBuilder; + break; + } + case 58: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustContrast subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustContrast(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustContrast) { + subBuilder.MergeFrom(RandomAdjustContrast); + } + input.ReadMessage(subBuilder); + RandomAdjustContrast = subBuilder; + break; + } + case 66: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustHue subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustHue(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustHue) { + subBuilder.MergeFrom(RandomAdjustHue); + } + input.ReadMessage(subBuilder); + RandomAdjustHue = subBuilder; + break; + } + case 74: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustSaturation subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAdjustSaturation(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAdjustSaturation) { + subBuilder.MergeFrom(RandomAdjustSaturation); + } + input.ReadMessage(subBuilder); + RandomAdjustSaturation = subBuilder; + break; + } + case 82: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomDistortColor subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomDistortColor(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomDistortColor) { + subBuilder.MergeFrom(RandomDistortColor); + } + input.ReadMessage(subBuilder); + RandomDistortColor = subBuilder; + break; + } + case 90: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomJitterBoxes subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomJitterBoxes(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomJitterBoxes) { + subBuilder.MergeFrom(RandomJitterBoxes); + } + input.ReadMessage(subBuilder); + RandomJitterBoxes = subBuilder; + break; + } + case 98: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomCropImage subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomCropImage(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropImage) { + subBuilder.MergeFrom(RandomCropImage); + } + input.ReadMessage(subBuilder); + RandomCropImage = subBuilder; + break; + } + case 106: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomPadImage subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomPadImage(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomPadImage) { + subBuilder.MergeFrom(RandomPadImage); + } + input.ReadMessage(subBuilder); + RandomPadImage = subBuilder; + break; + } + case 114: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomCropPadImage subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomCropPadImage(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropPadImage) { + subBuilder.MergeFrom(RandomCropPadImage); + } + input.ReadMessage(subBuilder); + RandomCropPadImage = subBuilder; + break; + } + case 122: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomCropToAspectRatio subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomCropToAspectRatio(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomCropToAspectRatio) { + subBuilder.MergeFrom(RandomCropToAspectRatio); + } + input.ReadMessage(subBuilder); + RandomCropToAspectRatio = subBuilder; + break; + } + case 130: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomBlackPatches subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomBlackPatches(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomBlackPatches) { + subBuilder.MergeFrom(RandomBlackPatches); + } + input.ReadMessage(subBuilder); + RandomBlackPatches = subBuilder; + break; + } + case 138: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomResizeMethod subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomResizeMethod(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomResizeMethod) { + subBuilder.MergeFrom(RandomResizeMethod); + } + input.ReadMessage(subBuilder); + RandomResizeMethod = subBuilder; + break; + } + case 146: { + global::Tensorflow.Models.ObjectDetection.Protos.ScaleBoxesToPixelCoordinates subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ScaleBoxesToPixelCoordinates(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ScaleBoxesToPixelCoordinates) { + subBuilder.MergeFrom(ScaleBoxesToPixelCoordinates); + } + input.ReadMessage(subBuilder); + ScaleBoxesToPixelCoordinates = subBuilder; + break; + } + case 154: { + global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ResizeImage) { + subBuilder.MergeFrom(ResizeImage); + } + input.ReadMessage(subBuilder); + ResizeImage = subBuilder; + break; + } + case 162: { + global::Tensorflow.Models.ObjectDetection.Protos.SubtractChannelMean subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.SubtractChannelMean(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SubtractChannelMean) { + subBuilder.MergeFrom(SubtractChannelMean); + } + input.ReadMessage(subBuilder); + SubtractChannelMean = subBuilder; + break; + } + case 170: { + global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCrop subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCrop(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCrop) { + subBuilder.MergeFrom(SsdRandomCrop); + } + input.ReadMessage(subBuilder); + SsdRandomCrop = subBuilder; + break; + } + case 178: { + global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPad subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPad(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPad) { + subBuilder.MergeFrom(SsdRandomCropPad); + } + input.ReadMessage(subBuilder); + SsdRandomCropPad = subBuilder; + break; + } + case 186: { + global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatio subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatio(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropFixedAspectRatio) { + subBuilder.MergeFrom(SsdRandomCropFixedAspectRatio); + } + input.ReadMessage(subBuilder); + SsdRandomCropFixedAspectRatio = subBuilder; + break; + } + case 194: { + global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatio subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatio(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.SsdRandomCropPadFixedAspectRatio) { + subBuilder.MergeFrom(SsdRandomCropPadFixedAspectRatio); + } + input.ReadMessage(subBuilder); + SsdRandomCropPadFixedAspectRatio = subBuilder; + break; + } + case 202: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomVerticalFlip subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomVerticalFlip(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomVerticalFlip) { + subBuilder.MergeFrom(RandomVerticalFlip); + } + input.ReadMessage(subBuilder); + RandomVerticalFlip = subBuilder; + break; + } + case 210: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomRotation90 subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomRotation90(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomRotation90) { + subBuilder.MergeFrom(RandomRotation90); + } + input.ReadMessage(subBuilder); + RandomRotation90 = subBuilder; + break; + } + case 218: { + global::Tensorflow.Models.ObjectDetection.Protos.RGBtoGray subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RGBtoGray(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RgbToGray) { + subBuilder.MergeFrom(RgbToGray); + } + input.ReadMessage(subBuilder); + RgbToGray = subBuilder; + break; + } + case 226: { + global::Tensorflow.Models.ObjectDetection.Protos.ConvertClassLogitsToSoftmax subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ConvertClassLogitsToSoftmax(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.ConvertClassLogitsToSoftmax) { + subBuilder.MergeFrom(ConvertClassLogitsToSoftmax); + } + input.ReadMessage(subBuilder); + ConvertClassLogitsToSoftmax = subBuilder; + break; + } + case 234: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomAbsolutePadImage subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomAbsolutePadImage(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomAbsolutePadImage) { + subBuilder.MergeFrom(RandomAbsolutePadImage); + } + input.ReadMessage(subBuilder); + RandomAbsolutePadImage = subBuilder; + break; + } + case 242: { + global::Tensorflow.Models.ObjectDetection.Protos.RandomSelfConcatImage subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RandomSelfConcatImage(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RandomSelfConcatImage) { + subBuilder.MergeFrom(RandomSelfConcatImage); + } + input.ReadMessage(subBuilder); + RandomSelfConcatImage = subBuilder; + break; + } + case 250: { + global::Tensorflow.Models.ObjectDetection.Protos.AutoAugmentImage subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.AutoAugmentImage(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.AutoaugmentImage) { + subBuilder.MergeFrom(AutoaugmentImage); + } + input.ReadMessage(subBuilder); + AutoaugmentImage = subBuilder; + break; + } + case 258: { + global::Tensorflow.Models.ObjectDetection.Protos.DropLabelProbabilistically subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.DropLabelProbabilistically(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.DropLabelProbabilistically) { + subBuilder.MergeFrom(DropLabelProbabilistically); + } + input.ReadMessage(subBuilder); + DropLabelProbabilistically = subBuilder; + break; + } + case 266: { + global::Tensorflow.Models.ObjectDetection.Protos.RemapLabels subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.RemapLabels(); + if (preprocessingStepCase_ == PreprocessingStepOneofCase.RemapLabels) { + subBuilder.MergeFrom(RemapLabels); + } + input.ReadMessage(subBuilder); + RemapLabels = subBuilder; + break; + } + } + } + } + + } + + /// + /// Normalizes pixel values in an image. + /// For every channel in the image, moves the pixel values from the range + /// [original_minval, original_maxval] to [target_minval, target_maxval]. + /// + public sealed partial class NormalizeImage : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NormalizeImage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NormalizeImage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NormalizeImage(NormalizeImage other) : this() { + originalMinval_ = other.originalMinval_; + originalMaxval_ = other.originalMaxval_; + targetMinval_ = other.targetMinval_; + targetMaxval_ = other.targetMaxval_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NormalizeImage Clone() { + return new NormalizeImage(this); + } + + /// Field number for the "original_minval" field. + public const int OriginalMinvalFieldNumber = 1; + private float originalMinval_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float OriginalMinval { + get { return originalMinval_; } + set { + originalMinval_ = value; + } + } + + /// Field number for the "original_maxval" field. + public const int OriginalMaxvalFieldNumber = 2; + private float originalMaxval_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float OriginalMaxval { + get { return originalMaxval_; } + set { + originalMaxval_ = value; + } + } + + /// Field number for the "target_minval" field. + public const int TargetMinvalFieldNumber = 3; + private float targetMinval_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float TargetMinval { + get { return targetMinval_; } + set { + targetMinval_ = value; + } + } + + /// Field number for the "target_maxval" field. + public const int TargetMaxvalFieldNumber = 4; + private float targetMaxval_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float TargetMaxval { + get { return targetMaxval_; } + set { + targetMaxval_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as NormalizeImage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(NormalizeImage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(OriginalMinval, other.OriginalMinval)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(OriginalMaxval, other.OriginalMaxval)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(TargetMinval, other.TargetMinval)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(TargetMaxval, other.TargetMaxval)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (OriginalMinval != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(OriginalMinval); + if (OriginalMaxval != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(OriginalMaxval); + if (TargetMinval != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(TargetMinval); + if (TargetMaxval != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(TargetMaxval); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (OriginalMinval != 0F) { + output.WriteRawTag(13); + output.WriteFloat(OriginalMinval); + } + if (OriginalMaxval != 0F) { + output.WriteRawTag(21); + output.WriteFloat(OriginalMaxval); + } + if (TargetMinval != 0F) { + output.WriteRawTag(29); + output.WriteFloat(TargetMinval); + } + if (TargetMaxval != 0F) { + output.WriteRawTag(37); + output.WriteFloat(TargetMaxval); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (OriginalMinval != 0F) { + size += 1 + 4; + } + if (OriginalMaxval != 0F) { + size += 1 + 4; + } + if (TargetMinval != 0F) { + size += 1 + 4; + } + if (TargetMaxval != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(NormalizeImage other) { + if (other == null) { + return; + } + if (other.OriginalMinval != 0F) { + OriginalMinval = other.OriginalMinval; + } + if (other.OriginalMaxval != 0F) { + OriginalMaxval = other.OriginalMaxval; + } + if (other.TargetMinval != 0F) { + TargetMinval = other.TargetMinval; + } + if (other.TargetMaxval != 0F) { + TargetMaxval = other.TargetMaxval; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + OriginalMinval = input.ReadFloat(); + break; + } + case 21: { + OriginalMaxval = input.ReadFloat(); + break; + } + case 29: { + TargetMinval = input.ReadFloat(); + break; + } + case 37: { + TargetMaxval = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly horizontally flips the image and detections 50% of the time. + /// + public sealed partial class RandomHorizontalFlip : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomHorizontalFlip()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomHorizontalFlip() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomHorizontalFlip(RandomHorizontalFlip other) : this() { + keypointFlipPermutation_ = other.keypointFlipPermutation_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomHorizontalFlip Clone() { + return new RandomHorizontalFlip(this); + } + + /// Field number for the "keypoint_flip_permutation" field. + public const int KeypointFlipPermutationFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_keypointFlipPermutation_codec + = pb::FieldCodec.ForInt32(10); + private readonly pbc::RepeatedField keypointFlipPermutation_ = new pbc::RepeatedField(); + /// + /// Specifies a mapping from the original keypoint indices to horizontally + /// flipped indices. This is used in the event that keypoints are specified, + /// in which case when the image is horizontally flipped the keypoints will + /// need to be permuted. E.g. for keypoints representing left_eye, right_eye, + /// nose_tip, mouth, left_ear, right_ear (in that order), one might specify + /// the keypoint_flip_permutation below: + /// keypoint_flip_permutation: 1 + /// keypoint_flip_permutation: 0 + /// keypoint_flip_permutation: 2 + /// keypoint_flip_permutation: 3 + /// keypoint_flip_permutation: 5 + /// keypoint_flip_permutation: 4 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField KeypointFlipPermutation { + get { return keypointFlipPermutation_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomHorizontalFlip); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomHorizontalFlip other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!keypointFlipPermutation_.Equals(other.keypointFlipPermutation_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= keypointFlipPermutation_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + keypointFlipPermutation_.WriteTo(output, _repeated_keypointFlipPermutation_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += keypointFlipPermutation_.CalculateSize(_repeated_keypointFlipPermutation_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomHorizontalFlip other) { + if (other == null) { + return; + } + keypointFlipPermutation_.Add(other.keypointFlipPermutation_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + keypointFlipPermutation_.AddEntriesFrom(input, _repeated_keypointFlipPermutation_codec); + break; + } + } + } + } + + } + + /// + /// Randomly vertically flips the image and detections 50% of the time. + /// + public sealed partial class RandomVerticalFlip : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomVerticalFlip()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomVerticalFlip() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomVerticalFlip(RandomVerticalFlip other) : this() { + keypointFlipPermutation_ = other.keypointFlipPermutation_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomVerticalFlip Clone() { + return new RandomVerticalFlip(this); + } + + /// Field number for the "keypoint_flip_permutation" field. + public const int KeypointFlipPermutationFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_keypointFlipPermutation_codec + = pb::FieldCodec.ForInt32(10); + private readonly pbc::RepeatedField keypointFlipPermutation_ = new pbc::RepeatedField(); + /// + /// Specifies a mapping from the original keypoint indices to vertically + /// flipped indices. This is used in the event that keypoints are specified, + /// in which case when the image is vertically flipped the keypoints will + /// need to be permuted. E.g. for keypoints representing left_eye, right_eye, + /// nose_tip, mouth, left_ear, right_ear (in that order), one might specify + /// the keypoint_flip_permutation below: + /// keypoint_flip_permutation: 1 + /// keypoint_flip_permutation: 0 + /// keypoint_flip_permutation: 2 + /// keypoint_flip_permutation: 3 + /// keypoint_flip_permutation: 5 + /// keypoint_flip_permutation: 4 + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField KeypointFlipPermutation { + get { return keypointFlipPermutation_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomVerticalFlip); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomVerticalFlip other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!keypointFlipPermutation_.Equals(other.keypointFlipPermutation_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= keypointFlipPermutation_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + keypointFlipPermutation_.WriteTo(output, _repeated_keypointFlipPermutation_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += keypointFlipPermutation_.CalculateSize(_repeated_keypointFlipPermutation_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomVerticalFlip other) { + if (other == null) { + return; + } + keypointFlipPermutation_.Add(other.keypointFlipPermutation_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + keypointFlipPermutation_.AddEntriesFrom(input, _repeated_keypointFlipPermutation_codec); + break; + } + } + } + } + + } + + /// + /// Randomly rotates the image and detections by 90 degrees counter-clockwise + /// 50% of the time. + /// + public sealed partial class RandomRotation90 : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomRotation90()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomRotation90() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomRotation90(RandomRotation90 other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomRotation90 Clone() { + return new RandomRotation90(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomRotation90); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomRotation90 other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomRotation90 other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + /// + /// Randomly scales the values of all pixels in the image by some constant value + /// between [minval, maxval], then clip the value to a range between [0, 1.0]. + /// + public sealed partial class RandomPixelValueScale : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomPixelValueScale()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomPixelValueScale() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomPixelValueScale(RandomPixelValueScale other) : this() { + minval_ = other.minval_; + maxval_ = other.maxval_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomPixelValueScale Clone() { + return new RandomPixelValueScale(this); + } + + /// Field number for the "minval" field. + public const int MinvalFieldNumber = 1; + private float minval_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Minval { + get { return minval_; } + set { + minval_ = value; + } + } + + /// Field number for the "maxval" field. + public const int MaxvalFieldNumber = 2; + private float maxval_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Maxval { + get { return maxval_; } + set { + maxval_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomPixelValueScale); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomPixelValueScale other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Minval, other.Minval)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Maxval, other.Maxval)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Minval != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Minval); + if (Maxval != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Maxval); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Minval != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Minval); + } + if (Maxval != 0F) { + output.WriteRawTag(21); + output.WriteFloat(Maxval); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Minval != 0F) { + size += 1 + 4; + } + if (Maxval != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomPixelValueScale other) { + if (other == null) { + return; + } + if (other.Minval != 0F) { + Minval = other.Minval; + } + if (other.Maxval != 0F) { + Maxval = other.Maxval; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Minval = input.ReadFloat(); + break; + } + case 21: { + Maxval = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly enlarges or shrinks image (keeping aspect ratio). + /// + public sealed partial class RandomImageScale : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomImageScale()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[6]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomImageScale() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomImageScale(RandomImageScale other) : this() { + minScaleRatio_ = other.minScaleRatio_; + maxScaleRatio_ = other.maxScaleRatio_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomImageScale Clone() { + return new RandomImageScale(this); + } + + /// Field number for the "min_scale_ratio" field. + public const int MinScaleRatioFieldNumber = 1; + private float minScaleRatio_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinScaleRatio { + get { return minScaleRatio_; } + set { + minScaleRatio_ = value; + } + } + + /// Field number for the "max_scale_ratio" field. + public const int MaxScaleRatioFieldNumber = 2; + private float maxScaleRatio_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxScaleRatio { + get { return maxScaleRatio_; } + set { + maxScaleRatio_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomImageScale); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomImageScale other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinScaleRatio, other.MinScaleRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxScaleRatio, other.MaxScaleRatio)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinScaleRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinScaleRatio); + if (MaxScaleRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxScaleRatio); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinScaleRatio != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MinScaleRatio); + } + if (MaxScaleRatio != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MaxScaleRatio); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinScaleRatio != 0F) { + size += 1 + 4; + } + if (MaxScaleRatio != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomImageScale other) { + if (other == null) { + return; + } + if (other.MinScaleRatio != 0F) { + MinScaleRatio = other.MinScaleRatio; + } + if (other.MaxScaleRatio != 0F) { + MaxScaleRatio = other.MaxScaleRatio; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MinScaleRatio = input.ReadFloat(); + break; + } + case 21: { + MaxScaleRatio = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly convert entire image to grey scale. + /// + public sealed partial class RandomRGBtoGray : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomRGBtoGray()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[7]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomRGBtoGray() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomRGBtoGray(RandomRGBtoGray other) : this() { + probability_ = other.probability_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomRGBtoGray Clone() { + return new RandomRGBtoGray(this); + } + + /// Field number for the "probability" field. + public const int ProbabilityFieldNumber = 1; + private float probability_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Probability { + get { return probability_; } + set { + probability_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomRGBtoGray); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomRGBtoGray other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Probability, other.Probability)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Probability != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Probability); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Probability != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Probability); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Probability != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomRGBtoGray other) { + if (other == null) { + return; + } + if (other.Probability != 0F) { + Probability = other.Probability; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Probability = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly changes image brightness by up to max_delta. Image outputs will be + /// saturated between 0 and 1. + /// + public sealed partial class RandomAdjustBrightness : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomAdjustBrightness()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[8]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustBrightness() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustBrightness(RandomAdjustBrightness other) : this() { + maxDelta_ = other.maxDelta_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustBrightness Clone() { + return new RandomAdjustBrightness(this); + } + + /// Field number for the "max_delta" field. + public const int MaxDeltaFieldNumber = 1; + private float maxDelta_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxDelta { + get { return maxDelta_; } + set { + maxDelta_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomAdjustBrightness); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomAdjustBrightness other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxDelta, other.MaxDelta)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MaxDelta != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxDelta); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MaxDelta != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MaxDelta); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MaxDelta != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomAdjustBrightness other) { + if (other == null) { + return; + } + if (other.MaxDelta != 0F) { + MaxDelta = other.MaxDelta; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MaxDelta = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly scales contract by a value between [min_delta, max_delta]. + /// + public sealed partial class RandomAdjustContrast : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomAdjustContrast()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[9]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustContrast() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustContrast(RandomAdjustContrast other) : this() { + minDelta_ = other.minDelta_; + maxDelta_ = other.maxDelta_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustContrast Clone() { + return new RandomAdjustContrast(this); + } + + /// Field number for the "min_delta" field. + public const int MinDeltaFieldNumber = 1; + private float minDelta_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinDelta { + get { return minDelta_; } + set { + minDelta_ = value; + } + } + + /// Field number for the "max_delta" field. + public const int MaxDeltaFieldNumber = 2; + private float maxDelta_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxDelta { + get { return maxDelta_; } + set { + maxDelta_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomAdjustContrast); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomAdjustContrast other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinDelta, other.MinDelta)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxDelta, other.MaxDelta)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinDelta != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinDelta); + if (MaxDelta != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxDelta); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinDelta != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MinDelta); + } + if (MaxDelta != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MaxDelta); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinDelta != 0F) { + size += 1 + 4; + } + if (MaxDelta != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomAdjustContrast other) { + if (other == null) { + return; + } + if (other.MinDelta != 0F) { + MinDelta = other.MinDelta; + } + if (other.MaxDelta != 0F) { + MaxDelta = other.MaxDelta; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MinDelta = input.ReadFloat(); + break; + } + case 21: { + MaxDelta = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly alters hue by a value of up to max_delta. + /// + public sealed partial class RandomAdjustHue : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomAdjustHue()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[10]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustHue() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustHue(RandomAdjustHue other) : this() { + maxDelta_ = other.maxDelta_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustHue Clone() { + return new RandomAdjustHue(this); + } + + /// Field number for the "max_delta" field. + public const int MaxDeltaFieldNumber = 1; + private float maxDelta_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxDelta { + get { return maxDelta_; } + set { + maxDelta_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomAdjustHue); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomAdjustHue other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxDelta, other.MaxDelta)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MaxDelta != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxDelta); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MaxDelta != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MaxDelta); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MaxDelta != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomAdjustHue other) { + if (other == null) { + return; + } + if (other.MaxDelta != 0F) { + MaxDelta = other.MaxDelta; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MaxDelta = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly changes saturation by a value between [min_delta, max_delta]. + /// + public sealed partial class RandomAdjustSaturation : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomAdjustSaturation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[11]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustSaturation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustSaturation(RandomAdjustSaturation other) : this() { + minDelta_ = other.minDelta_; + maxDelta_ = other.maxDelta_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAdjustSaturation Clone() { + return new RandomAdjustSaturation(this); + } + + /// Field number for the "min_delta" field. + public const int MinDeltaFieldNumber = 1; + private float minDelta_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinDelta { + get { return minDelta_; } + set { + minDelta_ = value; + } + } + + /// Field number for the "max_delta" field. + public const int MaxDeltaFieldNumber = 2; + private float maxDelta_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxDelta { + get { return maxDelta_; } + set { + maxDelta_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomAdjustSaturation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomAdjustSaturation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinDelta, other.MinDelta)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxDelta, other.MaxDelta)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinDelta != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinDelta); + if (MaxDelta != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxDelta); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinDelta != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MinDelta); + } + if (MaxDelta != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MaxDelta); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinDelta != 0F) { + size += 1 + 4; + } + if (MaxDelta != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomAdjustSaturation other) { + if (other == null) { + return; + } + if (other.MinDelta != 0F) { + MinDelta = other.MinDelta; + } + if (other.MaxDelta != 0F) { + MaxDelta = other.MaxDelta; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MinDelta = input.ReadFloat(); + break; + } + case 21: { + MaxDelta = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Performs a random color distortion. color_orderings should either be 0 or 1. + /// + public sealed partial class RandomDistortColor : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomDistortColor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[12]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomDistortColor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomDistortColor(RandomDistortColor other) : this() { + colorOrdering_ = other.colorOrdering_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomDistortColor Clone() { + return new RandomDistortColor(this); + } + + /// Field number for the "color_ordering" field. + public const int ColorOrderingFieldNumber = 1; + private int colorOrdering_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int ColorOrdering { + get { return colorOrdering_; } + set { + colorOrdering_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomDistortColor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomDistortColor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ColorOrdering != other.ColorOrdering) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ColorOrdering != 0) hash ^= ColorOrdering.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ColorOrdering != 0) { + output.WriteRawTag(8); + output.WriteInt32(ColorOrdering); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ColorOrdering != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ColorOrdering); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomDistortColor other) { + if (other == null) { + return; + } + if (other.ColorOrdering != 0) { + ColorOrdering = other.ColorOrdering; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + ColorOrdering = input.ReadInt32(); + break; + } + } + } + } + + } + + /// + /// Randomly jitters corners of boxes in the image determined by ratio. + /// ie. If a box is [100, 200] and ratio is 0.02, the corners can move by [1, 4]. + /// + public sealed partial class RandomJitterBoxes : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomJitterBoxes()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[13]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomJitterBoxes() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomJitterBoxes(RandomJitterBoxes other) : this() { + ratio_ = other.ratio_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomJitterBoxes Clone() { + return new RandomJitterBoxes(this); + } + + /// Field number for the "ratio" field. + public const int RatioFieldNumber = 1; + private float ratio_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Ratio { + get { return ratio_; } + set { + ratio_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomJitterBoxes); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomJitterBoxes other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Ratio, other.Ratio)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Ratio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Ratio); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Ratio != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Ratio); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Ratio != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomJitterBoxes other) { + if (other == null) { + return; + } + if (other.Ratio != 0F) { + Ratio = other.Ratio; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Ratio = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly crops the image and bounding boxes. + /// + public sealed partial class RandomCropImage : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomCropImage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[14]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomCropImage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomCropImage(RandomCropImage other) : this() { + minObjectCovered_ = other.minObjectCovered_; + minAspectRatio_ = other.minAspectRatio_; + maxAspectRatio_ = other.maxAspectRatio_; + minArea_ = other.minArea_; + maxArea_ = other.maxArea_; + overlapThresh_ = other.overlapThresh_; + clipBoxes_ = other.clipBoxes_; + randomCoef_ = other.randomCoef_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomCropImage Clone() { + return new RandomCropImage(this); + } + + /// Field number for the "min_object_covered" field. + public const int MinObjectCoveredFieldNumber = 1; + private float minObjectCovered_; + /// + /// Cropped image must cover at least one box by this fraction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinObjectCovered { + get { return minObjectCovered_; } + set { + minObjectCovered_ = value; + } + } + + /// Field number for the "min_aspect_ratio" field. + public const int MinAspectRatioFieldNumber = 2; + private float minAspectRatio_; + /// + /// Aspect ratio bounds of cropped image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinAspectRatio { + get { return minAspectRatio_; } + set { + minAspectRatio_ = value; + } + } + + /// Field number for the "max_aspect_ratio" field. + public const int MaxAspectRatioFieldNumber = 3; + private float maxAspectRatio_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxAspectRatio { + get { return maxAspectRatio_; } + set { + maxAspectRatio_ = value; + } + } + + /// Field number for the "min_area" field. + public const int MinAreaFieldNumber = 4; + private float minArea_; + /// + /// Allowed area ratio of cropped image to original image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinArea { + get { return minArea_; } + set { + minArea_ = value; + } + } + + /// Field number for the "max_area" field. + public const int MaxAreaFieldNumber = 5; + private float maxArea_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxArea { + get { return maxArea_; } + set { + maxArea_ = value; + } + } + + /// Field number for the "overlap_thresh" field. + public const int OverlapThreshFieldNumber = 6; + private float overlapThresh_; + /// + /// Minimum overlap threshold of cropped boxes to keep in new image. If the + /// ratio between a cropped bounding box and the original is less than this + /// value, it is removed from the new image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float OverlapThresh { + get { return overlapThresh_; } + set { + overlapThresh_ = value; + } + } + + /// Field number for the "clip_boxes" field. + public const int ClipBoxesFieldNumber = 8; + private bool clipBoxes_; + /// + /// Whether to clip the boxes to the cropped image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ClipBoxes { + get { return clipBoxes_; } + set { + clipBoxes_ = value; + } + } + + /// Field number for the "random_coef" field. + public const int RandomCoefFieldNumber = 7; + private float randomCoef_; + /// + /// Probability of keeping the original image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float RandomCoef { + get { return randomCoef_; } + set { + randomCoef_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomCropImage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomCropImage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinObjectCovered, other.MinObjectCovered)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinAspectRatio, other.MinAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxAspectRatio, other.MaxAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinArea, other.MinArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxArea, other.MaxArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(OverlapThresh, other.OverlapThresh)) return false; + if (ClipBoxes != other.ClipBoxes) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(RandomCoef, other.RandomCoef)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinObjectCovered != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinObjectCovered); + if (MinAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinAspectRatio); + if (MaxAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxAspectRatio); + if (MinArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinArea); + if (MaxArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxArea); + if (OverlapThresh != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(OverlapThresh); + if (ClipBoxes != false) hash ^= ClipBoxes.GetHashCode(); + if (RandomCoef != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(RandomCoef); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinObjectCovered != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MinObjectCovered); + } + if (MinAspectRatio != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MinAspectRatio); + } + if (MaxAspectRatio != 0F) { + output.WriteRawTag(29); + output.WriteFloat(MaxAspectRatio); + } + if (MinArea != 0F) { + output.WriteRawTag(37); + output.WriteFloat(MinArea); + } + if (MaxArea != 0F) { + output.WriteRawTag(45); + output.WriteFloat(MaxArea); + } + if (OverlapThresh != 0F) { + output.WriteRawTag(53); + output.WriteFloat(OverlapThresh); + } + if (RandomCoef != 0F) { + output.WriteRawTag(61); + output.WriteFloat(RandomCoef); + } + if (ClipBoxes != false) { + output.WriteRawTag(64); + output.WriteBool(ClipBoxes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinObjectCovered != 0F) { + size += 1 + 4; + } + if (MinAspectRatio != 0F) { + size += 1 + 4; + } + if (MaxAspectRatio != 0F) { + size += 1 + 4; + } + if (MinArea != 0F) { + size += 1 + 4; + } + if (MaxArea != 0F) { + size += 1 + 4; + } + if (OverlapThresh != 0F) { + size += 1 + 4; + } + if (ClipBoxes != false) { + size += 1 + 1; + } + if (RandomCoef != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomCropImage other) { + if (other == null) { + return; + } + if (other.MinObjectCovered != 0F) { + MinObjectCovered = other.MinObjectCovered; + } + if (other.MinAspectRatio != 0F) { + MinAspectRatio = other.MinAspectRatio; + } + if (other.MaxAspectRatio != 0F) { + MaxAspectRatio = other.MaxAspectRatio; + } + if (other.MinArea != 0F) { + MinArea = other.MinArea; + } + if (other.MaxArea != 0F) { + MaxArea = other.MaxArea; + } + if (other.OverlapThresh != 0F) { + OverlapThresh = other.OverlapThresh; + } + if (other.ClipBoxes != false) { + ClipBoxes = other.ClipBoxes; + } + if (other.RandomCoef != 0F) { + RandomCoef = other.RandomCoef; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MinObjectCovered = input.ReadFloat(); + break; + } + case 21: { + MinAspectRatio = input.ReadFloat(); + break; + } + case 29: { + MaxAspectRatio = input.ReadFloat(); + break; + } + case 37: { + MinArea = input.ReadFloat(); + break; + } + case 45: { + MaxArea = input.ReadFloat(); + break; + } + case 53: { + OverlapThresh = input.ReadFloat(); + break; + } + case 61: { + RandomCoef = input.ReadFloat(); + break; + } + case 64: { + ClipBoxes = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Randomly adds padding to the image. + /// + public sealed partial class RandomPadImage : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomPadImage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[15]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomPadImage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomPadImage(RandomPadImage other) : this() { + minImageHeight_ = other.minImageHeight_; + minImageWidth_ = other.minImageWidth_; + maxImageHeight_ = other.maxImageHeight_; + maxImageWidth_ = other.maxImageWidth_; + padColor_ = other.padColor_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomPadImage Clone() { + return new RandomPadImage(this); + } + + /// Field number for the "min_image_height" field. + public const int MinImageHeightFieldNumber = 1; + private int minImageHeight_; + /// + /// Minimum dimensions for padded image. If unset, will use original image + /// dimension as a lower bound. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinImageHeight { + get { return minImageHeight_; } + set { + minImageHeight_ = value; + } + } + + /// Field number for the "min_image_width" field. + public const int MinImageWidthFieldNumber = 2; + private int minImageWidth_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinImageWidth { + get { return minImageWidth_; } + set { + minImageWidth_ = value; + } + } + + /// Field number for the "max_image_height" field. + public const int MaxImageHeightFieldNumber = 3; + private int maxImageHeight_; + /// + /// Maximum dimensions for padded image. If unset, will use double the original + /// image dimension as a lower bound. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxImageHeight { + get { return maxImageHeight_; } + set { + maxImageHeight_ = value; + } + } + + /// Field number for the "max_image_width" field. + public const int MaxImageWidthFieldNumber = 4; + private int maxImageWidth_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxImageWidth { + get { return maxImageWidth_; } + set { + maxImageWidth_ = value; + } + } + + /// Field number for the "pad_color" field. + public const int PadColorFieldNumber = 5; + private static readonly pb::FieldCodec _repeated_padColor_codec + = pb::FieldCodec.ForFloat(42); + private readonly pbc::RepeatedField padColor_ = new pbc::RepeatedField(); + /// + /// Color of the padding. If unset, will pad using average color of the input + /// image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField PadColor { + get { return padColor_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomPadImage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomPadImage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MinImageHeight != other.MinImageHeight) return false; + if (MinImageWidth != other.MinImageWidth) return false; + if (MaxImageHeight != other.MaxImageHeight) return false; + if (MaxImageWidth != other.MaxImageWidth) return false; + if(!padColor_.Equals(other.padColor_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinImageHeight != 0) hash ^= MinImageHeight.GetHashCode(); + if (MinImageWidth != 0) hash ^= MinImageWidth.GetHashCode(); + if (MaxImageHeight != 0) hash ^= MaxImageHeight.GetHashCode(); + if (MaxImageWidth != 0) hash ^= MaxImageWidth.GetHashCode(); + hash ^= padColor_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinImageHeight != 0) { + output.WriteRawTag(8); + output.WriteInt32(MinImageHeight); + } + if (MinImageWidth != 0) { + output.WriteRawTag(16); + output.WriteInt32(MinImageWidth); + } + if (MaxImageHeight != 0) { + output.WriteRawTag(24); + output.WriteInt32(MaxImageHeight); + } + if (MaxImageWidth != 0) { + output.WriteRawTag(32); + output.WriteInt32(MaxImageWidth); + } + padColor_.WriteTo(output, _repeated_padColor_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinImageHeight != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinImageHeight); + } + if (MinImageWidth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinImageWidth); + } + if (MaxImageHeight != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxImageHeight); + } + if (MaxImageWidth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxImageWidth); + } + size += padColor_.CalculateSize(_repeated_padColor_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomPadImage other) { + if (other == null) { + return; + } + if (other.MinImageHeight != 0) { + MinImageHeight = other.MinImageHeight; + } + if (other.MinImageWidth != 0) { + MinImageWidth = other.MinImageWidth; + } + if (other.MaxImageHeight != 0) { + MaxImageHeight = other.MaxImageHeight; + } + if (other.MaxImageWidth != 0) { + MaxImageWidth = other.MaxImageWidth; + } + padColor_.Add(other.padColor_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + MinImageHeight = input.ReadInt32(); + break; + } + case 16: { + MinImageWidth = input.ReadInt32(); + break; + } + case 24: { + MaxImageHeight = input.ReadInt32(); + break; + } + case 32: { + MaxImageWidth = input.ReadInt32(); + break; + } + case 42: + case 45: { + padColor_.AddEntriesFrom(input, _repeated_padColor_codec); + break; + } + } + } + } + + } + + /// + /// Randomly adds a padding of size [0, max_height_padding), [0, max_width_padding). + /// + public sealed partial class RandomAbsolutePadImage : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomAbsolutePadImage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[16]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAbsolutePadImage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAbsolutePadImage(RandomAbsolutePadImage other) : this() { + maxHeightPadding_ = other.maxHeightPadding_; + maxWidthPadding_ = other.maxWidthPadding_; + padColor_ = other.padColor_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomAbsolutePadImage Clone() { + return new RandomAbsolutePadImage(this); + } + + /// Field number for the "max_height_padding" field. + public const int MaxHeightPaddingFieldNumber = 1; + private int maxHeightPadding_; + /// + /// Height will be padded uniformly at random from [0, max_height_padding). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxHeightPadding { + get { return maxHeightPadding_; } + set { + maxHeightPadding_ = value; + } + } + + /// Field number for the "max_width_padding" field. + public const int MaxWidthPaddingFieldNumber = 2; + private int maxWidthPadding_; + /// + /// Width will be padded uniformly at random from [0, max_width_padding). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxWidthPadding { + get { return maxWidthPadding_; } + set { + maxWidthPadding_ = value; + } + } + + /// Field number for the "pad_color" field. + public const int PadColorFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_padColor_codec + = pb::FieldCodec.ForFloat(26); + private readonly pbc::RepeatedField padColor_ = new pbc::RepeatedField(); + /// + /// Color of the padding. If unset, will pad using average color of the input + /// image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField PadColor { + get { return padColor_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomAbsolutePadImage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomAbsolutePadImage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MaxHeightPadding != other.MaxHeightPadding) return false; + if (MaxWidthPadding != other.MaxWidthPadding) return false; + if(!padColor_.Equals(other.padColor_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MaxHeightPadding != 0) hash ^= MaxHeightPadding.GetHashCode(); + if (MaxWidthPadding != 0) hash ^= MaxWidthPadding.GetHashCode(); + hash ^= padColor_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MaxHeightPadding != 0) { + output.WriteRawTag(8); + output.WriteInt32(MaxHeightPadding); + } + if (MaxWidthPadding != 0) { + output.WriteRawTag(16); + output.WriteInt32(MaxWidthPadding); + } + padColor_.WriteTo(output, _repeated_padColor_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MaxHeightPadding != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxHeightPadding); + } + if (MaxWidthPadding != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxWidthPadding); + } + size += padColor_.CalculateSize(_repeated_padColor_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomAbsolutePadImage other) { + if (other == null) { + return; + } + if (other.MaxHeightPadding != 0) { + MaxHeightPadding = other.MaxHeightPadding; + } + if (other.MaxWidthPadding != 0) { + MaxWidthPadding = other.MaxWidthPadding; + } + padColor_.Add(other.padColor_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + MaxHeightPadding = input.ReadInt32(); + break; + } + case 16: { + MaxWidthPadding = input.ReadInt32(); + break; + } + case 26: + case 29: { + padColor_.AddEntriesFrom(input, _repeated_padColor_codec); + break; + } + } + } + } + + } + + /// + /// Randomly crops an image followed by a random pad. + /// + public sealed partial class RandomCropPadImage : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomCropPadImage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[17]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomCropPadImage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomCropPadImage(RandomCropPadImage other) : this() { + minObjectCovered_ = other.minObjectCovered_; + minAspectRatio_ = other.minAspectRatio_; + maxAspectRatio_ = other.maxAspectRatio_; + minArea_ = other.minArea_; + maxArea_ = other.maxArea_; + overlapThresh_ = other.overlapThresh_; + clipBoxes_ = other.clipBoxes_; + randomCoef_ = other.randomCoef_; + minPaddedSizeRatio_ = other.minPaddedSizeRatio_.Clone(); + maxPaddedSizeRatio_ = other.maxPaddedSizeRatio_.Clone(); + padColor_ = other.padColor_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomCropPadImage Clone() { + return new RandomCropPadImage(this); + } + + /// Field number for the "min_object_covered" field. + public const int MinObjectCoveredFieldNumber = 1; + private float minObjectCovered_; + /// + /// Cropping operation must cover at least one box by this fraction. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinObjectCovered { + get { return minObjectCovered_; } + set { + minObjectCovered_ = value; + } + } + + /// Field number for the "min_aspect_ratio" field. + public const int MinAspectRatioFieldNumber = 2; + private float minAspectRatio_; + /// + /// Aspect ratio bounds of image after cropping operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinAspectRatio { + get { return minAspectRatio_; } + set { + minAspectRatio_ = value; + } + } + + /// Field number for the "max_aspect_ratio" field. + public const int MaxAspectRatioFieldNumber = 3; + private float maxAspectRatio_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxAspectRatio { + get { return maxAspectRatio_; } + set { + maxAspectRatio_ = value; + } + } + + /// Field number for the "min_area" field. + public const int MinAreaFieldNumber = 4; + private float minArea_; + /// + /// Allowed area ratio of image after cropping operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinArea { + get { return minArea_; } + set { + minArea_ = value; + } + } + + /// Field number for the "max_area" field. + public const int MaxAreaFieldNumber = 5; + private float maxArea_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxArea { + get { return maxArea_; } + set { + maxArea_ = value; + } + } + + /// Field number for the "overlap_thresh" field. + public const int OverlapThreshFieldNumber = 6; + private float overlapThresh_; + /// + /// Minimum overlap threshold of cropped boxes to keep in new image. If the + /// ratio between a cropped bounding box and the original is less than this + /// value, it is removed from the new image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float OverlapThresh { + get { return overlapThresh_; } + set { + overlapThresh_ = value; + } + } + + /// Field number for the "clip_boxes" field. + public const int ClipBoxesFieldNumber = 11; + private bool clipBoxes_; + /// + /// Whether to clip the boxes to the cropped image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ClipBoxes { + get { return clipBoxes_; } + set { + clipBoxes_ = value; + } + } + + /// Field number for the "random_coef" field. + public const int RandomCoefFieldNumber = 7; + private float randomCoef_; + /// + /// Probability of keeping the original image during the crop operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float RandomCoef { + get { return randomCoef_; } + set { + randomCoef_ = value; + } + } + + /// Field number for the "min_padded_size_ratio" field. + public const int MinPaddedSizeRatioFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_minPaddedSizeRatio_codec + = pb::FieldCodec.ForFloat(66); + private readonly pbc::RepeatedField minPaddedSizeRatio_ = new pbc::RepeatedField(); + /// + /// Maximum dimensions for padded image. If unset, will use double the original + /// image dimension as a lower bound. Both of the following fields should be + /// length 2. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField MinPaddedSizeRatio { + get { return minPaddedSizeRatio_; } + } + + /// Field number for the "max_padded_size_ratio" field. + public const int MaxPaddedSizeRatioFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_maxPaddedSizeRatio_codec + = pb::FieldCodec.ForFloat(74); + private readonly pbc::RepeatedField maxPaddedSizeRatio_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField MaxPaddedSizeRatio { + get { return maxPaddedSizeRatio_; } + } + + /// Field number for the "pad_color" field. + public const int PadColorFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_padColor_codec + = pb::FieldCodec.ForFloat(82); + private readonly pbc::RepeatedField padColor_ = new pbc::RepeatedField(); + /// + /// Color of the padding. If unset, will pad using average color of the input + /// image. This field should be of length 3. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField PadColor { + get { return padColor_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomCropPadImage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomCropPadImage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinObjectCovered, other.MinObjectCovered)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinAspectRatio, other.MinAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxAspectRatio, other.MaxAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinArea, other.MinArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxArea, other.MaxArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(OverlapThresh, other.OverlapThresh)) return false; + if (ClipBoxes != other.ClipBoxes) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(RandomCoef, other.RandomCoef)) return false; + if(!minPaddedSizeRatio_.Equals(other.minPaddedSizeRatio_)) return false; + if(!maxPaddedSizeRatio_.Equals(other.maxPaddedSizeRatio_)) return false; + if(!padColor_.Equals(other.padColor_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinObjectCovered != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinObjectCovered); + if (MinAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinAspectRatio); + if (MaxAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxAspectRatio); + if (MinArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinArea); + if (MaxArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxArea); + if (OverlapThresh != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(OverlapThresh); + if (ClipBoxes != false) hash ^= ClipBoxes.GetHashCode(); + if (RandomCoef != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(RandomCoef); + hash ^= minPaddedSizeRatio_.GetHashCode(); + hash ^= maxPaddedSizeRatio_.GetHashCode(); + hash ^= padColor_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinObjectCovered != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MinObjectCovered); + } + if (MinAspectRatio != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MinAspectRatio); + } + if (MaxAspectRatio != 0F) { + output.WriteRawTag(29); + output.WriteFloat(MaxAspectRatio); + } + if (MinArea != 0F) { + output.WriteRawTag(37); + output.WriteFloat(MinArea); + } + if (MaxArea != 0F) { + output.WriteRawTag(45); + output.WriteFloat(MaxArea); + } + if (OverlapThresh != 0F) { + output.WriteRawTag(53); + output.WriteFloat(OverlapThresh); + } + if (RandomCoef != 0F) { + output.WriteRawTag(61); + output.WriteFloat(RandomCoef); + } + minPaddedSizeRatio_.WriteTo(output, _repeated_minPaddedSizeRatio_codec); + maxPaddedSizeRatio_.WriteTo(output, _repeated_maxPaddedSizeRatio_codec); + padColor_.WriteTo(output, _repeated_padColor_codec); + if (ClipBoxes != false) { + output.WriteRawTag(88); + output.WriteBool(ClipBoxes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinObjectCovered != 0F) { + size += 1 + 4; + } + if (MinAspectRatio != 0F) { + size += 1 + 4; + } + if (MaxAspectRatio != 0F) { + size += 1 + 4; + } + if (MinArea != 0F) { + size += 1 + 4; + } + if (MaxArea != 0F) { + size += 1 + 4; + } + if (OverlapThresh != 0F) { + size += 1 + 4; + } + if (ClipBoxes != false) { + size += 1 + 1; + } + if (RandomCoef != 0F) { + size += 1 + 4; + } + size += minPaddedSizeRatio_.CalculateSize(_repeated_minPaddedSizeRatio_codec); + size += maxPaddedSizeRatio_.CalculateSize(_repeated_maxPaddedSizeRatio_codec); + size += padColor_.CalculateSize(_repeated_padColor_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomCropPadImage other) { + if (other == null) { + return; + } + if (other.MinObjectCovered != 0F) { + MinObjectCovered = other.MinObjectCovered; + } + if (other.MinAspectRatio != 0F) { + MinAspectRatio = other.MinAspectRatio; + } + if (other.MaxAspectRatio != 0F) { + MaxAspectRatio = other.MaxAspectRatio; + } + if (other.MinArea != 0F) { + MinArea = other.MinArea; + } + if (other.MaxArea != 0F) { + MaxArea = other.MaxArea; + } + if (other.OverlapThresh != 0F) { + OverlapThresh = other.OverlapThresh; + } + if (other.ClipBoxes != false) { + ClipBoxes = other.ClipBoxes; + } + if (other.RandomCoef != 0F) { + RandomCoef = other.RandomCoef; + } + minPaddedSizeRatio_.Add(other.minPaddedSizeRatio_); + maxPaddedSizeRatio_.Add(other.maxPaddedSizeRatio_); + padColor_.Add(other.padColor_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MinObjectCovered = input.ReadFloat(); + break; + } + case 21: { + MinAspectRatio = input.ReadFloat(); + break; + } + case 29: { + MaxAspectRatio = input.ReadFloat(); + break; + } + case 37: { + MinArea = input.ReadFloat(); + break; + } + case 45: { + MaxArea = input.ReadFloat(); + break; + } + case 53: { + OverlapThresh = input.ReadFloat(); + break; + } + case 61: { + RandomCoef = input.ReadFloat(); + break; + } + case 66: + case 69: { + minPaddedSizeRatio_.AddEntriesFrom(input, _repeated_minPaddedSizeRatio_codec); + break; + } + case 74: + case 77: { + maxPaddedSizeRatio_.AddEntriesFrom(input, _repeated_maxPaddedSizeRatio_codec); + break; + } + case 82: + case 85: { + padColor_.AddEntriesFrom(input, _repeated_padColor_codec); + break; + } + case 88: { + ClipBoxes = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Randomly crops an iamge to a given aspect ratio. + /// + public sealed partial class RandomCropToAspectRatio : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomCropToAspectRatio()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[18]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomCropToAspectRatio() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomCropToAspectRatio(RandomCropToAspectRatio other) : this() { + aspectRatio_ = other.aspectRatio_; + overlapThresh_ = other.overlapThresh_; + clipBoxes_ = other.clipBoxes_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomCropToAspectRatio Clone() { + return new RandomCropToAspectRatio(this); + } + + /// Field number for the "aspect_ratio" field. + public const int AspectRatioFieldNumber = 1; + private float aspectRatio_; + /// + /// Aspect ratio. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float AspectRatio { + get { return aspectRatio_; } + set { + aspectRatio_ = value; + } + } + + /// Field number for the "overlap_thresh" field. + public const int OverlapThreshFieldNumber = 2; + private float overlapThresh_; + /// + /// Minimum overlap threshold of cropped boxes to keep in new image. If the + /// ratio between a cropped bounding box and the original is less than this + /// value, it is removed from the new image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float OverlapThresh { + get { return overlapThresh_; } + set { + overlapThresh_ = value; + } + } + + /// Field number for the "clip_boxes" field. + public const int ClipBoxesFieldNumber = 3; + private bool clipBoxes_; + /// + /// Whether to clip the boxes to the cropped image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ClipBoxes { + get { return clipBoxes_; } + set { + clipBoxes_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomCropToAspectRatio); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomCropToAspectRatio other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(AspectRatio, other.AspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(OverlapThresh, other.OverlapThresh)) return false; + if (ClipBoxes != other.ClipBoxes) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (AspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(AspectRatio); + if (OverlapThresh != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(OverlapThresh); + if (ClipBoxes != false) hash ^= ClipBoxes.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (AspectRatio != 0F) { + output.WriteRawTag(13); + output.WriteFloat(AspectRatio); + } + if (OverlapThresh != 0F) { + output.WriteRawTag(21); + output.WriteFloat(OverlapThresh); + } + if (ClipBoxes != false) { + output.WriteRawTag(24); + output.WriteBool(ClipBoxes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (AspectRatio != 0F) { + size += 1 + 4; + } + if (OverlapThresh != 0F) { + size += 1 + 4; + } + if (ClipBoxes != false) { + size += 1 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomCropToAspectRatio other) { + if (other == null) { + return; + } + if (other.AspectRatio != 0F) { + AspectRatio = other.AspectRatio; + } + if (other.OverlapThresh != 0F) { + OverlapThresh = other.OverlapThresh; + } + if (other.ClipBoxes != false) { + ClipBoxes = other.ClipBoxes; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + AspectRatio = input.ReadFloat(); + break; + } + case 21: { + OverlapThresh = input.ReadFloat(); + break; + } + case 24: { + ClipBoxes = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Randomly adds black square patches to an image. + /// + public sealed partial class RandomBlackPatches : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomBlackPatches()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[19]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomBlackPatches() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomBlackPatches(RandomBlackPatches other) : this() { + maxBlackPatches_ = other.maxBlackPatches_; + probability_ = other.probability_; + sizeToImageRatio_ = other.sizeToImageRatio_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomBlackPatches Clone() { + return new RandomBlackPatches(this); + } + + /// Field number for the "max_black_patches" field. + public const int MaxBlackPatchesFieldNumber = 1; + private int maxBlackPatches_; + /// + /// The maximum number of black patches to add. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxBlackPatches { + get { return maxBlackPatches_; } + set { + maxBlackPatches_ = value; + } + } + + /// Field number for the "probability" field. + public const int ProbabilityFieldNumber = 2; + private float probability_; + /// + /// The probability of a black patch being added to an image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Probability { + get { return probability_; } + set { + probability_ = value; + } + } + + /// Field number for the "size_to_image_ratio" field. + public const int SizeToImageRatioFieldNumber = 3; + private float sizeToImageRatio_; + /// + /// Ratio between the dimension of the black patch to the minimum dimension of + /// the image (patch_width = patch_height = min(image_height, image_width)). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float SizeToImageRatio { + get { return sizeToImageRatio_; } + set { + sizeToImageRatio_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomBlackPatches); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomBlackPatches other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MaxBlackPatches != other.MaxBlackPatches) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Probability, other.Probability)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(SizeToImageRatio, other.SizeToImageRatio)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MaxBlackPatches != 0) hash ^= MaxBlackPatches.GetHashCode(); + if (Probability != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Probability); + if (SizeToImageRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(SizeToImageRatio); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MaxBlackPatches != 0) { + output.WriteRawTag(8); + output.WriteInt32(MaxBlackPatches); + } + if (Probability != 0F) { + output.WriteRawTag(21); + output.WriteFloat(Probability); + } + if (SizeToImageRatio != 0F) { + output.WriteRawTag(29); + output.WriteFloat(SizeToImageRatio); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MaxBlackPatches != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxBlackPatches); + } + if (Probability != 0F) { + size += 1 + 4; + } + if (SizeToImageRatio != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomBlackPatches other) { + if (other == null) { + return; + } + if (other.MaxBlackPatches != 0) { + MaxBlackPatches = other.MaxBlackPatches; + } + if (other.Probability != 0F) { + Probability = other.Probability; + } + if (other.SizeToImageRatio != 0F) { + SizeToImageRatio = other.SizeToImageRatio; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + MaxBlackPatches = input.ReadInt32(); + break; + } + case 21: { + Probability = input.ReadFloat(); + break; + } + case 29: { + SizeToImageRatio = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly resizes the image up to [target_height, target_width]. + /// + public sealed partial class RandomResizeMethod : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomResizeMethod()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[20]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomResizeMethod() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomResizeMethod(RandomResizeMethod other) : this() { + targetHeight_ = other.targetHeight_; + targetWidth_ = other.targetWidth_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomResizeMethod Clone() { + return new RandomResizeMethod(this); + } + + /// Field number for the "target_height" field. + public const int TargetHeightFieldNumber = 1; + private int targetHeight_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int TargetHeight { + get { return targetHeight_; } + set { + targetHeight_ = value; + } + } + + /// Field number for the "target_width" field. + public const int TargetWidthFieldNumber = 2; + private int targetWidth_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int TargetWidth { + get { return targetWidth_; } + set { + targetWidth_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomResizeMethod); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomResizeMethod other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TargetHeight != other.TargetHeight) return false; + if (TargetWidth != other.TargetWidth) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (TargetHeight != 0) hash ^= TargetHeight.GetHashCode(); + if (TargetWidth != 0) hash ^= TargetWidth.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (TargetHeight != 0) { + output.WriteRawTag(8); + output.WriteInt32(TargetHeight); + } + if (TargetWidth != 0) { + output.WriteRawTag(16); + output.WriteInt32(TargetWidth); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (TargetHeight != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TargetHeight); + } + if (TargetWidth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(TargetWidth); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomResizeMethod other) { + if (other == null) { + return; + } + if (other.TargetHeight != 0) { + TargetHeight = other.TargetHeight; + } + if (other.TargetWidth != 0) { + TargetWidth = other.TargetWidth; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + TargetHeight = input.ReadInt32(); + break; + } + case 16: { + TargetWidth = input.ReadInt32(); + break; + } + } + } + } + + } + + /// + /// Converts the RGB image to a grayscale image. This also converts the image + /// depth from 3 to 1, unlike RandomRGBtoGray which does not change the image + /// depth. + /// + public sealed partial class RGBtoGray : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RGBtoGray()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[21]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RGBtoGray() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RGBtoGray(RGBtoGray other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RGBtoGray Clone() { + return new RGBtoGray(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RGBtoGray); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RGBtoGray other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RGBtoGray other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + /// + /// Scales boxes from normalized coordinates to pixel coordinates. + /// + public sealed partial class ScaleBoxesToPixelCoordinates : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ScaleBoxesToPixelCoordinates()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[22]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ScaleBoxesToPixelCoordinates() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ScaleBoxesToPixelCoordinates(ScaleBoxesToPixelCoordinates other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ScaleBoxesToPixelCoordinates Clone() { + return new ScaleBoxesToPixelCoordinates(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ScaleBoxesToPixelCoordinates); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ScaleBoxesToPixelCoordinates other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ScaleBoxesToPixelCoordinates other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + /// + /// Resizes images to [new_height, new_width]. + /// + public sealed partial class ResizeImage : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ResizeImage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[23]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ResizeImage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ResizeImage(ResizeImage other) : this() { + newHeight_ = other.newHeight_; + newWidth_ = other.newWidth_; + method_ = other.method_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ResizeImage Clone() { + return new ResizeImage(this); + } + + /// Field number for the "new_height" field. + public const int NewHeightFieldNumber = 1; + private int newHeight_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NewHeight { + get { return newHeight_; } + set { + newHeight_ = value; + } + } + + /// Field number for the "new_width" field. + public const int NewWidthFieldNumber = 2; + private int newWidth_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NewWidth { + get { return newWidth_; } + set { + newWidth_ = value; + } + } + + /// Field number for the "method" field. + public const int MethodFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage.Types.Method method_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage.Types.Method Method { + get { return method_; } + set { + method_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ResizeImage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ResizeImage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NewHeight != other.NewHeight) return false; + if (NewWidth != other.NewWidth) return false; + if (Method != other.Method) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (NewHeight != 0) hash ^= NewHeight.GetHashCode(); + if (NewWidth != 0) hash ^= NewWidth.GetHashCode(); + if (Method != 0) hash ^= Method.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (NewHeight != 0) { + output.WriteRawTag(8); + output.WriteInt32(NewHeight); + } + if (NewWidth != 0) { + output.WriteRawTag(16); + output.WriteInt32(NewWidth); + } + if (Method != 0) { + output.WriteRawTag(24); + output.WriteEnum((int) Method); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (NewHeight != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NewHeight); + } + if (NewWidth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NewWidth); + } + if (Method != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Method); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ResizeImage other) { + if (other == null) { + return; + } + if (other.NewHeight != 0) { + NewHeight = other.NewHeight; + } + if (other.NewWidth != 0) { + NewWidth = other.NewWidth; + } + if (other.Method != 0) { + Method = other.Method; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NewHeight = input.ReadInt32(); + break; + } + case 16: { + NewWidth = input.ReadInt32(); + break; + } + case 24: { + method_ = (global::Tensorflow.Models.ObjectDetection.Protos.ResizeImage.Types.Method) input.ReadEnum(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the ResizeImage message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + public enum Method { + [pbr::OriginalName("NONE")] None = 0, + [pbr::OriginalName("AREA")] Area = 1, + [pbr::OriginalName("BICUBIC")] Bicubic = 2, + [pbr::OriginalName("BILINEAR")] Bilinear = 3, + [pbr::OriginalName("NEAREST_NEIGHBOR")] NearestNeighbor = 4, + } + + } + #endregion + + } + + /// + /// Normalizes an image by subtracting a mean from each channel. + /// + public sealed partial class SubtractChannelMean : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SubtractChannelMean()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[24]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SubtractChannelMean() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SubtractChannelMean(SubtractChannelMean other) : this() { + means_ = other.means_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SubtractChannelMean Clone() { + return new SubtractChannelMean(this); + } + + /// Field number for the "means" field. + public const int MeansFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_means_codec + = pb::FieldCodec.ForFloat(10); + private readonly pbc::RepeatedField means_ = new pbc::RepeatedField(); + /// + /// The mean to subtract from each channel. Should be of same dimension of + /// channels in the input image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Means { + get { return means_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SubtractChannelMean); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SubtractChannelMean other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!means_.Equals(other.means_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= means_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + means_.WriteTo(output, _repeated_means_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += means_.CalculateSize(_repeated_means_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SubtractChannelMean other) { + if (other == null) { + return; + } + means_.Add(other.means_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 13: { + means_.AddEntriesFrom(input, _repeated_means_codec); + break; + } + } + } + } + + } + + public sealed partial class SSDRandomCropOperation : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SSDRandomCropOperation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[25]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropOperation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropOperation(SSDRandomCropOperation other) : this() { + minObjectCovered_ = other.minObjectCovered_; + minAspectRatio_ = other.minAspectRatio_; + maxAspectRatio_ = other.maxAspectRatio_; + minArea_ = other.minArea_; + maxArea_ = other.maxArea_; + overlapThresh_ = other.overlapThresh_; + clipBoxes_ = other.clipBoxes_; + randomCoef_ = other.randomCoef_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropOperation Clone() { + return new SSDRandomCropOperation(this); + } + + /// Field number for the "min_object_covered" field. + public const int MinObjectCoveredFieldNumber = 1; + private float minObjectCovered_; + /// + /// Cropped image must cover at least this fraction of one original bounding + /// box. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinObjectCovered { + get { return minObjectCovered_; } + set { + minObjectCovered_ = value; + } + } + + /// Field number for the "min_aspect_ratio" field. + public const int MinAspectRatioFieldNumber = 2; + private float minAspectRatio_; + /// + /// The aspect ratio of the cropped image must be within the range of + /// [min_aspect_ratio, max_aspect_ratio]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinAspectRatio { + get { return minAspectRatio_; } + set { + minAspectRatio_ = value; + } + } + + /// Field number for the "max_aspect_ratio" field. + public const int MaxAspectRatioFieldNumber = 3; + private float maxAspectRatio_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxAspectRatio { + get { return maxAspectRatio_; } + set { + maxAspectRatio_ = value; + } + } + + /// Field number for the "min_area" field. + public const int MinAreaFieldNumber = 4; + private float minArea_; + /// + /// The area of the cropped image must be within the range of + /// [min_area, max_area]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinArea { + get { return minArea_; } + set { + minArea_ = value; + } + } + + /// Field number for the "max_area" field. + public const int MaxAreaFieldNumber = 5; + private float maxArea_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxArea { + get { return maxArea_; } + set { + maxArea_ = value; + } + } + + /// Field number for the "overlap_thresh" field. + public const int OverlapThreshFieldNumber = 6; + private float overlapThresh_; + /// + /// Cropped box area ratio must be above this threhold to be kept. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float OverlapThresh { + get { return overlapThresh_; } + set { + overlapThresh_ = value; + } + } + + /// Field number for the "clip_boxes" field. + public const int ClipBoxesFieldNumber = 8; + private bool clipBoxes_; + /// + /// Whether to clip the boxes to the cropped image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ClipBoxes { + get { return clipBoxes_; } + set { + clipBoxes_ = value; + } + } + + /// Field number for the "random_coef" field. + public const int RandomCoefFieldNumber = 7; + private float randomCoef_; + /// + /// Probability a crop operation is skipped. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float RandomCoef { + get { return randomCoef_; } + set { + randomCoef_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SSDRandomCropOperation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SSDRandomCropOperation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinObjectCovered, other.MinObjectCovered)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinAspectRatio, other.MinAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxAspectRatio, other.MaxAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinArea, other.MinArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxArea, other.MaxArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(OverlapThresh, other.OverlapThresh)) return false; + if (ClipBoxes != other.ClipBoxes) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(RandomCoef, other.RandomCoef)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinObjectCovered != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinObjectCovered); + if (MinAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinAspectRatio); + if (MaxAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxAspectRatio); + if (MinArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinArea); + if (MaxArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxArea); + if (OverlapThresh != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(OverlapThresh); + if (ClipBoxes != false) hash ^= ClipBoxes.GetHashCode(); + if (RandomCoef != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(RandomCoef); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinObjectCovered != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MinObjectCovered); + } + if (MinAspectRatio != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MinAspectRatio); + } + if (MaxAspectRatio != 0F) { + output.WriteRawTag(29); + output.WriteFloat(MaxAspectRatio); + } + if (MinArea != 0F) { + output.WriteRawTag(37); + output.WriteFloat(MinArea); + } + if (MaxArea != 0F) { + output.WriteRawTag(45); + output.WriteFloat(MaxArea); + } + if (OverlapThresh != 0F) { + output.WriteRawTag(53); + output.WriteFloat(OverlapThresh); + } + if (RandomCoef != 0F) { + output.WriteRawTag(61); + output.WriteFloat(RandomCoef); + } + if (ClipBoxes != false) { + output.WriteRawTag(64); + output.WriteBool(ClipBoxes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinObjectCovered != 0F) { + size += 1 + 4; + } + if (MinAspectRatio != 0F) { + size += 1 + 4; + } + if (MaxAspectRatio != 0F) { + size += 1 + 4; + } + if (MinArea != 0F) { + size += 1 + 4; + } + if (MaxArea != 0F) { + size += 1 + 4; + } + if (OverlapThresh != 0F) { + size += 1 + 4; + } + if (ClipBoxes != false) { + size += 1 + 1; + } + if (RandomCoef != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SSDRandomCropOperation other) { + if (other == null) { + return; + } + if (other.MinObjectCovered != 0F) { + MinObjectCovered = other.MinObjectCovered; + } + if (other.MinAspectRatio != 0F) { + MinAspectRatio = other.MinAspectRatio; + } + if (other.MaxAspectRatio != 0F) { + MaxAspectRatio = other.MaxAspectRatio; + } + if (other.MinArea != 0F) { + MinArea = other.MinArea; + } + if (other.MaxArea != 0F) { + MaxArea = other.MaxArea; + } + if (other.OverlapThresh != 0F) { + OverlapThresh = other.OverlapThresh; + } + if (other.ClipBoxes != false) { + ClipBoxes = other.ClipBoxes; + } + if (other.RandomCoef != 0F) { + RandomCoef = other.RandomCoef; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MinObjectCovered = input.ReadFloat(); + break; + } + case 21: { + MinAspectRatio = input.ReadFloat(); + break; + } + case 29: { + MaxAspectRatio = input.ReadFloat(); + break; + } + case 37: { + MinArea = input.ReadFloat(); + break; + } + case 45: { + MaxArea = input.ReadFloat(); + break; + } + case 53: { + OverlapThresh = input.ReadFloat(); + break; + } + case 61: { + RandomCoef = input.ReadFloat(); + break; + } + case 64: { + ClipBoxes = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Randomly crops a image according to: + /// Liu et al., SSD: Single shot multibox detector. + /// This preprocessing step defines multiple SSDRandomCropOperations. Only one + /// operation (chosen at random) is actually performed on an image. + /// + public sealed partial class SSDRandomCrop : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SSDRandomCrop()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[26]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCrop() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCrop(SSDRandomCrop other) : this() { + operations_ = other.operations_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCrop Clone() { + return new SSDRandomCrop(this); + } + + /// Field number for the "operations" field. + public const int OperationsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_operations_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropOperation.Parser); + private readonly pbc::RepeatedField operations_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Operations { + get { return operations_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SSDRandomCrop); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SSDRandomCrop other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!operations_.Equals(other.operations_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= operations_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + operations_.WriteTo(output, _repeated_operations_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += operations_.CalculateSize(_repeated_operations_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SSDRandomCrop other) { + if (other == null) { + return; + } + operations_.Add(other.operations_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + operations_.AddEntriesFrom(input, _repeated_operations_codec); + break; + } + } + } + } + + } + + public sealed partial class SSDRandomCropPadOperation : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SSDRandomCropPadOperation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[27]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPadOperation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPadOperation(SSDRandomCropPadOperation other) : this() { + minObjectCovered_ = other.minObjectCovered_; + minAspectRatio_ = other.minAspectRatio_; + maxAspectRatio_ = other.maxAspectRatio_; + minArea_ = other.minArea_; + maxArea_ = other.maxArea_; + overlapThresh_ = other.overlapThresh_; + clipBoxes_ = other.clipBoxes_; + randomCoef_ = other.randomCoef_; + minPaddedSizeRatio_ = other.minPaddedSizeRatio_.Clone(); + maxPaddedSizeRatio_ = other.maxPaddedSizeRatio_.Clone(); + padColorR_ = other.padColorR_; + padColorG_ = other.padColorG_; + padColorB_ = other.padColorB_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPadOperation Clone() { + return new SSDRandomCropPadOperation(this); + } + + /// Field number for the "min_object_covered" field. + public const int MinObjectCoveredFieldNumber = 1; + private float minObjectCovered_; + /// + /// Cropped image must cover at least this fraction of one original bounding + /// box. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinObjectCovered { + get { return minObjectCovered_; } + set { + minObjectCovered_ = value; + } + } + + /// Field number for the "min_aspect_ratio" field. + public const int MinAspectRatioFieldNumber = 2; + private float minAspectRatio_; + /// + /// The aspect ratio of the cropped image must be within the range of + /// [min_aspect_ratio, max_aspect_ratio]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinAspectRatio { + get { return minAspectRatio_; } + set { + minAspectRatio_ = value; + } + } + + /// Field number for the "max_aspect_ratio" field. + public const int MaxAspectRatioFieldNumber = 3; + private float maxAspectRatio_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxAspectRatio { + get { return maxAspectRatio_; } + set { + maxAspectRatio_ = value; + } + } + + /// Field number for the "min_area" field. + public const int MinAreaFieldNumber = 4; + private float minArea_; + /// + /// The area of the cropped image must be within the range of + /// [min_area, max_area]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinArea { + get { return minArea_; } + set { + minArea_ = value; + } + } + + /// Field number for the "max_area" field. + public const int MaxAreaFieldNumber = 5; + private float maxArea_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxArea { + get { return maxArea_; } + set { + maxArea_ = value; + } + } + + /// Field number for the "overlap_thresh" field. + public const int OverlapThreshFieldNumber = 6; + private float overlapThresh_; + /// + /// Cropped box area ratio must be above this threhold to be kept. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float OverlapThresh { + get { return overlapThresh_; } + set { + overlapThresh_ = value; + } + } + + /// Field number for the "clip_boxes" field. + public const int ClipBoxesFieldNumber = 13; + private bool clipBoxes_; + /// + /// Whether to clip the boxes to the cropped image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ClipBoxes { + get { return clipBoxes_; } + set { + clipBoxes_ = value; + } + } + + /// Field number for the "random_coef" field. + public const int RandomCoefFieldNumber = 7; + private float randomCoef_; + /// + /// Probability a crop operation is skipped. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float RandomCoef { + get { return randomCoef_; } + set { + randomCoef_ = value; + } + } + + /// Field number for the "min_padded_size_ratio" field. + public const int MinPaddedSizeRatioFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_minPaddedSizeRatio_codec + = pb::FieldCodec.ForFloat(66); + private readonly pbc::RepeatedField minPaddedSizeRatio_ = new pbc::RepeatedField(); + /// + /// Min ratio of padded image height and width to the input image's height and + /// width. Two entries per operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField MinPaddedSizeRatio { + get { return minPaddedSizeRatio_; } + } + + /// Field number for the "max_padded_size_ratio" field. + public const int MaxPaddedSizeRatioFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_maxPaddedSizeRatio_codec + = pb::FieldCodec.ForFloat(74); + private readonly pbc::RepeatedField maxPaddedSizeRatio_ = new pbc::RepeatedField(); + /// + /// Max ratio of padded image height and width to the input image's height and + /// width. Two entries per operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField MaxPaddedSizeRatio { + get { return maxPaddedSizeRatio_; } + } + + /// Field number for the "pad_color_r" field. + public const int PadColorRFieldNumber = 10; + private float padColorR_; + /// + /// Padding color. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float PadColorR { + get { return padColorR_; } + set { + padColorR_ = value; + } + } + + /// Field number for the "pad_color_g" field. + public const int PadColorGFieldNumber = 11; + private float padColorG_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float PadColorG { + get { return padColorG_; } + set { + padColorG_ = value; + } + } + + /// Field number for the "pad_color_b" field. + public const int PadColorBFieldNumber = 12; + private float padColorB_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float PadColorB { + get { return padColorB_; } + set { + padColorB_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SSDRandomCropPadOperation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SSDRandomCropPadOperation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinObjectCovered, other.MinObjectCovered)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinAspectRatio, other.MinAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxAspectRatio, other.MaxAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinArea, other.MinArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxArea, other.MaxArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(OverlapThresh, other.OverlapThresh)) return false; + if (ClipBoxes != other.ClipBoxes) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(RandomCoef, other.RandomCoef)) return false; + if(!minPaddedSizeRatio_.Equals(other.minPaddedSizeRatio_)) return false; + if(!maxPaddedSizeRatio_.Equals(other.maxPaddedSizeRatio_)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(PadColorR, other.PadColorR)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(PadColorG, other.PadColorG)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(PadColorB, other.PadColorB)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinObjectCovered != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinObjectCovered); + if (MinAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinAspectRatio); + if (MaxAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxAspectRatio); + if (MinArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinArea); + if (MaxArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxArea); + if (OverlapThresh != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(OverlapThresh); + if (ClipBoxes != false) hash ^= ClipBoxes.GetHashCode(); + if (RandomCoef != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(RandomCoef); + hash ^= minPaddedSizeRatio_.GetHashCode(); + hash ^= maxPaddedSizeRatio_.GetHashCode(); + if (PadColorR != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(PadColorR); + if (PadColorG != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(PadColorG); + if (PadColorB != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(PadColorB); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinObjectCovered != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MinObjectCovered); + } + if (MinAspectRatio != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MinAspectRatio); + } + if (MaxAspectRatio != 0F) { + output.WriteRawTag(29); + output.WriteFloat(MaxAspectRatio); + } + if (MinArea != 0F) { + output.WriteRawTag(37); + output.WriteFloat(MinArea); + } + if (MaxArea != 0F) { + output.WriteRawTag(45); + output.WriteFloat(MaxArea); + } + if (OverlapThresh != 0F) { + output.WriteRawTag(53); + output.WriteFloat(OverlapThresh); + } + if (RandomCoef != 0F) { + output.WriteRawTag(61); + output.WriteFloat(RandomCoef); + } + minPaddedSizeRatio_.WriteTo(output, _repeated_minPaddedSizeRatio_codec); + maxPaddedSizeRatio_.WriteTo(output, _repeated_maxPaddedSizeRatio_codec); + if (PadColorR != 0F) { + output.WriteRawTag(85); + output.WriteFloat(PadColorR); + } + if (PadColorG != 0F) { + output.WriteRawTag(93); + output.WriteFloat(PadColorG); + } + if (PadColorB != 0F) { + output.WriteRawTag(101); + output.WriteFloat(PadColorB); + } + if (ClipBoxes != false) { + output.WriteRawTag(104); + output.WriteBool(ClipBoxes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinObjectCovered != 0F) { + size += 1 + 4; + } + if (MinAspectRatio != 0F) { + size += 1 + 4; + } + if (MaxAspectRatio != 0F) { + size += 1 + 4; + } + if (MinArea != 0F) { + size += 1 + 4; + } + if (MaxArea != 0F) { + size += 1 + 4; + } + if (OverlapThresh != 0F) { + size += 1 + 4; + } + if (ClipBoxes != false) { + size += 1 + 1; + } + if (RandomCoef != 0F) { + size += 1 + 4; + } + size += minPaddedSizeRatio_.CalculateSize(_repeated_minPaddedSizeRatio_codec); + size += maxPaddedSizeRatio_.CalculateSize(_repeated_maxPaddedSizeRatio_codec); + if (PadColorR != 0F) { + size += 1 + 4; + } + if (PadColorG != 0F) { + size += 1 + 4; + } + if (PadColorB != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SSDRandomCropPadOperation other) { + if (other == null) { + return; + } + if (other.MinObjectCovered != 0F) { + MinObjectCovered = other.MinObjectCovered; + } + if (other.MinAspectRatio != 0F) { + MinAspectRatio = other.MinAspectRatio; + } + if (other.MaxAspectRatio != 0F) { + MaxAspectRatio = other.MaxAspectRatio; + } + if (other.MinArea != 0F) { + MinArea = other.MinArea; + } + if (other.MaxArea != 0F) { + MaxArea = other.MaxArea; + } + if (other.OverlapThresh != 0F) { + OverlapThresh = other.OverlapThresh; + } + if (other.ClipBoxes != false) { + ClipBoxes = other.ClipBoxes; + } + if (other.RandomCoef != 0F) { + RandomCoef = other.RandomCoef; + } + minPaddedSizeRatio_.Add(other.minPaddedSizeRatio_); + maxPaddedSizeRatio_.Add(other.maxPaddedSizeRatio_); + if (other.PadColorR != 0F) { + PadColorR = other.PadColorR; + } + if (other.PadColorG != 0F) { + PadColorG = other.PadColorG; + } + if (other.PadColorB != 0F) { + PadColorB = other.PadColorB; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MinObjectCovered = input.ReadFloat(); + break; + } + case 21: { + MinAspectRatio = input.ReadFloat(); + break; + } + case 29: { + MaxAspectRatio = input.ReadFloat(); + break; + } + case 37: { + MinArea = input.ReadFloat(); + break; + } + case 45: { + MaxArea = input.ReadFloat(); + break; + } + case 53: { + OverlapThresh = input.ReadFloat(); + break; + } + case 61: { + RandomCoef = input.ReadFloat(); + break; + } + case 66: + case 69: { + minPaddedSizeRatio_.AddEntriesFrom(input, _repeated_minPaddedSizeRatio_codec); + break; + } + case 74: + case 77: { + maxPaddedSizeRatio_.AddEntriesFrom(input, _repeated_maxPaddedSizeRatio_codec); + break; + } + case 85: { + PadColorR = input.ReadFloat(); + break; + } + case 93: { + PadColorG = input.ReadFloat(); + break; + } + case 101: { + PadColorB = input.ReadFloat(); + break; + } + case 104: { + ClipBoxes = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Randomly crops and pads an image according to: + /// Liu et al., SSD: Single shot multibox detector. + /// This preprocessing step defines multiple SSDRandomCropPadOperations. Only one + /// operation (chosen at random) is actually performed on an image. + /// + public sealed partial class SSDRandomCropPad : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SSDRandomCropPad()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[28]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPad() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPad(SSDRandomCropPad other) : this() { + operations_ = other.operations_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPad Clone() { + return new SSDRandomCropPad(this); + } + + /// Field number for the "operations" field. + public const int OperationsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_operations_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadOperation.Parser); + private readonly pbc::RepeatedField operations_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Operations { + get { return operations_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SSDRandomCropPad); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SSDRandomCropPad other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!operations_.Equals(other.operations_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= operations_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + operations_.WriteTo(output, _repeated_operations_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += operations_.CalculateSize(_repeated_operations_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SSDRandomCropPad other) { + if (other == null) { + return; + } + operations_.Add(other.operations_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + operations_.AddEntriesFrom(input, _repeated_operations_codec); + break; + } + } + } + } + + } + + public sealed partial class SSDRandomCropFixedAspectRatioOperation : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SSDRandomCropFixedAspectRatioOperation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[29]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropFixedAspectRatioOperation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropFixedAspectRatioOperation(SSDRandomCropFixedAspectRatioOperation other) : this() { + minObjectCovered_ = other.minObjectCovered_; + minArea_ = other.minArea_; + maxArea_ = other.maxArea_; + overlapThresh_ = other.overlapThresh_; + clipBoxes_ = other.clipBoxes_; + randomCoef_ = other.randomCoef_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropFixedAspectRatioOperation Clone() { + return new SSDRandomCropFixedAspectRatioOperation(this); + } + + /// Field number for the "min_object_covered" field. + public const int MinObjectCoveredFieldNumber = 1; + private float minObjectCovered_; + /// + /// Cropped image must cover at least this fraction of one original bounding + /// box. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinObjectCovered { + get { return minObjectCovered_; } + set { + minObjectCovered_ = value; + } + } + + /// Field number for the "min_area" field. + public const int MinAreaFieldNumber = 4; + private float minArea_; + /// + /// The area of the cropped image must be within the range of + /// [min_area, max_area]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinArea { + get { return minArea_; } + set { + minArea_ = value; + } + } + + /// Field number for the "max_area" field. + public const int MaxAreaFieldNumber = 5; + private float maxArea_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxArea { + get { return maxArea_; } + set { + maxArea_ = value; + } + } + + /// Field number for the "overlap_thresh" field. + public const int OverlapThreshFieldNumber = 6; + private float overlapThresh_; + /// + /// Cropped box area ratio must be above this threhold to be kept. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float OverlapThresh { + get { return overlapThresh_; } + set { + overlapThresh_ = value; + } + } + + /// Field number for the "clip_boxes" field. + public const int ClipBoxesFieldNumber = 8; + private bool clipBoxes_; + /// + /// Whether to clip the boxes to the cropped image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ClipBoxes { + get { return clipBoxes_; } + set { + clipBoxes_ = value; + } + } + + /// Field number for the "random_coef" field. + public const int RandomCoefFieldNumber = 7; + private float randomCoef_; + /// + /// Probability a crop operation is skipped. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float RandomCoef { + get { return randomCoef_; } + set { + randomCoef_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SSDRandomCropFixedAspectRatioOperation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SSDRandomCropFixedAspectRatioOperation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinObjectCovered, other.MinObjectCovered)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinArea, other.MinArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxArea, other.MaxArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(OverlapThresh, other.OverlapThresh)) return false; + if (ClipBoxes != other.ClipBoxes) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(RandomCoef, other.RandomCoef)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinObjectCovered != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinObjectCovered); + if (MinArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinArea); + if (MaxArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxArea); + if (OverlapThresh != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(OverlapThresh); + if (ClipBoxes != false) hash ^= ClipBoxes.GetHashCode(); + if (RandomCoef != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(RandomCoef); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinObjectCovered != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MinObjectCovered); + } + if (MinArea != 0F) { + output.WriteRawTag(37); + output.WriteFloat(MinArea); + } + if (MaxArea != 0F) { + output.WriteRawTag(45); + output.WriteFloat(MaxArea); + } + if (OverlapThresh != 0F) { + output.WriteRawTag(53); + output.WriteFloat(OverlapThresh); + } + if (RandomCoef != 0F) { + output.WriteRawTag(61); + output.WriteFloat(RandomCoef); + } + if (ClipBoxes != false) { + output.WriteRawTag(64); + output.WriteBool(ClipBoxes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinObjectCovered != 0F) { + size += 1 + 4; + } + if (MinArea != 0F) { + size += 1 + 4; + } + if (MaxArea != 0F) { + size += 1 + 4; + } + if (OverlapThresh != 0F) { + size += 1 + 4; + } + if (ClipBoxes != false) { + size += 1 + 1; + } + if (RandomCoef != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SSDRandomCropFixedAspectRatioOperation other) { + if (other == null) { + return; + } + if (other.MinObjectCovered != 0F) { + MinObjectCovered = other.MinObjectCovered; + } + if (other.MinArea != 0F) { + MinArea = other.MinArea; + } + if (other.MaxArea != 0F) { + MaxArea = other.MaxArea; + } + if (other.OverlapThresh != 0F) { + OverlapThresh = other.OverlapThresh; + } + if (other.ClipBoxes != false) { + ClipBoxes = other.ClipBoxes; + } + if (other.RandomCoef != 0F) { + RandomCoef = other.RandomCoef; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MinObjectCovered = input.ReadFloat(); + break; + } + case 37: { + MinArea = input.ReadFloat(); + break; + } + case 45: { + MaxArea = input.ReadFloat(); + break; + } + case 53: { + OverlapThresh = input.ReadFloat(); + break; + } + case 61: { + RandomCoef = input.ReadFloat(); + break; + } + case 64: { + ClipBoxes = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Randomly crops a image to a fixed aspect ratio according to: + /// Liu et al., SSD: Single shot multibox detector. + /// Multiple SSDRandomCropFixedAspectRatioOperations are defined by this + /// preprocessing step. Only one operation (chosen at random) is actually + /// performed on an image. + /// + public sealed partial class SSDRandomCropFixedAspectRatio : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SSDRandomCropFixedAspectRatio()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[30]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropFixedAspectRatio() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropFixedAspectRatio(SSDRandomCropFixedAspectRatio other) : this() { + operations_ = other.operations_.Clone(); + aspectRatio_ = other.aspectRatio_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropFixedAspectRatio Clone() { + return new SSDRandomCropFixedAspectRatio(this); + } + + /// Field number for the "operations" field. + public const int OperationsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_operations_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropFixedAspectRatioOperation.Parser); + private readonly pbc::RepeatedField operations_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Operations { + get { return operations_; } + } + + /// Field number for the "aspect_ratio" field. + public const int AspectRatioFieldNumber = 2; + private float aspectRatio_; + /// + /// Aspect ratio to crop to. This value is used for all crop operations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float AspectRatio { + get { return aspectRatio_; } + set { + aspectRatio_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SSDRandomCropFixedAspectRatio); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SSDRandomCropFixedAspectRatio other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!operations_.Equals(other.operations_)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(AspectRatio, other.AspectRatio)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= operations_.GetHashCode(); + if (AspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(AspectRatio); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + operations_.WriteTo(output, _repeated_operations_codec); + if (AspectRatio != 0F) { + output.WriteRawTag(21); + output.WriteFloat(AspectRatio); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += operations_.CalculateSize(_repeated_operations_codec); + if (AspectRatio != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SSDRandomCropFixedAspectRatio other) { + if (other == null) { + return; + } + operations_.Add(other.operations_); + if (other.AspectRatio != 0F) { + AspectRatio = other.AspectRatio; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + operations_.AddEntriesFrom(input, _repeated_operations_codec); + break; + } + case 21: { + AspectRatio = input.ReadFloat(); + break; + } + } + } + } + + } + + public sealed partial class SSDRandomCropPadFixedAspectRatioOperation : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SSDRandomCropPadFixedAspectRatioOperation()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[31]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPadFixedAspectRatioOperation() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPadFixedAspectRatioOperation(SSDRandomCropPadFixedAspectRatioOperation other) : this() { + minObjectCovered_ = other.minObjectCovered_; + minAspectRatio_ = other.minAspectRatio_; + maxAspectRatio_ = other.maxAspectRatio_; + minArea_ = other.minArea_; + maxArea_ = other.maxArea_; + overlapThresh_ = other.overlapThresh_; + clipBoxes_ = other.clipBoxes_; + randomCoef_ = other.randomCoef_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPadFixedAspectRatioOperation Clone() { + return new SSDRandomCropPadFixedAspectRatioOperation(this); + } + + /// Field number for the "min_object_covered" field. + public const int MinObjectCoveredFieldNumber = 1; + private float minObjectCovered_; + /// + /// Cropped image must cover at least this fraction of one original bounding + /// box. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinObjectCovered { + get { return minObjectCovered_; } + set { + minObjectCovered_ = value; + } + } + + /// Field number for the "min_aspect_ratio" field. + public const int MinAspectRatioFieldNumber = 2; + private float minAspectRatio_; + /// + /// The aspect ratio of the cropped image must be within the range of + /// [min_aspect_ratio, max_aspect_ratio]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinAspectRatio { + get { return minAspectRatio_; } + set { + minAspectRatio_ = value; + } + } + + /// Field number for the "max_aspect_ratio" field. + public const int MaxAspectRatioFieldNumber = 3; + private float maxAspectRatio_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxAspectRatio { + get { return maxAspectRatio_; } + set { + maxAspectRatio_ = value; + } + } + + /// Field number for the "min_area" field. + public const int MinAreaFieldNumber = 4; + private float minArea_; + /// + /// The area of the cropped image must be within the range of + /// [min_area, max_area]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinArea { + get { return minArea_; } + set { + minArea_ = value; + } + } + + /// Field number for the "max_area" field. + public const int MaxAreaFieldNumber = 5; + private float maxArea_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxArea { + get { return maxArea_; } + set { + maxArea_ = value; + } + } + + /// Field number for the "overlap_thresh" field. + public const int OverlapThreshFieldNumber = 6; + private float overlapThresh_; + /// + /// Cropped box area ratio must be above this threhold to be kept. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float OverlapThresh { + get { return overlapThresh_; } + set { + overlapThresh_ = value; + } + } + + /// Field number for the "clip_boxes" field. + public const int ClipBoxesFieldNumber = 8; + private bool clipBoxes_; + /// + /// Whether to clip the boxes to the cropped image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ClipBoxes { + get { return clipBoxes_; } + set { + clipBoxes_ = value; + } + } + + /// Field number for the "random_coef" field. + public const int RandomCoefFieldNumber = 7; + private float randomCoef_; + /// + /// Probability a crop operation is skipped. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float RandomCoef { + get { return randomCoef_; } + set { + randomCoef_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SSDRandomCropPadFixedAspectRatioOperation); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SSDRandomCropPadFixedAspectRatioOperation other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinObjectCovered, other.MinObjectCovered)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinAspectRatio, other.MinAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxAspectRatio, other.MaxAspectRatio)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinArea, other.MinArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxArea, other.MaxArea)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(OverlapThresh, other.OverlapThresh)) return false; + if (ClipBoxes != other.ClipBoxes) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(RandomCoef, other.RandomCoef)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinObjectCovered != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinObjectCovered); + if (MinAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinAspectRatio); + if (MaxAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxAspectRatio); + if (MinArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinArea); + if (MaxArea != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxArea); + if (OverlapThresh != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(OverlapThresh); + if (ClipBoxes != false) hash ^= ClipBoxes.GetHashCode(); + if (RandomCoef != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(RandomCoef); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinObjectCovered != 0F) { + output.WriteRawTag(13); + output.WriteFloat(MinObjectCovered); + } + if (MinAspectRatio != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MinAspectRatio); + } + if (MaxAspectRatio != 0F) { + output.WriteRawTag(29); + output.WriteFloat(MaxAspectRatio); + } + if (MinArea != 0F) { + output.WriteRawTag(37); + output.WriteFloat(MinArea); + } + if (MaxArea != 0F) { + output.WriteRawTag(45); + output.WriteFloat(MaxArea); + } + if (OverlapThresh != 0F) { + output.WriteRawTag(53); + output.WriteFloat(OverlapThresh); + } + if (RandomCoef != 0F) { + output.WriteRawTag(61); + output.WriteFloat(RandomCoef); + } + if (ClipBoxes != false) { + output.WriteRawTag(64); + output.WriteBool(ClipBoxes); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinObjectCovered != 0F) { + size += 1 + 4; + } + if (MinAspectRatio != 0F) { + size += 1 + 4; + } + if (MaxAspectRatio != 0F) { + size += 1 + 4; + } + if (MinArea != 0F) { + size += 1 + 4; + } + if (MaxArea != 0F) { + size += 1 + 4; + } + if (OverlapThresh != 0F) { + size += 1 + 4; + } + if (ClipBoxes != false) { + size += 1 + 1; + } + if (RandomCoef != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SSDRandomCropPadFixedAspectRatioOperation other) { + if (other == null) { + return; + } + if (other.MinObjectCovered != 0F) { + MinObjectCovered = other.MinObjectCovered; + } + if (other.MinAspectRatio != 0F) { + MinAspectRatio = other.MinAspectRatio; + } + if (other.MaxAspectRatio != 0F) { + MaxAspectRatio = other.MaxAspectRatio; + } + if (other.MinArea != 0F) { + MinArea = other.MinArea; + } + if (other.MaxArea != 0F) { + MaxArea = other.MaxArea; + } + if (other.OverlapThresh != 0F) { + OverlapThresh = other.OverlapThresh; + } + if (other.ClipBoxes != false) { + ClipBoxes = other.ClipBoxes; + } + if (other.RandomCoef != 0F) { + RandomCoef = other.RandomCoef; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + MinObjectCovered = input.ReadFloat(); + break; + } + case 21: { + MinAspectRatio = input.ReadFloat(); + break; + } + case 29: { + MaxAspectRatio = input.ReadFloat(); + break; + } + case 37: { + MinArea = input.ReadFloat(); + break; + } + case 45: { + MaxArea = input.ReadFloat(); + break; + } + case 53: { + OverlapThresh = input.ReadFloat(); + break; + } + case 61: { + RandomCoef = input.ReadFloat(); + break; + } + case 64: { + ClipBoxes = input.ReadBool(); + break; + } + } + } + } + + } + + /// + /// Randomly crops and pads an image to a fixed aspect ratio according to: + /// Liu et al., SSD: Single shot multibox detector. + /// Multiple SSDRandomCropPadFixedAspectRatioOperations are defined by this + /// preprocessing step. Only one operation (chosen at random) is actually + /// performed on an image. + /// + public sealed partial class SSDRandomCropPadFixedAspectRatio : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SSDRandomCropPadFixedAspectRatio()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[32]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPadFixedAspectRatio() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPadFixedAspectRatio(SSDRandomCropPadFixedAspectRatio other) : this() { + operations_ = other.operations_.Clone(); + aspectRatio_ = other.aspectRatio_; + minPaddedSizeRatio_ = other.minPaddedSizeRatio_.Clone(); + maxPaddedSizeRatio_ = other.maxPaddedSizeRatio_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SSDRandomCropPadFixedAspectRatio Clone() { + return new SSDRandomCropPadFixedAspectRatio(this); + } + + /// Field number for the "operations" field. + public const int OperationsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_operations_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.Models.ObjectDetection.Protos.SSDRandomCropPadFixedAspectRatioOperation.Parser); + private readonly pbc::RepeatedField operations_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Operations { + get { return operations_; } + } + + /// Field number for the "aspect_ratio" field. + public const int AspectRatioFieldNumber = 2; + private float aspectRatio_; + /// + /// Aspect ratio to pad to. This value is used for all crop and pad operations. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float AspectRatio { + get { return aspectRatio_; } + set { + aspectRatio_ = value; + } + } + + /// Field number for the "min_padded_size_ratio" field. + public const int MinPaddedSizeRatioFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_minPaddedSizeRatio_codec + = pb::FieldCodec.ForFloat(26); + private readonly pbc::RepeatedField minPaddedSizeRatio_ = new pbc::RepeatedField(); + /// + /// Min ratio of padded image height and width to the input image's height and + /// width. Two entries per operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField MinPaddedSizeRatio { + get { return minPaddedSizeRatio_; } + } + + /// Field number for the "max_padded_size_ratio" field. + public const int MaxPaddedSizeRatioFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_maxPaddedSizeRatio_codec + = pb::FieldCodec.ForFloat(34); + private readonly pbc::RepeatedField maxPaddedSizeRatio_ = new pbc::RepeatedField(); + /// + /// Max ratio of padded image height and width to the input image's height and + /// width. Two entries per operation. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField MaxPaddedSizeRatio { + get { return maxPaddedSizeRatio_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SSDRandomCropPadFixedAspectRatio); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SSDRandomCropPadFixedAspectRatio other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!operations_.Equals(other.operations_)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(AspectRatio, other.AspectRatio)) return false; + if(!minPaddedSizeRatio_.Equals(other.minPaddedSizeRatio_)) return false; + if(!maxPaddedSizeRatio_.Equals(other.maxPaddedSizeRatio_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= operations_.GetHashCode(); + if (AspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(AspectRatio); + hash ^= minPaddedSizeRatio_.GetHashCode(); + hash ^= maxPaddedSizeRatio_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + operations_.WriteTo(output, _repeated_operations_codec); + if (AspectRatio != 0F) { + output.WriteRawTag(21); + output.WriteFloat(AspectRatio); + } + minPaddedSizeRatio_.WriteTo(output, _repeated_minPaddedSizeRatio_codec); + maxPaddedSizeRatio_.WriteTo(output, _repeated_maxPaddedSizeRatio_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += operations_.CalculateSize(_repeated_operations_codec); + if (AspectRatio != 0F) { + size += 1 + 4; + } + size += minPaddedSizeRatio_.CalculateSize(_repeated_minPaddedSizeRatio_codec); + size += maxPaddedSizeRatio_.CalculateSize(_repeated_maxPaddedSizeRatio_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SSDRandomCropPadFixedAspectRatio other) { + if (other == null) { + return; + } + operations_.Add(other.operations_); + if (other.AspectRatio != 0F) { + AspectRatio = other.AspectRatio; + } + minPaddedSizeRatio_.Add(other.minPaddedSizeRatio_); + maxPaddedSizeRatio_.Add(other.maxPaddedSizeRatio_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + operations_.AddEntriesFrom(input, _repeated_operations_codec); + break; + } + case 21: { + AspectRatio = input.ReadFloat(); + break; + } + case 26: + case 29: { + minPaddedSizeRatio_.AddEntriesFrom(input, _repeated_minPaddedSizeRatio_codec); + break; + } + case 34: + case 37: { + maxPaddedSizeRatio_.AddEntriesFrom(input, _repeated_maxPaddedSizeRatio_codec); + break; + } + } + } + } + + } + + /// + /// Converts class logits to softmax optionally scaling the values by temperature + /// first. + /// + public sealed partial class ConvertClassLogitsToSoftmax : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ConvertClassLogitsToSoftmax()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[33]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConvertClassLogitsToSoftmax() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConvertClassLogitsToSoftmax(ConvertClassLogitsToSoftmax other) : this() { + temperature_ = other.temperature_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ConvertClassLogitsToSoftmax Clone() { + return new ConvertClassLogitsToSoftmax(this); + } + + /// Field number for the "temperature" field. + public const int TemperatureFieldNumber = 1; + private float temperature_; + /// + /// Scale to use on logits before applying softmax. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float Temperature { + get { return temperature_; } + set { + temperature_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ConvertClassLogitsToSoftmax); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ConvertClassLogitsToSoftmax other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Temperature, other.Temperature)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Temperature != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Temperature); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Temperature != 0F) { + output.WriteRawTag(13); + output.WriteFloat(Temperature); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Temperature != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ConvertClassLogitsToSoftmax other) { + if (other == null) { + return; + } + if (other.Temperature != 0F) { + Temperature = other.Temperature; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + Temperature = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Randomly concatenates the image with itself horizontally and/or vertically. + /// + public sealed partial class RandomSelfConcatImage : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RandomSelfConcatImage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[34]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomSelfConcatImage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomSelfConcatImage(RandomSelfConcatImage other) : this() { + concatVerticalProbability_ = other.concatVerticalProbability_; + concatHorizontalProbability_ = other.concatHorizontalProbability_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RandomSelfConcatImage Clone() { + return new RandomSelfConcatImage(this); + } + + /// Field number for the "concat_vertical_probability" field. + public const int ConcatVerticalProbabilityFieldNumber = 1; + private float concatVerticalProbability_; + /// + /// Probability of concatenating the image vertically. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float ConcatVerticalProbability { + get { return concatVerticalProbability_; } + set { + concatVerticalProbability_ = value; + } + } + + /// Field number for the "concat_horizontal_probability" field. + public const int ConcatHorizontalProbabilityFieldNumber = 2; + private float concatHorizontalProbability_; + /// + /// Probability of concatenating the image horizontally. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float ConcatHorizontalProbability { + get { return concatHorizontalProbability_; } + set { + concatHorizontalProbability_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RandomSelfConcatImage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RandomSelfConcatImage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(ConcatVerticalProbability, other.ConcatVerticalProbability)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(ConcatHorizontalProbability, other.ConcatHorizontalProbability)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ConcatVerticalProbability != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(ConcatVerticalProbability); + if (ConcatHorizontalProbability != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(ConcatHorizontalProbability); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ConcatVerticalProbability != 0F) { + output.WriteRawTag(13); + output.WriteFloat(ConcatVerticalProbability); + } + if (ConcatHorizontalProbability != 0F) { + output.WriteRawTag(21); + output.WriteFloat(ConcatHorizontalProbability); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ConcatVerticalProbability != 0F) { + size += 1 + 4; + } + if (ConcatHorizontalProbability != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RandomSelfConcatImage other) { + if (other == null) { + return; + } + if (other.ConcatVerticalProbability != 0F) { + ConcatVerticalProbability = other.ConcatVerticalProbability; + } + if (other.ConcatHorizontalProbability != 0F) { + ConcatHorizontalProbability = other.ConcatHorizontalProbability; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + ConcatVerticalProbability = input.ReadFloat(); + break; + } + case 21: { + ConcatHorizontalProbability = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + /// Apply an Autoaugment policy to the image and bounding boxes. + /// + public sealed partial class AutoAugmentImage : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new AutoAugmentImage()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[35]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AutoAugmentImage() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AutoAugmentImage(AutoAugmentImage other) : this() { + policyName_ = other.policyName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AutoAugmentImage Clone() { + return new AutoAugmentImage(this); + } + + /// Field number for the "policy_name" field. + public const int PolicyNameFieldNumber = 1; + private string policyName_ = ""; + /// + /// What AutoAugment policy to apply to the Image + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PolicyName { + get { return policyName_; } + set { + policyName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as AutoAugmentImage); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(AutoAugmentImage other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (PolicyName != other.PolicyName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (PolicyName.Length != 0) hash ^= PolicyName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (PolicyName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(PolicyName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (PolicyName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PolicyName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(AutoAugmentImage other) { + if (other == null) { + return; + } + if (other.PolicyName.Length != 0) { + PolicyName = other.PolicyName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + PolicyName = input.ReadString(); + break; + } + } + } + } + + } + + /// + /// Randomly drops ground truth boxes for a label with some probability. + /// + public sealed partial class DropLabelProbabilistically : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new DropLabelProbabilistically()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[36]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public DropLabelProbabilistically() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public DropLabelProbabilistically(DropLabelProbabilistically other) : this() { + label_ = other.label_; + dropProbability_ = other.dropProbability_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public DropLabelProbabilistically Clone() { + return new DropLabelProbabilistically(this); + } + + /// Field number for the "label" field. + public const int LabelFieldNumber = 1; + private int label_; + /// + /// The label that should be dropped. This corresponds to one of the entries + /// in the label map. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Label { + get { return label_; } + set { + label_ = value; + } + } + + /// Field number for the "drop_probability" field. + public const int DropProbabilityFieldNumber = 2; + private float dropProbability_; + /// + /// Probability of dropping the label. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float DropProbability { + get { return dropProbability_; } + set { + dropProbability_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as DropLabelProbabilistically); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(DropLabelProbabilistically other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Label != other.Label) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(DropProbability, other.DropProbability)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Label != 0) hash ^= Label.GetHashCode(); + if (DropProbability != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(DropProbability); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Label != 0) { + output.WriteRawTag(8); + output.WriteInt32(Label); + } + if (DropProbability != 0F) { + output.WriteRawTag(21); + output.WriteFloat(DropProbability); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Label != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Label); + } + if (DropProbability != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(DropLabelProbabilistically other) { + if (other == null) { + return; + } + if (other.Label != 0) { + Label = other.Label; + } + if (other.DropProbability != 0F) { + DropProbability = other.DropProbability; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Label = input.ReadInt32(); + break; + } + case 21: { + DropProbability = input.ReadFloat(); + break; + } + } + } + } + + } + + /// + ///Remap a set of labels to a new label. + /// + public sealed partial class RemapLabels : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RemapLabels()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor.MessageTypes[37]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RemapLabels() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RemapLabels(RemapLabels other) : this() { + originalLabels_ = other.originalLabels_.Clone(); + newLabel_ = other.newLabel_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RemapLabels Clone() { + return new RemapLabels(this); + } + + /// Field number for the "original_labels" field. + public const int OriginalLabelsFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_originalLabels_codec + = pb::FieldCodec.ForInt32(10); + private readonly pbc::RepeatedField originalLabels_ = new pbc::RepeatedField(); + /// + /// Labels to be remapped. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField OriginalLabels { + get { return originalLabels_; } + } + + /// Field number for the "new_label" field. + public const int NewLabelFieldNumber = 2; + private int newLabel_; + /// + /// Label to map to. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NewLabel { + get { return newLabel_; } + set { + newLabel_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RemapLabels); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RemapLabels other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!originalLabels_.Equals(other.originalLabels_)) return false; + if (NewLabel != other.NewLabel) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= originalLabels_.GetHashCode(); + if (NewLabel != 0) hash ^= NewLabel.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + originalLabels_.WriteTo(output, _repeated_originalLabels_codec); + if (NewLabel != 0) { + output.WriteRawTag(16); + output.WriteInt32(NewLabel); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += originalLabels_.CalculateSize(_repeated_originalLabels_codec); + if (NewLabel != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NewLabel); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RemapLabels other) { + if (other == null) { + return; + } + originalLabels_.Add(other.originalLabels_); + if (other.NewLabel != 0) { + NewLabel = other.NewLabel; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: + case 8: { + originalLabels_.AddEntriesFrom(input, _repeated_originalLabels_codec); + break; + } + case 16: { + NewLabel = input.ReadInt32(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/RegionSimilarityCalculator.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/RegionSimilarityCalculator.cs new file mode 100644 index 00000000..9b06eebc --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/RegionSimilarityCalculator.cs @@ -0,0 +1,791 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/region_similarity_calculator.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/region_similarity_calculator.proto + public static partial class RegionSimilarityCalculatorReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/region_similarity_calculator.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static RegionSimilarityCalculatorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjpvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9yZWdpb25fc2ltaWxhcml0eV9j", + "YWxjdWxhdG9yLnByb3RvEhdvYmplY3RfZGV0ZWN0aW9uLnByb3RvcyLeAgoa", + "UmVnaW9uU2ltaWxhcml0eUNhbGN1bGF0b3ISTgoWbmVnX3NxX2Rpc3Rfc2lt", + "aWxhcml0eRgBIAEoCzIsLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLk5lZ1Nx", + "RGlzdFNpbWlsYXJpdHlIABJACg5pb3Vfc2ltaWxhcml0eRgCIAEoCzImLm9i", + "amVjdF9kZXRlY3Rpb24ucHJvdG9zLklvdVNpbWlsYXJpdHlIABJACg5pb2Ff", + "c2ltaWxhcml0eRgDIAEoCzImLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLklv", + "YVNpbWlsYXJpdHlIABJXChp0aHJlc2hvbGRlZF9pb3Vfc2ltaWxhcml0eRgE", + "IAEoCzIxLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlRocmVzaG9sZGVkSW91", + "U2ltaWxhcml0eUgAQhMKEXJlZ2lvbl9zaW1pbGFyaXR5IhUKE05lZ1NxRGlz", + "dFNpbWlsYXJpdHkiDwoNSW91U2ltaWxhcml0eSIPCg1Jb2FTaW1pbGFyaXR5", + "IjEKGFRocmVzaG9sZGVkSW91U2ltaWxhcml0eRIVCg1pb3VfdGhyZXNob2xk", + "GAEgASgCYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculator), global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculator.Parser, new[]{ "NegSqDistSimilarity", "IouSimilarity", "IoaSimilarity", "ThresholdedIouSimilarity" }, new[]{ "RegionSimilarity" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.NegSqDistSimilarity), global::Tensorflow.Models.ObjectDetection.Protos.NegSqDistSimilarity.Parser, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.IouSimilarity), global::Tensorflow.Models.ObjectDetection.Protos.IouSimilarity.Parser, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.IoaSimilarity), global::Tensorflow.Models.ObjectDetection.Protos.IoaSimilarity.Parser, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.ThresholdedIouSimilarity), global::Tensorflow.Models.ObjectDetection.Protos.ThresholdedIouSimilarity.Parser, new[]{ "IouThreshold" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for region similarity calculators. See + /// core/region_similarity_calculator.py for details. + /// + public sealed partial class RegionSimilarityCalculator : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RegionSimilarityCalculator()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculatorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RegionSimilarityCalculator() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RegionSimilarityCalculator(RegionSimilarityCalculator other) : this() { + switch (other.RegionSimilarityCase) { + case RegionSimilarityOneofCase.NegSqDistSimilarity: + NegSqDistSimilarity = other.NegSqDistSimilarity.Clone(); + break; + case RegionSimilarityOneofCase.IouSimilarity: + IouSimilarity = other.IouSimilarity.Clone(); + break; + case RegionSimilarityOneofCase.IoaSimilarity: + IoaSimilarity = other.IoaSimilarity.Clone(); + break; + case RegionSimilarityOneofCase.ThresholdedIouSimilarity: + ThresholdedIouSimilarity = other.ThresholdedIouSimilarity.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RegionSimilarityCalculator Clone() { + return new RegionSimilarityCalculator(this); + } + + /// Field number for the "neg_sq_dist_similarity" field. + public const int NegSqDistSimilarityFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.NegSqDistSimilarity NegSqDistSimilarity { + get { return regionSimilarityCase_ == RegionSimilarityOneofCase.NegSqDistSimilarity ? (global::Tensorflow.Models.ObjectDetection.Protos.NegSqDistSimilarity) regionSimilarity_ : null; } + set { + regionSimilarity_ = value; + regionSimilarityCase_ = value == null ? RegionSimilarityOneofCase.None : RegionSimilarityOneofCase.NegSqDistSimilarity; + } + } + + /// Field number for the "iou_similarity" field. + public const int IouSimilarityFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.IouSimilarity IouSimilarity { + get { return regionSimilarityCase_ == RegionSimilarityOneofCase.IouSimilarity ? (global::Tensorflow.Models.ObjectDetection.Protos.IouSimilarity) regionSimilarity_ : null; } + set { + regionSimilarity_ = value; + regionSimilarityCase_ = value == null ? RegionSimilarityOneofCase.None : RegionSimilarityOneofCase.IouSimilarity; + } + } + + /// Field number for the "ioa_similarity" field. + public const int IoaSimilarityFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.IoaSimilarity IoaSimilarity { + get { return regionSimilarityCase_ == RegionSimilarityOneofCase.IoaSimilarity ? (global::Tensorflow.Models.ObjectDetection.Protos.IoaSimilarity) regionSimilarity_ : null; } + set { + regionSimilarity_ = value; + regionSimilarityCase_ = value == null ? RegionSimilarityOneofCase.None : RegionSimilarityOneofCase.IoaSimilarity; + } + } + + /// Field number for the "thresholded_iou_similarity" field. + public const int ThresholdedIouSimilarityFieldNumber = 4; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ThresholdedIouSimilarity ThresholdedIouSimilarity { + get { return regionSimilarityCase_ == RegionSimilarityOneofCase.ThresholdedIouSimilarity ? (global::Tensorflow.Models.ObjectDetection.Protos.ThresholdedIouSimilarity) regionSimilarity_ : null; } + set { + regionSimilarity_ = value; + regionSimilarityCase_ = value == null ? RegionSimilarityOneofCase.None : RegionSimilarityOneofCase.ThresholdedIouSimilarity; + } + } + + private object regionSimilarity_; + /// Enum of possible cases for the "region_similarity" oneof. + public enum RegionSimilarityOneofCase { + None = 0, + NegSqDistSimilarity = 1, + IouSimilarity = 2, + IoaSimilarity = 3, + ThresholdedIouSimilarity = 4, + } + private RegionSimilarityOneofCase regionSimilarityCase_ = RegionSimilarityOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RegionSimilarityOneofCase RegionSimilarityCase { + get { return regionSimilarityCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearRegionSimilarity() { + regionSimilarityCase_ = RegionSimilarityOneofCase.None; + regionSimilarity_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RegionSimilarityCalculator); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RegionSimilarityCalculator other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(NegSqDistSimilarity, other.NegSqDistSimilarity)) return false; + if (!object.Equals(IouSimilarity, other.IouSimilarity)) return false; + if (!object.Equals(IoaSimilarity, other.IoaSimilarity)) return false; + if (!object.Equals(ThresholdedIouSimilarity, other.ThresholdedIouSimilarity)) return false; + if (RegionSimilarityCase != other.RegionSimilarityCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (regionSimilarityCase_ == RegionSimilarityOneofCase.NegSqDistSimilarity) hash ^= NegSqDistSimilarity.GetHashCode(); + if (regionSimilarityCase_ == RegionSimilarityOneofCase.IouSimilarity) hash ^= IouSimilarity.GetHashCode(); + if (regionSimilarityCase_ == RegionSimilarityOneofCase.IoaSimilarity) hash ^= IoaSimilarity.GetHashCode(); + if (regionSimilarityCase_ == RegionSimilarityOneofCase.ThresholdedIouSimilarity) hash ^= ThresholdedIouSimilarity.GetHashCode(); + hash ^= (int) regionSimilarityCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (regionSimilarityCase_ == RegionSimilarityOneofCase.NegSqDistSimilarity) { + output.WriteRawTag(10); + output.WriteMessage(NegSqDistSimilarity); + } + if (regionSimilarityCase_ == RegionSimilarityOneofCase.IouSimilarity) { + output.WriteRawTag(18); + output.WriteMessage(IouSimilarity); + } + if (regionSimilarityCase_ == RegionSimilarityOneofCase.IoaSimilarity) { + output.WriteRawTag(26); + output.WriteMessage(IoaSimilarity); + } + if (regionSimilarityCase_ == RegionSimilarityOneofCase.ThresholdedIouSimilarity) { + output.WriteRawTag(34); + output.WriteMessage(ThresholdedIouSimilarity); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (regionSimilarityCase_ == RegionSimilarityOneofCase.NegSqDistSimilarity) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(NegSqDistSimilarity); + } + if (regionSimilarityCase_ == RegionSimilarityOneofCase.IouSimilarity) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(IouSimilarity); + } + if (regionSimilarityCase_ == RegionSimilarityOneofCase.IoaSimilarity) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(IoaSimilarity); + } + if (regionSimilarityCase_ == RegionSimilarityOneofCase.ThresholdedIouSimilarity) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ThresholdedIouSimilarity); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RegionSimilarityCalculator other) { + if (other == null) { + return; + } + switch (other.RegionSimilarityCase) { + case RegionSimilarityOneofCase.NegSqDistSimilarity: + if (NegSqDistSimilarity == null) { + NegSqDistSimilarity = new global::Tensorflow.Models.ObjectDetection.Protos.NegSqDistSimilarity(); + } + NegSqDistSimilarity.MergeFrom(other.NegSqDistSimilarity); + break; + case RegionSimilarityOneofCase.IouSimilarity: + if (IouSimilarity == null) { + IouSimilarity = new global::Tensorflow.Models.ObjectDetection.Protos.IouSimilarity(); + } + IouSimilarity.MergeFrom(other.IouSimilarity); + break; + case RegionSimilarityOneofCase.IoaSimilarity: + if (IoaSimilarity == null) { + IoaSimilarity = new global::Tensorflow.Models.ObjectDetection.Protos.IoaSimilarity(); + } + IoaSimilarity.MergeFrom(other.IoaSimilarity); + break; + case RegionSimilarityOneofCase.ThresholdedIouSimilarity: + if (ThresholdedIouSimilarity == null) { + ThresholdedIouSimilarity = new global::Tensorflow.Models.ObjectDetection.Protos.ThresholdedIouSimilarity(); + } + ThresholdedIouSimilarity.MergeFrom(other.ThresholdedIouSimilarity); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.Models.ObjectDetection.Protos.NegSqDistSimilarity subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.NegSqDistSimilarity(); + if (regionSimilarityCase_ == RegionSimilarityOneofCase.NegSqDistSimilarity) { + subBuilder.MergeFrom(NegSqDistSimilarity); + } + input.ReadMessage(subBuilder); + NegSqDistSimilarity = subBuilder; + break; + } + case 18: { + global::Tensorflow.Models.ObjectDetection.Protos.IouSimilarity subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.IouSimilarity(); + if (regionSimilarityCase_ == RegionSimilarityOneofCase.IouSimilarity) { + subBuilder.MergeFrom(IouSimilarity); + } + input.ReadMessage(subBuilder); + IouSimilarity = subBuilder; + break; + } + case 26: { + global::Tensorflow.Models.ObjectDetection.Protos.IoaSimilarity subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.IoaSimilarity(); + if (regionSimilarityCase_ == RegionSimilarityOneofCase.IoaSimilarity) { + subBuilder.MergeFrom(IoaSimilarity); + } + input.ReadMessage(subBuilder); + IoaSimilarity = subBuilder; + break; + } + case 34: { + global::Tensorflow.Models.ObjectDetection.Protos.ThresholdedIouSimilarity subBuilder = new global::Tensorflow.Models.ObjectDetection.Protos.ThresholdedIouSimilarity(); + if (regionSimilarityCase_ == RegionSimilarityOneofCase.ThresholdedIouSimilarity) { + subBuilder.MergeFrom(ThresholdedIouSimilarity); + } + input.ReadMessage(subBuilder); + ThresholdedIouSimilarity = subBuilder; + break; + } + } + } + } + + } + + /// + /// Configuration for negative squared distance similarity calculator. + /// + public sealed partial class NegSqDistSimilarity : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new NegSqDistSimilarity()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculatorReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NegSqDistSimilarity() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NegSqDistSimilarity(NegSqDistSimilarity other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public NegSqDistSimilarity Clone() { + return new NegSqDistSimilarity(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as NegSqDistSimilarity); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(NegSqDistSimilarity other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(NegSqDistSimilarity other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + /// + /// Configuration for intersection-over-union (IOU) similarity calculator. + /// + public sealed partial class IouSimilarity : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new IouSimilarity()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculatorReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IouSimilarity() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IouSimilarity(IouSimilarity other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IouSimilarity Clone() { + return new IouSimilarity(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as IouSimilarity); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(IouSimilarity other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(IouSimilarity other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + /// + /// Configuration for intersection-over-area (IOA) similarity calculator. + /// + public sealed partial class IoaSimilarity : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new IoaSimilarity()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculatorReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IoaSimilarity() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IoaSimilarity(IoaSimilarity other) : this() { + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public IoaSimilarity Clone() { + return new IoaSimilarity(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as IoaSimilarity); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(IoaSimilarity other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(IoaSimilarity other) { + if (other == null) { + return; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + } + } + } + + } + + /// + /// Configuration for thresholded-intersection-over-union similarity calculator. + /// + public sealed partial class ThresholdedIouSimilarity : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ThresholdedIouSimilarity()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculatorReflection.Descriptor.MessageTypes[4]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ThresholdedIouSimilarity() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ThresholdedIouSimilarity(ThresholdedIouSimilarity other) : this() { + iouThreshold_ = other.iouThreshold_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ThresholdedIouSimilarity Clone() { + return new ThresholdedIouSimilarity(this); + } + + /// Field number for the "iou_threshold" field. + public const int IouThresholdFieldNumber = 1; + private float iouThreshold_; + /// + /// IOU threshold used for filtering scores. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float IouThreshold { + get { return iouThreshold_; } + set { + iouThreshold_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ThresholdedIouSimilarity); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ThresholdedIouSimilarity other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(IouThreshold, other.IouThreshold)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (IouThreshold != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(IouThreshold); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (IouThreshold != 0F) { + output.WriteRawTag(13); + output.WriteFloat(IouThreshold); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (IouThreshold != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ThresholdedIouSimilarity other) { + if (other == null) { + return; + } + if (other.IouThreshold != 0F) { + IouThreshold = other.IouThreshold; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + IouThreshold = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/SquareBoxCoder.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/SquareBoxCoder.cs new file mode 100644 index 00000000..a1f3d794 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/SquareBoxCoder.cs @@ -0,0 +1,240 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/square_box_coder.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/square_box_coder.proto + public static partial class SquareBoxCoderReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/square_box_coder.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SquareBoxCoderReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Ci5vYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9zcXVhcmVfYm94X2NvZGVyLnBy", + "b3RvEhdvYmplY3RfZGV0ZWN0aW9uLnByb3RvcyJICg5TcXVhcmVCb3hDb2Rl", + "chIPCgd5X3NjYWxlGAEgASgCEg8KB3hfc2NhbGUYAiABKAISFAoMbGVuZ3Ro", + "X3NjYWxlGAMgASgCYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SquareBoxCoder), global::Tensorflow.Models.ObjectDetection.Protos.SquareBoxCoder.Parser, new[]{ "YScale", "XScale", "LengthScale" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for SquareBoxCoder. See + /// box_coders/square_box_coder.py for details. + /// + public sealed partial class SquareBoxCoder : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SquareBoxCoder()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.SquareBoxCoderReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SquareBoxCoder() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SquareBoxCoder(SquareBoxCoder other) : this() { + yScale_ = other.yScale_; + xScale_ = other.xScale_; + lengthScale_ = other.lengthScale_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SquareBoxCoder Clone() { + return new SquareBoxCoder(this); + } + + /// Field number for the "y_scale" field. + public const int YScaleFieldNumber = 1; + private float yScale_; + /// + /// Scale factor for anchor encoded box center. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float YScale { + get { return yScale_; } + set { + yScale_ = value; + } + } + + /// Field number for the "x_scale" field. + public const int XScaleFieldNumber = 2; + private float xScale_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float XScale { + get { return xScale_; } + set { + xScale_ = value; + } + } + + /// Field number for the "length_scale" field. + public const int LengthScaleFieldNumber = 3; + private float lengthScale_; + /// + /// Scale factor for anchor encoded box length. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float LengthScale { + get { return lengthScale_; } + set { + lengthScale_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SquareBoxCoder); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SquareBoxCoder other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(YScale, other.YScale)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(XScale, other.XScale)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(LengthScale, other.LengthScale)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (YScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(YScale); + if (XScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(XScale); + if (LengthScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(LengthScale); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (YScale != 0F) { + output.WriteRawTag(13); + output.WriteFloat(YScale); + } + if (XScale != 0F) { + output.WriteRawTag(21); + output.WriteFloat(XScale); + } + if (LengthScale != 0F) { + output.WriteRawTag(29); + output.WriteFloat(LengthScale); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (YScale != 0F) { + size += 1 + 4; + } + if (XScale != 0F) { + size += 1 + 4; + } + if (LengthScale != 0F) { + size += 1 + 4; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SquareBoxCoder other) { + if (other == null) { + return; + } + if (other.YScale != 0F) { + YScale = other.YScale; + } + if (other.XScale != 0F) { + XScale = other.XScale; + } + if (other.LengthScale != 0F) { + LengthScale = other.LengthScale; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 13: { + YScale = input.ReadFloat(); + break; + } + case 21: { + XScale = input.ReadFloat(); + break; + } + case 29: { + LengthScale = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Ssd.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Ssd.cs new file mode 100644 index 00000000..1a43ba6c --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Ssd.cs @@ -0,0 +1,2028 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/ssd.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/ssd.proto + public static partial class SsdReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/ssd.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SsdReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiFvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9zc2QucHJvdG8SF29iamVjdF9k", + "ZXRlY3Rpb24ucHJvdG9zGi5vYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9hbmNo", + "b3JfZ2VuZXJhdG9yLnByb3RvGidvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9i", + "b3hfY29kZXIucHJvdG8aK29iamVjdF9kZXRlY3Rpb24vcHJvdG9zL2JveF9w", + "cmVkaWN0b3IucHJvdG8aKW9iamVjdF9kZXRlY3Rpb24vcHJvdG9zL2h5cGVy", + "cGFyYW1zLnByb3RvGitvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9pbWFnZV9y", + "ZXNpemVyLnByb3RvGiRvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9sb3NzZXMu", + "cHJvdG8aJW9iamVjdF9kZXRlY3Rpb24vcHJvdG9zL21hdGNoZXIucHJvdG8a", + "LW9iamVjdF9kZXRlY3Rpb24vcHJvdG9zL3Bvc3RfcHJvY2Vzc2luZy5wcm90", + "bxo6b2JqZWN0X2RldGVjdGlvbi9wcm90b3MvcmVnaW9uX3NpbWlsYXJpdHlf", + "Y2FsY3VsYXRvci5wcm90byLDCgoDU3NkEhMKC251bV9jbGFzc2VzGAEgASgF", + "EjwKDWltYWdlX3Jlc2l6ZXIYAiABKAsyJS5vYmplY3RfZGV0ZWN0aW9uLnBy", + "b3Rvcy5JbWFnZVJlc2l6ZXISRwoRZmVhdHVyZV9leHRyYWN0b3IYAyABKAsy", + "LC5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5Tc2RGZWF0dXJlRXh0cmFjdG9y", + "EjQKCWJveF9jb2RlchgEIAEoCzIhLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9z", + "LkJveENvZGVyEjEKB21hdGNoZXIYBSABKAsyIC5vYmplY3RfZGV0ZWN0aW9u", + "LnByb3Rvcy5NYXRjaGVyElIKFXNpbWlsYXJpdHlfY2FsY3VsYXRvchgGIAEo", + "CzIzLm9iamVjdF9kZXRlY3Rpb24ucHJvdG9zLlJlZ2lvblNpbWlsYXJpdHlD", + "YWxjdWxhdG9yEiIKGmVuY29kZV9iYWNrZ3JvdW5kX2FzX3plcm9zGAwgASgI", + "Eh0KFW5lZ2F0aXZlX2NsYXNzX3dlaWdodBgNIAEoAhI8Cg1ib3hfcHJlZGlj", + "dG9yGAcgASgLMiUub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuQm94UHJlZGlj", + "dG9yEkIKEGFuY2hvcl9nZW5lcmF0b3IYCCABKAsyKC5vYmplY3RfZGV0ZWN0", + "aW9uLnByb3Rvcy5BbmNob3JHZW5lcmF0b3ISQAoPcG9zdF9wcm9jZXNzaW5n", + "GAkgASgLMicub2JqZWN0X2RldGVjdGlvbi5wcm90b3MuUG9zdFByb2Nlc3Np", + "bmcSJQodbm9ybWFsaXplX2xvc3NfYnlfbnVtX21hdGNoZXMYCiABKAgSJgoe", + "bm9ybWFsaXplX2xvY19sb3NzX2J5X2NvZGVzaXplGA4gASgIEisKBGxvc3MY", + "CyABKAsyHS5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5Mb3NzEhgKEGZyZWV6", + "ZV9iYXRjaG5vcm0YECABKAgSIAoYaW5wbGFjZV9iYXRjaG5vcm1fdXBkYXRl", + "GA8gASgIEhwKFGFkZF9iYWNrZ3JvdW5kX2NsYXNzGBUgASgIEiEKGWV4cGxp", + "Y2l0X2JhY2tncm91bmRfY2xhc3MYGCABKAgSIgoadXNlX2NvbmZpZGVuY2Vz", + "X2FzX3RhcmdldHMYFiABKAgSHwoXaW1wbGljaXRfZXhhbXBsZV93ZWlnaHQY", + "FyABKAISPwoQbWFza19oZWFkX2NvbmZpZxgZIAEoCzIlLm9iamVjdF9kZXRl", + "Y3Rpb24ucHJvdG9zLlNzZC5NYXNrSGVhZBrcAgoITWFza0hlYWQSEwoLbWFz", + "a19oZWlnaHQYASABKAUSEgoKbWFza193aWR0aBgCIAEoBRIgChhtYXNrc19h", + "cmVfY2xhc3NfYWdub3N0aWMYAyABKAgSIgoabWFza19wcmVkaWN0aW9uX2Nv", + "bnZfZGVwdGgYBCABKAUSJwofbWFza19wcmVkaWN0aW9uX251bV9jb252X2xh", + "eWVycxgFIAEoBRIkChxjb252b2x2ZV90aGVuX3Vwc2FtcGxlX21hc2tzGAYg", + "ASgIEhgKEG1hc2tfbG9zc193ZWlnaHQYByABKAISHQoVbWFza19sb3NzX3Nh", + "bXBsZV9zaXplGAggASgFEj4KEGNvbnZfaHlwZXJwYXJhbXMYCSABKAsyJC5v", + "YmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5IeXBlcnBhcmFtcxIZChFpbml0aWFs", + "X2Nyb3Bfc2l6ZRgKIAEoBSKaAwoTU3NkRmVhdHVyZUV4dHJhY3RvchIMCgR0", + "eXBlGAEgASgJEhgKEGRlcHRoX211bHRpcGxpZXIYAiABKAISEQoJbWluX2Rl", + "cHRoGAMgASgFEj4KEGNvbnZfaHlwZXJwYXJhbXMYBCABKAsyJC5vYmplY3Rf", + "ZGV0ZWN0aW9uLnByb3Rvcy5IeXBlcnBhcmFtcxIzCitvdmVycmlkZV9iYXNl", + "X2ZlYXR1cmVfZXh0cmFjdG9yX2h5cGVycGFyYW1zGAkgASgIEhcKD3BhZF90", + "b19tdWx0aXBsZRgFIAEoBRIcChR1c2VfZXhwbGljaXRfcGFkZGluZxgHIAEo", + "CBIVCg11c2VfZGVwdGh3aXNlGAggASgIEjwKA2ZwbhgKIAEoCzIvLm9iamVj", + "dF9kZXRlY3Rpb24ucHJvdG9zLkZlYXR1cmVQeXJhbWlkTmV0d29ya3MSLQol", + "cmVwbGFjZV9wcmVwcm9jZXNzb3Jfd2l0aF9wbGFjZWhvbGRlchgLIAEoCBIS", + "CgpudW1fbGF5ZXJzGAwgASgFSgQIBhAHIl4KFkZlYXR1cmVQeXJhbWlkTmV0", + "d29ya3MSEQoJbWluX2xldmVsGAEgASgFEhEKCW1heF9sZXZlbBgCIAEoBRIe", + "ChZhZGRpdGlvbmFsX2xheWVyX2RlcHRoGAMgASgFYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.AnchorGeneratorReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.BoxCoderReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictorReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.HyperparamsReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.ImageResizerReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.LossesReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.MatcherReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.PostProcessingReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculatorReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Ssd), global::Tensorflow.Models.ObjectDetection.Protos.Ssd.Parser, new[]{ "NumClasses", "ImageResizer", "FeatureExtractor", "BoxCoder", "Matcher", "SimilarityCalculator", "EncodeBackgroundAsZeros", "NegativeClassWeight", "BoxPredictor", "AnchorGenerator", "PostProcessing", "NormalizeLossByNumMatches", "NormalizeLocLossByCodesize", "Loss", "FreezeBatchnorm", "InplaceBatchnormUpdate", "AddBackgroundClass", "ExplicitBackgroundClass", "UseConfidencesAsTargets", "ImplicitExampleWeight", "MaskHeadConfig" }, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.Ssd.Types.MaskHead), global::Tensorflow.Models.ObjectDetection.Protos.Ssd.Types.MaskHead.Parser, new[]{ "MaskHeight", "MaskWidth", "MasksAreClassAgnostic", "MaskPredictionConvDepth", "MaskPredictionNumConvLayers", "ConvolveThenUpsampleMasks", "MaskLossWeight", "MaskLossSampleSize", "ConvHyperparams", "InitialCropSize" }, null, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SsdFeatureExtractor), global::Tensorflow.Models.ObjectDetection.Protos.SsdFeatureExtractor.Parser, new[]{ "Type", "DepthMultiplier", "MinDepth", "ConvHyperparams", "OverrideBaseFeatureExtractorHyperparams", "PadToMultiple", "UseExplicitPadding", "UseDepthwise", "Fpn", "ReplacePreprocessorWithPlaceholder", "NumLayers" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.FeaturePyramidNetworks), global::Tensorflow.Models.ObjectDetection.Protos.FeaturePyramidNetworks.Parser, new[]{ "MinLevel", "MaxLevel", "AdditionalLayerDepth" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration for Single Shot Detection (SSD) models. + /// Next id: 26 + /// + public sealed partial class Ssd : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new Ssd()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.SsdReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Ssd() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Ssd(Ssd other) : this() { + numClasses_ = other.numClasses_; + imageResizer_ = other.imageResizer_ != null ? other.imageResizer_.Clone() : null; + featureExtractor_ = other.featureExtractor_ != null ? other.featureExtractor_.Clone() : null; + boxCoder_ = other.boxCoder_ != null ? other.boxCoder_.Clone() : null; + matcher_ = other.matcher_ != null ? other.matcher_.Clone() : null; + similarityCalculator_ = other.similarityCalculator_ != null ? other.similarityCalculator_.Clone() : null; + encodeBackgroundAsZeros_ = other.encodeBackgroundAsZeros_; + negativeClassWeight_ = other.negativeClassWeight_; + boxPredictor_ = other.boxPredictor_ != null ? other.boxPredictor_.Clone() : null; + anchorGenerator_ = other.anchorGenerator_ != null ? other.anchorGenerator_.Clone() : null; + postProcessing_ = other.postProcessing_ != null ? other.postProcessing_.Clone() : null; + normalizeLossByNumMatches_ = other.normalizeLossByNumMatches_; + normalizeLocLossByCodesize_ = other.normalizeLocLossByCodesize_; + loss_ = other.loss_ != null ? other.loss_.Clone() : null; + freezeBatchnorm_ = other.freezeBatchnorm_; + inplaceBatchnormUpdate_ = other.inplaceBatchnormUpdate_; + addBackgroundClass_ = other.addBackgroundClass_; + explicitBackgroundClass_ = other.explicitBackgroundClass_; + useConfidencesAsTargets_ = other.useConfidencesAsTargets_; + implicitExampleWeight_ = other.implicitExampleWeight_; + maskHeadConfig_ = other.maskHeadConfig_ != null ? other.maskHeadConfig_.Clone() : null; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public Ssd Clone() { + return new Ssd(this); + } + + /// Field number for the "num_classes" field. + public const int NumClassesFieldNumber = 1; + private int numClasses_; + /// + /// Number of classes to predict. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumClasses { + get { return numClasses_; } + set { + numClasses_ = value; + } + } + + /// Field number for the "image_resizer" field. + public const int ImageResizerFieldNumber = 2; + private global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer imageResizer_; + /// + /// Image resizer for preprocessing the input image. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer ImageResizer { + get { return imageResizer_; } + set { + imageResizer_ = value; + } + } + + /// Field number for the "feature_extractor" field. + public const int FeatureExtractorFieldNumber = 3; + private global::Tensorflow.Models.ObjectDetection.Protos.SsdFeatureExtractor featureExtractor_; + /// + /// Feature extractor config. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.SsdFeatureExtractor FeatureExtractor { + get { return featureExtractor_; } + set { + featureExtractor_ = value; + } + } + + /// Field number for the "box_coder" field. + public const int BoxCoderFieldNumber = 4; + private global::Tensorflow.Models.ObjectDetection.Protos.BoxCoder boxCoder_; + /// + /// Box coder to encode the boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.BoxCoder BoxCoder { + get { return boxCoder_; } + set { + boxCoder_ = value; + } + } + + /// Field number for the "matcher" field. + public const int MatcherFieldNumber = 5; + private global::Tensorflow.Models.ObjectDetection.Protos.Matcher matcher_; + /// + /// Matcher to match groundtruth with anchors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Matcher Matcher { + get { return matcher_; } + set { + matcher_ = value; + } + } + + /// Field number for the "similarity_calculator" field. + public const int SimilarityCalculatorFieldNumber = 6; + private global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculator similarityCalculator_; + /// + /// Region similarity calculator to compute similarity of boxes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculator SimilarityCalculator { + get { return similarityCalculator_; } + set { + similarityCalculator_ = value; + } + } + + /// Field number for the "encode_background_as_zeros" field. + public const int EncodeBackgroundAsZerosFieldNumber = 12; + private bool encodeBackgroundAsZeros_; + /// + /// Whether background targets are to be encoded as an all + /// zeros vector or a one-hot vector (where background is the 0th class). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool EncodeBackgroundAsZeros { + get { return encodeBackgroundAsZeros_; } + set { + encodeBackgroundAsZeros_ = value; + } + } + + /// Field number for the "negative_class_weight" field. + public const int NegativeClassWeightFieldNumber = 13; + private float negativeClassWeight_; + /// + /// classification weight to be associated to negative + /// anchors (default: 1.0). The weight must be in [0., 1.]. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float NegativeClassWeight { + get { return negativeClassWeight_; } + set { + negativeClassWeight_ = value; + } + } + + /// Field number for the "box_predictor" field. + public const int BoxPredictorFieldNumber = 7; + private global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor boxPredictor_; + /// + /// Box predictor to attach to the features. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor BoxPredictor { + get { return boxPredictor_; } + set { + boxPredictor_ = value; + } + } + + /// Field number for the "anchor_generator" field. + public const int AnchorGeneratorFieldNumber = 8; + private global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator anchorGenerator_; + /// + /// Anchor generator to compute anchors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator AnchorGenerator { + get { return anchorGenerator_; } + set { + anchorGenerator_ = value; + } + } + + /// Field number for the "post_processing" field. + public const int PostProcessingFieldNumber = 9; + private global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing postProcessing_; + /// + /// Post processing to apply on the predictions. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing PostProcessing { + get { return postProcessing_; } + set { + postProcessing_ = value; + } + } + + /// Field number for the "normalize_loss_by_num_matches" field. + public const int NormalizeLossByNumMatchesFieldNumber = 10; + private bool normalizeLossByNumMatches_; + /// + /// Whether to normalize the loss by number of groundtruth boxes that match to + /// the anchors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool NormalizeLossByNumMatches { + get { return normalizeLossByNumMatches_; } + set { + normalizeLossByNumMatches_ = value; + } + } + + /// Field number for the "normalize_loc_loss_by_codesize" field. + public const int NormalizeLocLossByCodesizeFieldNumber = 14; + private bool normalizeLocLossByCodesize_; + /// + /// Whether to normalize the localization loss by the code size of the box + /// encodings. This is applied along with other normalization factors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool NormalizeLocLossByCodesize { + get { return normalizeLocLossByCodesize_; } + set { + normalizeLocLossByCodesize_ = value; + } + } + + /// Field number for the "loss" field. + public const int LossFieldNumber = 11; + private global::Tensorflow.Models.ObjectDetection.Protos.Loss loss_; + /// + /// Loss configuration for training. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Loss Loss { + get { return loss_; } + set { + loss_ = value; + } + } + + /// Field number for the "freeze_batchnorm" field. + public const int FreezeBatchnormFieldNumber = 16; + private bool freezeBatchnorm_; + /// + /// Whether to update batch norm parameters during training or not. + /// When training with a relative small batch size (e.g. 1), it is + /// desirable to disable batch norm update and use pretrained batch norm + /// params. + /// + /// Note: Some feature extractors are used with canned arg_scopes + /// (e.g resnet arg scopes). In these cases training behavior of batch norm + /// variables may depend on both values of `batch_norm_trainable` and + /// `is_training`. + /// + /// When canned arg_scopes are used with feature extractors `conv_hyperparams` + /// will apply only to the additional layers that are added and are outside the + /// canned arg_scope. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool FreezeBatchnorm { + get { return freezeBatchnorm_; } + set { + freezeBatchnorm_ = value; + } + } + + /// Field number for the "inplace_batchnorm_update" field. + public const int InplaceBatchnormUpdateFieldNumber = 15; + private bool inplaceBatchnormUpdate_; + /// + /// Whether to update batch_norm inplace during training. This is required + /// for batch norm to work correctly on TPUs. When this is false, user must add + /// a control dependency on tf.GraphKeys.UPDATE_OPS for train/loss op in order + /// to update the batch norm moving average parameters. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool InplaceBatchnormUpdate { + get { return inplaceBatchnormUpdate_; } + set { + inplaceBatchnormUpdate_ = value; + } + } + + /// Field number for the "add_background_class" field. + public const int AddBackgroundClassFieldNumber = 21; + private bool addBackgroundClass_; + /// + /// Whether to add an implicit background class to one-hot encodings of + /// groundtruth labels. Set to false if training a single + /// class model or using an explicit background class. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AddBackgroundClass { + get { return addBackgroundClass_; } + set { + addBackgroundClass_ = value; + } + } + + /// Field number for the "explicit_background_class" field. + public const int ExplicitBackgroundClassFieldNumber = 24; + private bool explicitBackgroundClass_; + /// + /// Whether to use an explicit background class. Set to true if using + /// groundtruth labels with an explicit background class, as in multiclass + /// scores. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ExplicitBackgroundClass { + get { return explicitBackgroundClass_; } + set { + explicitBackgroundClass_ = value; + } + } + + /// Field number for the "use_confidences_as_targets" field. + public const int UseConfidencesAsTargetsFieldNumber = 22; + private bool useConfidencesAsTargets_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseConfidencesAsTargets { + get { return useConfidencesAsTargets_; } + set { + useConfidencesAsTargets_ = value; + } + } + + /// Field number for the "implicit_example_weight" field. + public const int ImplicitExampleWeightFieldNumber = 23; + private float implicitExampleWeight_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float ImplicitExampleWeight { + get { return implicitExampleWeight_; } + set { + implicitExampleWeight_ = value; + } + } + + /// Field number for the "mask_head_config" field. + public const int MaskHeadConfigFieldNumber = 25; + private global::Tensorflow.Models.ObjectDetection.Protos.Ssd.Types.MaskHead maskHeadConfig_; + /// + /// Configs for mask head. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Ssd.Types.MaskHead MaskHeadConfig { + get { return maskHeadConfig_; } + set { + maskHeadConfig_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as Ssd); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(Ssd other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NumClasses != other.NumClasses) return false; + if (!object.Equals(ImageResizer, other.ImageResizer)) return false; + if (!object.Equals(FeatureExtractor, other.FeatureExtractor)) return false; + if (!object.Equals(BoxCoder, other.BoxCoder)) return false; + if (!object.Equals(Matcher, other.Matcher)) return false; + if (!object.Equals(SimilarityCalculator, other.SimilarityCalculator)) return false; + if (EncodeBackgroundAsZeros != other.EncodeBackgroundAsZeros) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(NegativeClassWeight, other.NegativeClassWeight)) return false; + if (!object.Equals(BoxPredictor, other.BoxPredictor)) return false; + if (!object.Equals(AnchorGenerator, other.AnchorGenerator)) return false; + if (!object.Equals(PostProcessing, other.PostProcessing)) return false; + if (NormalizeLossByNumMatches != other.NormalizeLossByNumMatches) return false; + if (NormalizeLocLossByCodesize != other.NormalizeLocLossByCodesize) return false; + if (!object.Equals(Loss, other.Loss)) return false; + if (FreezeBatchnorm != other.FreezeBatchnorm) return false; + if (InplaceBatchnormUpdate != other.InplaceBatchnormUpdate) return false; + if (AddBackgroundClass != other.AddBackgroundClass) return false; + if (ExplicitBackgroundClass != other.ExplicitBackgroundClass) return false; + if (UseConfidencesAsTargets != other.UseConfidencesAsTargets) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(ImplicitExampleWeight, other.ImplicitExampleWeight)) return false; + if (!object.Equals(MaskHeadConfig, other.MaskHeadConfig)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (NumClasses != 0) hash ^= NumClasses.GetHashCode(); + if (imageResizer_ != null) hash ^= ImageResizer.GetHashCode(); + if (featureExtractor_ != null) hash ^= FeatureExtractor.GetHashCode(); + if (boxCoder_ != null) hash ^= BoxCoder.GetHashCode(); + if (matcher_ != null) hash ^= Matcher.GetHashCode(); + if (similarityCalculator_ != null) hash ^= SimilarityCalculator.GetHashCode(); + if (EncodeBackgroundAsZeros != false) hash ^= EncodeBackgroundAsZeros.GetHashCode(); + if (NegativeClassWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(NegativeClassWeight); + if (boxPredictor_ != null) hash ^= BoxPredictor.GetHashCode(); + if (anchorGenerator_ != null) hash ^= AnchorGenerator.GetHashCode(); + if (postProcessing_ != null) hash ^= PostProcessing.GetHashCode(); + if (NormalizeLossByNumMatches != false) hash ^= NormalizeLossByNumMatches.GetHashCode(); + if (NormalizeLocLossByCodesize != false) hash ^= NormalizeLocLossByCodesize.GetHashCode(); + if (loss_ != null) hash ^= Loss.GetHashCode(); + if (FreezeBatchnorm != false) hash ^= FreezeBatchnorm.GetHashCode(); + if (InplaceBatchnormUpdate != false) hash ^= InplaceBatchnormUpdate.GetHashCode(); + if (AddBackgroundClass != false) hash ^= AddBackgroundClass.GetHashCode(); + if (ExplicitBackgroundClass != false) hash ^= ExplicitBackgroundClass.GetHashCode(); + if (UseConfidencesAsTargets != false) hash ^= UseConfidencesAsTargets.GetHashCode(); + if (ImplicitExampleWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(ImplicitExampleWeight); + if (maskHeadConfig_ != null) hash ^= MaskHeadConfig.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (NumClasses != 0) { + output.WriteRawTag(8); + output.WriteInt32(NumClasses); + } + if (imageResizer_ != null) { + output.WriteRawTag(18); + output.WriteMessage(ImageResizer); + } + if (featureExtractor_ != null) { + output.WriteRawTag(26); + output.WriteMessage(FeatureExtractor); + } + if (boxCoder_ != null) { + output.WriteRawTag(34); + output.WriteMessage(BoxCoder); + } + if (matcher_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Matcher); + } + if (similarityCalculator_ != null) { + output.WriteRawTag(50); + output.WriteMessage(SimilarityCalculator); + } + if (boxPredictor_ != null) { + output.WriteRawTag(58); + output.WriteMessage(BoxPredictor); + } + if (anchorGenerator_ != null) { + output.WriteRawTag(66); + output.WriteMessage(AnchorGenerator); + } + if (postProcessing_ != null) { + output.WriteRawTag(74); + output.WriteMessage(PostProcessing); + } + if (NormalizeLossByNumMatches != false) { + output.WriteRawTag(80); + output.WriteBool(NormalizeLossByNumMatches); + } + if (loss_ != null) { + output.WriteRawTag(90); + output.WriteMessage(Loss); + } + if (EncodeBackgroundAsZeros != false) { + output.WriteRawTag(96); + output.WriteBool(EncodeBackgroundAsZeros); + } + if (NegativeClassWeight != 0F) { + output.WriteRawTag(109); + output.WriteFloat(NegativeClassWeight); + } + if (NormalizeLocLossByCodesize != false) { + output.WriteRawTag(112); + output.WriteBool(NormalizeLocLossByCodesize); + } + if (InplaceBatchnormUpdate != false) { + output.WriteRawTag(120); + output.WriteBool(InplaceBatchnormUpdate); + } + if (FreezeBatchnorm != false) { + output.WriteRawTag(128, 1); + output.WriteBool(FreezeBatchnorm); + } + if (AddBackgroundClass != false) { + output.WriteRawTag(168, 1); + output.WriteBool(AddBackgroundClass); + } + if (UseConfidencesAsTargets != false) { + output.WriteRawTag(176, 1); + output.WriteBool(UseConfidencesAsTargets); + } + if (ImplicitExampleWeight != 0F) { + output.WriteRawTag(189, 1); + output.WriteFloat(ImplicitExampleWeight); + } + if (ExplicitBackgroundClass != false) { + output.WriteRawTag(192, 1); + output.WriteBool(ExplicitBackgroundClass); + } + if (maskHeadConfig_ != null) { + output.WriteRawTag(202, 1); + output.WriteMessage(MaskHeadConfig); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (NumClasses != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumClasses); + } + if (imageResizer_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ImageResizer); + } + if (featureExtractor_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(FeatureExtractor); + } + if (boxCoder_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BoxCoder); + } + if (matcher_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Matcher); + } + if (similarityCalculator_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(SimilarityCalculator); + } + if (EncodeBackgroundAsZeros != false) { + size += 1 + 1; + } + if (NegativeClassWeight != 0F) { + size += 1 + 4; + } + if (boxPredictor_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(BoxPredictor); + } + if (anchorGenerator_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(AnchorGenerator); + } + if (postProcessing_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(PostProcessing); + } + if (NormalizeLossByNumMatches != false) { + size += 1 + 1; + } + if (NormalizeLocLossByCodesize != false) { + size += 1 + 1; + } + if (loss_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Loss); + } + if (FreezeBatchnorm != false) { + size += 2 + 1; + } + if (InplaceBatchnormUpdate != false) { + size += 1 + 1; + } + if (AddBackgroundClass != false) { + size += 2 + 1; + } + if (ExplicitBackgroundClass != false) { + size += 2 + 1; + } + if (UseConfidencesAsTargets != false) { + size += 2 + 1; + } + if (ImplicitExampleWeight != 0F) { + size += 2 + 4; + } + if (maskHeadConfig_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(MaskHeadConfig); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(Ssd other) { + if (other == null) { + return; + } + if (other.NumClasses != 0) { + NumClasses = other.NumClasses; + } + if (other.imageResizer_ != null) { + if (imageResizer_ == null) { + imageResizer_ = new global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer(); + } + ImageResizer.MergeFrom(other.ImageResizer); + } + if (other.featureExtractor_ != null) { + if (featureExtractor_ == null) { + featureExtractor_ = new global::Tensorflow.Models.ObjectDetection.Protos.SsdFeatureExtractor(); + } + FeatureExtractor.MergeFrom(other.FeatureExtractor); + } + if (other.boxCoder_ != null) { + if (boxCoder_ == null) { + boxCoder_ = new global::Tensorflow.Models.ObjectDetection.Protos.BoxCoder(); + } + BoxCoder.MergeFrom(other.BoxCoder); + } + if (other.matcher_ != null) { + if (matcher_ == null) { + matcher_ = new global::Tensorflow.Models.ObjectDetection.Protos.Matcher(); + } + Matcher.MergeFrom(other.Matcher); + } + if (other.similarityCalculator_ != null) { + if (similarityCalculator_ == null) { + similarityCalculator_ = new global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculator(); + } + SimilarityCalculator.MergeFrom(other.SimilarityCalculator); + } + if (other.EncodeBackgroundAsZeros != false) { + EncodeBackgroundAsZeros = other.EncodeBackgroundAsZeros; + } + if (other.NegativeClassWeight != 0F) { + NegativeClassWeight = other.NegativeClassWeight; + } + if (other.boxPredictor_ != null) { + if (boxPredictor_ == null) { + boxPredictor_ = new global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor(); + } + BoxPredictor.MergeFrom(other.BoxPredictor); + } + if (other.anchorGenerator_ != null) { + if (anchorGenerator_ == null) { + anchorGenerator_ = new global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator(); + } + AnchorGenerator.MergeFrom(other.AnchorGenerator); + } + if (other.postProcessing_ != null) { + if (postProcessing_ == null) { + postProcessing_ = new global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing(); + } + PostProcessing.MergeFrom(other.PostProcessing); + } + if (other.NormalizeLossByNumMatches != false) { + NormalizeLossByNumMatches = other.NormalizeLossByNumMatches; + } + if (other.NormalizeLocLossByCodesize != false) { + NormalizeLocLossByCodesize = other.NormalizeLocLossByCodesize; + } + if (other.loss_ != null) { + if (loss_ == null) { + loss_ = new global::Tensorflow.Models.ObjectDetection.Protos.Loss(); + } + Loss.MergeFrom(other.Loss); + } + if (other.FreezeBatchnorm != false) { + FreezeBatchnorm = other.FreezeBatchnorm; + } + if (other.InplaceBatchnormUpdate != false) { + InplaceBatchnormUpdate = other.InplaceBatchnormUpdate; + } + if (other.AddBackgroundClass != false) { + AddBackgroundClass = other.AddBackgroundClass; + } + if (other.ExplicitBackgroundClass != false) { + ExplicitBackgroundClass = other.ExplicitBackgroundClass; + } + if (other.UseConfidencesAsTargets != false) { + UseConfidencesAsTargets = other.UseConfidencesAsTargets; + } + if (other.ImplicitExampleWeight != 0F) { + ImplicitExampleWeight = other.ImplicitExampleWeight; + } + if (other.maskHeadConfig_ != null) { + if (maskHeadConfig_ == null) { + maskHeadConfig_ = new global::Tensorflow.Models.ObjectDetection.Protos.Ssd.Types.MaskHead(); + } + MaskHeadConfig.MergeFrom(other.MaskHeadConfig); + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NumClasses = input.ReadInt32(); + break; + } + case 18: { + if (imageResizer_ == null) { + imageResizer_ = new global::Tensorflow.Models.ObjectDetection.Protos.ImageResizer(); + } + input.ReadMessage(imageResizer_); + break; + } + case 26: { + if (featureExtractor_ == null) { + featureExtractor_ = new global::Tensorflow.Models.ObjectDetection.Protos.SsdFeatureExtractor(); + } + input.ReadMessage(featureExtractor_); + break; + } + case 34: { + if (boxCoder_ == null) { + boxCoder_ = new global::Tensorflow.Models.ObjectDetection.Protos.BoxCoder(); + } + input.ReadMessage(boxCoder_); + break; + } + case 42: { + if (matcher_ == null) { + matcher_ = new global::Tensorflow.Models.ObjectDetection.Protos.Matcher(); + } + input.ReadMessage(matcher_); + break; + } + case 50: { + if (similarityCalculator_ == null) { + similarityCalculator_ = new global::Tensorflow.Models.ObjectDetection.Protos.RegionSimilarityCalculator(); + } + input.ReadMessage(similarityCalculator_); + break; + } + case 58: { + if (boxPredictor_ == null) { + boxPredictor_ = new global::Tensorflow.Models.ObjectDetection.Protos.BoxPredictor(); + } + input.ReadMessage(boxPredictor_); + break; + } + case 66: { + if (anchorGenerator_ == null) { + anchorGenerator_ = new global::Tensorflow.Models.ObjectDetection.Protos.AnchorGenerator(); + } + input.ReadMessage(anchorGenerator_); + break; + } + case 74: { + if (postProcessing_ == null) { + postProcessing_ = new global::Tensorflow.Models.ObjectDetection.Protos.PostProcessing(); + } + input.ReadMessage(postProcessing_); + break; + } + case 80: { + NormalizeLossByNumMatches = input.ReadBool(); + break; + } + case 90: { + if (loss_ == null) { + loss_ = new global::Tensorflow.Models.ObjectDetection.Protos.Loss(); + } + input.ReadMessage(loss_); + break; + } + case 96: { + EncodeBackgroundAsZeros = input.ReadBool(); + break; + } + case 109: { + NegativeClassWeight = input.ReadFloat(); + break; + } + case 112: { + NormalizeLocLossByCodesize = input.ReadBool(); + break; + } + case 120: { + InplaceBatchnormUpdate = input.ReadBool(); + break; + } + case 128: { + FreezeBatchnorm = input.ReadBool(); + break; + } + case 168: { + AddBackgroundClass = input.ReadBool(); + break; + } + case 176: { + UseConfidencesAsTargets = input.ReadBool(); + break; + } + case 189: { + ImplicitExampleWeight = input.ReadFloat(); + break; + } + case 192: { + ExplicitBackgroundClass = input.ReadBool(); + break; + } + case 202: { + if (maskHeadConfig_ == null) { + maskHeadConfig_ = new global::Tensorflow.Models.ObjectDetection.Protos.Ssd.Types.MaskHead(); + } + input.ReadMessage(maskHeadConfig_); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the Ssd message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + /// + /// Configuration proto for MaskHead. + /// Next id: 11 + /// + public sealed partial class MaskHead : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MaskHead()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.Ssd.Descriptor.NestedTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MaskHead() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MaskHead(MaskHead other) : this() { + maskHeight_ = other.maskHeight_; + maskWidth_ = other.maskWidth_; + masksAreClassAgnostic_ = other.masksAreClassAgnostic_; + maskPredictionConvDepth_ = other.maskPredictionConvDepth_; + maskPredictionNumConvLayers_ = other.maskPredictionNumConvLayers_; + convolveThenUpsampleMasks_ = other.convolveThenUpsampleMasks_; + maskLossWeight_ = other.maskLossWeight_; + maskLossSampleSize_ = other.maskLossSampleSize_; + convHyperparams_ = other.convHyperparams_ != null ? other.convHyperparams_.Clone() : null; + initialCropSize_ = other.initialCropSize_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MaskHead Clone() { + return new MaskHead(this); + } + + /// Field number for the "mask_height" field. + public const int MaskHeightFieldNumber = 1; + private int maskHeight_; + /// + /// The height and the width of the predicted mask. Only used when + /// predict_instance_masks is true. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaskHeight { + get { return maskHeight_; } + set { + maskHeight_ = value; + } + } + + /// Field number for the "mask_width" field. + public const int MaskWidthFieldNumber = 2; + private int maskWidth_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaskWidth { + get { return maskWidth_; } + set { + maskWidth_ = value; + } + } + + /// Field number for the "masks_are_class_agnostic" field. + public const int MasksAreClassAgnosticFieldNumber = 3; + private bool masksAreClassAgnostic_; + /// + /// Whether to predict class agnostic masks. Only used when + /// predict_instance_masks is true. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool MasksAreClassAgnostic { + get { return masksAreClassAgnostic_; } + set { + masksAreClassAgnostic_ = value; + } + } + + /// Field number for the "mask_prediction_conv_depth" field. + public const int MaskPredictionConvDepthFieldNumber = 4; + private int maskPredictionConvDepth_; + /// + /// The depth for the first conv2d_transpose op applied to the + /// image_features in the mask prediction branch. If set to 0, the value + /// will be set automatically based on the number of channels in the image + /// features and the number of classes. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaskPredictionConvDepth { + get { return maskPredictionConvDepth_; } + set { + maskPredictionConvDepth_ = value; + } + } + + /// Field number for the "mask_prediction_num_conv_layers" field. + public const int MaskPredictionNumConvLayersFieldNumber = 5; + private int maskPredictionNumConvLayers_; + /// + /// The number of convolutions applied to image_features in the mask + /// prediction branch. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaskPredictionNumConvLayers { + get { return maskPredictionNumConvLayers_; } + set { + maskPredictionNumConvLayers_ = value; + } + } + + /// Field number for the "convolve_then_upsample_masks" field. + public const int ConvolveThenUpsampleMasksFieldNumber = 6; + private bool convolveThenUpsampleMasks_; + /// + /// Whether to apply convolutions on mask features before upsampling using + /// nearest neighbor resizing. + /// By default, mask features are resized to [`mask_height`, `mask_width`] + /// before applying convolutions and predicting masks. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ConvolveThenUpsampleMasks { + get { return convolveThenUpsampleMasks_; } + set { + convolveThenUpsampleMasks_ = value; + } + } + + /// Field number for the "mask_loss_weight" field. + public const int MaskLossWeightFieldNumber = 7; + private float maskLossWeight_; + /// + /// Mask loss weight. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaskLossWeight { + get { return maskLossWeight_; } + set { + maskLossWeight_ = value; + } + } + + /// Field number for the "mask_loss_sample_size" field. + public const int MaskLossSampleSizeFieldNumber = 8; + private int maskLossSampleSize_; + /// + /// Number of boxes to be generated at training time for computing mask loss. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaskLossSampleSize { + get { return maskLossSampleSize_; } + set { + maskLossSampleSize_ = value; + } + } + + /// Field number for the "conv_hyperparams" field. + public const int ConvHyperparamsFieldNumber = 9; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams convHyperparams_; + /// + /// Hyperparameters for convolution ops used in the box predictor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams ConvHyperparams { + get { return convHyperparams_; } + set { + convHyperparams_ = value; + } + } + + /// Field number for the "initial_crop_size" field. + public const int InitialCropSizeFieldNumber = 10; + private int initialCropSize_; + /// + /// Output size (width and height are set to be the same) of the initial + /// bilinear interpolation based cropping during ROI pooling. Only used when + /// we have second stage prediction head enabled (e.g. mask head). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int InitialCropSize { + get { return initialCropSize_; } + set { + initialCropSize_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as MaskHead); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(MaskHead other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MaskHeight != other.MaskHeight) return false; + if (MaskWidth != other.MaskWidth) return false; + if (MasksAreClassAgnostic != other.MasksAreClassAgnostic) return false; + if (MaskPredictionConvDepth != other.MaskPredictionConvDepth) return false; + if (MaskPredictionNumConvLayers != other.MaskPredictionNumConvLayers) return false; + if (ConvolveThenUpsampleMasks != other.ConvolveThenUpsampleMasks) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaskLossWeight, other.MaskLossWeight)) return false; + if (MaskLossSampleSize != other.MaskLossSampleSize) return false; + if (!object.Equals(ConvHyperparams, other.ConvHyperparams)) return false; + if (InitialCropSize != other.InitialCropSize) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MaskHeight != 0) hash ^= MaskHeight.GetHashCode(); + if (MaskWidth != 0) hash ^= MaskWidth.GetHashCode(); + if (MasksAreClassAgnostic != false) hash ^= MasksAreClassAgnostic.GetHashCode(); + if (MaskPredictionConvDepth != 0) hash ^= MaskPredictionConvDepth.GetHashCode(); + if (MaskPredictionNumConvLayers != 0) hash ^= MaskPredictionNumConvLayers.GetHashCode(); + if (ConvolveThenUpsampleMasks != false) hash ^= ConvolveThenUpsampleMasks.GetHashCode(); + if (MaskLossWeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaskLossWeight); + if (MaskLossSampleSize != 0) hash ^= MaskLossSampleSize.GetHashCode(); + if (convHyperparams_ != null) hash ^= ConvHyperparams.GetHashCode(); + if (InitialCropSize != 0) hash ^= InitialCropSize.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MaskHeight != 0) { + output.WriteRawTag(8); + output.WriteInt32(MaskHeight); + } + if (MaskWidth != 0) { + output.WriteRawTag(16); + output.WriteInt32(MaskWidth); + } + if (MasksAreClassAgnostic != false) { + output.WriteRawTag(24); + output.WriteBool(MasksAreClassAgnostic); + } + if (MaskPredictionConvDepth != 0) { + output.WriteRawTag(32); + output.WriteInt32(MaskPredictionConvDepth); + } + if (MaskPredictionNumConvLayers != 0) { + output.WriteRawTag(40); + output.WriteInt32(MaskPredictionNumConvLayers); + } + if (ConvolveThenUpsampleMasks != false) { + output.WriteRawTag(48); + output.WriteBool(ConvolveThenUpsampleMasks); + } + if (MaskLossWeight != 0F) { + output.WriteRawTag(61); + output.WriteFloat(MaskLossWeight); + } + if (MaskLossSampleSize != 0) { + output.WriteRawTag(64); + output.WriteInt32(MaskLossSampleSize); + } + if (convHyperparams_ != null) { + output.WriteRawTag(74); + output.WriteMessage(ConvHyperparams); + } + if (InitialCropSize != 0) { + output.WriteRawTag(80); + output.WriteInt32(InitialCropSize); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MaskHeight != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaskHeight); + } + if (MaskWidth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaskWidth); + } + if (MasksAreClassAgnostic != false) { + size += 1 + 1; + } + if (MaskPredictionConvDepth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaskPredictionConvDepth); + } + if (MaskPredictionNumConvLayers != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaskPredictionNumConvLayers); + } + if (ConvolveThenUpsampleMasks != false) { + size += 1 + 1; + } + if (MaskLossWeight != 0F) { + size += 1 + 4; + } + if (MaskLossSampleSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaskLossSampleSize); + } + if (convHyperparams_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ConvHyperparams); + } + if (InitialCropSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(InitialCropSize); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(MaskHead other) { + if (other == null) { + return; + } + if (other.MaskHeight != 0) { + MaskHeight = other.MaskHeight; + } + if (other.MaskWidth != 0) { + MaskWidth = other.MaskWidth; + } + if (other.MasksAreClassAgnostic != false) { + MasksAreClassAgnostic = other.MasksAreClassAgnostic; + } + if (other.MaskPredictionConvDepth != 0) { + MaskPredictionConvDepth = other.MaskPredictionConvDepth; + } + if (other.MaskPredictionNumConvLayers != 0) { + MaskPredictionNumConvLayers = other.MaskPredictionNumConvLayers; + } + if (other.ConvolveThenUpsampleMasks != false) { + ConvolveThenUpsampleMasks = other.ConvolveThenUpsampleMasks; + } + if (other.MaskLossWeight != 0F) { + MaskLossWeight = other.MaskLossWeight; + } + if (other.MaskLossSampleSize != 0) { + MaskLossSampleSize = other.MaskLossSampleSize; + } + if (other.convHyperparams_ != null) { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + ConvHyperparams.MergeFrom(other.ConvHyperparams); + } + if (other.InitialCropSize != 0) { + InitialCropSize = other.InitialCropSize; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + MaskHeight = input.ReadInt32(); + break; + } + case 16: { + MaskWidth = input.ReadInt32(); + break; + } + case 24: { + MasksAreClassAgnostic = input.ReadBool(); + break; + } + case 32: { + MaskPredictionConvDepth = input.ReadInt32(); + break; + } + case 40: { + MaskPredictionNumConvLayers = input.ReadInt32(); + break; + } + case 48: { + ConvolveThenUpsampleMasks = input.ReadBool(); + break; + } + case 61: { + MaskLossWeight = input.ReadFloat(); + break; + } + case 64: { + MaskLossSampleSize = input.ReadInt32(); + break; + } + case 74: { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + input.ReadMessage(convHyperparams_); + break; + } + case 80: { + InitialCropSize = input.ReadInt32(); + break; + } + } + } + } + + } + + } + #endregion + + } + + public sealed partial class SsdFeatureExtractor : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SsdFeatureExtractor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.SsdReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SsdFeatureExtractor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SsdFeatureExtractor(SsdFeatureExtractor other) : this() { + type_ = other.type_; + depthMultiplier_ = other.depthMultiplier_; + minDepth_ = other.minDepth_; + convHyperparams_ = other.convHyperparams_ != null ? other.convHyperparams_.Clone() : null; + overrideBaseFeatureExtractorHyperparams_ = other.overrideBaseFeatureExtractorHyperparams_; + padToMultiple_ = other.padToMultiple_; + useExplicitPadding_ = other.useExplicitPadding_; + useDepthwise_ = other.useDepthwise_; + fpn_ = other.fpn_ != null ? other.fpn_.Clone() : null; + replacePreprocessorWithPlaceholder_ = other.replacePreprocessorWithPlaceholder_; + numLayers_ = other.numLayers_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SsdFeatureExtractor Clone() { + return new SsdFeatureExtractor(this); + } + + /// Field number for the "type" field. + public const int TypeFieldNumber = 1; + private string type_ = ""; + /// + /// Type of ssd feature extractor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Type { + get { return type_; } + set { + type_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "depth_multiplier" field. + public const int DepthMultiplierFieldNumber = 2; + private float depthMultiplier_; + /// + /// The factor to alter the depth of the channels in the feature extractor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float DepthMultiplier { + get { return depthMultiplier_; } + set { + depthMultiplier_ = value; + } + } + + /// Field number for the "min_depth" field. + public const int MinDepthFieldNumber = 3; + private int minDepth_; + /// + /// Minimum number of the channels in the feature extractor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinDepth { + get { return minDepth_; } + set { + minDepth_ = value; + } + } + + /// Field number for the "conv_hyperparams" field. + public const int ConvHyperparamsFieldNumber = 4; + private global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams convHyperparams_; + /// + /// Hyperparameters that affect the layers of feature extractor added on top + /// of the base feature extractor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams ConvHyperparams { + get { return convHyperparams_; } + set { + convHyperparams_ = value; + } + } + + /// Field number for the "override_base_feature_extractor_hyperparams" field. + public const int OverrideBaseFeatureExtractorHyperparamsFieldNumber = 9; + private bool overrideBaseFeatureExtractorHyperparams_; + /// + /// Normally, SSD feature extractors are constructed by reusing an existing + /// base feature extractor (that has its own hyperparams) and adding new layers + /// on top of it. `conv_hyperparams` above normally applies only to the new + /// layers while base feature extractor uses its own default hyperparams. If + /// this value is set to true, the base feature extractor's hyperparams will be + /// overridden with the `conv_hyperparams`. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool OverrideBaseFeatureExtractorHyperparams { + get { return overrideBaseFeatureExtractorHyperparams_; } + set { + overrideBaseFeatureExtractorHyperparams_ = value; + } + } + + /// Field number for the "pad_to_multiple" field. + public const int PadToMultipleFieldNumber = 5; + private int padToMultiple_; + /// + /// The nearest multiple to zero-pad the input height and width dimensions to. + /// For example, if pad_to_multiple = 2, input dimensions are zero-padded + /// until the resulting dimensions are even. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int PadToMultiple { + get { return padToMultiple_; } + set { + padToMultiple_ = value; + } + } + + /// Field number for the "use_explicit_padding" field. + public const int UseExplicitPaddingFieldNumber = 7; + private bool useExplicitPadding_; + /// + /// Whether to use explicit padding when extracting SSD multiresolution + /// features. This will also apply to the base feature extractor if a MobileNet + /// architecture is used. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseExplicitPadding { + get { return useExplicitPadding_; } + set { + useExplicitPadding_ = value; + } + } + + /// Field number for the "use_depthwise" field. + public const int UseDepthwiseFieldNumber = 8; + private bool useDepthwise_; + /// + /// Whether to use depthwise separable convolutions for to extract additional + /// feature maps added by SSD. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseDepthwise { + get { return useDepthwise_; } + set { + useDepthwise_ = value; + } + } + + /// Field number for the "fpn" field. + public const int FpnFieldNumber = 10; + private global::Tensorflow.Models.ObjectDetection.Protos.FeaturePyramidNetworks fpn_; + /// + /// Feature Pyramid Networks config. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.FeaturePyramidNetworks Fpn { + get { return fpn_; } + set { + fpn_ = value; + } + } + + /// Field number for the "replace_preprocessor_with_placeholder" field. + public const int ReplacePreprocessorWithPlaceholderFieldNumber = 11; + private bool replacePreprocessorWithPlaceholder_; + /// + /// If true, replace preprocess function of feature extractor with a + /// placeholder. This should only be used if all the image preprocessing steps + /// happen outside the graph. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ReplacePreprocessorWithPlaceholder { + get { return replacePreprocessorWithPlaceholder_; } + set { + replacePreprocessorWithPlaceholder_ = value; + } + } + + /// Field number for the "num_layers" field. + public const int NumLayersFieldNumber = 12; + private int numLayers_; + /// + /// The number of SSD layers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumLayers { + get { return numLayers_; } + set { + numLayers_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SsdFeatureExtractor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SsdFeatureExtractor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Type != other.Type) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(DepthMultiplier, other.DepthMultiplier)) return false; + if (MinDepth != other.MinDepth) return false; + if (!object.Equals(ConvHyperparams, other.ConvHyperparams)) return false; + if (OverrideBaseFeatureExtractorHyperparams != other.OverrideBaseFeatureExtractorHyperparams) return false; + if (PadToMultiple != other.PadToMultiple) return false; + if (UseExplicitPadding != other.UseExplicitPadding) return false; + if (UseDepthwise != other.UseDepthwise) return false; + if (!object.Equals(Fpn, other.Fpn)) return false; + if (ReplacePreprocessorWithPlaceholder != other.ReplacePreprocessorWithPlaceholder) return false; + if (NumLayers != other.NumLayers) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Type.Length != 0) hash ^= Type.GetHashCode(); + if (DepthMultiplier != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(DepthMultiplier); + if (MinDepth != 0) hash ^= MinDepth.GetHashCode(); + if (convHyperparams_ != null) hash ^= ConvHyperparams.GetHashCode(); + if (OverrideBaseFeatureExtractorHyperparams != false) hash ^= OverrideBaseFeatureExtractorHyperparams.GetHashCode(); + if (PadToMultiple != 0) hash ^= PadToMultiple.GetHashCode(); + if (UseExplicitPadding != false) hash ^= UseExplicitPadding.GetHashCode(); + if (UseDepthwise != false) hash ^= UseDepthwise.GetHashCode(); + if (fpn_ != null) hash ^= Fpn.GetHashCode(); + if (ReplacePreprocessorWithPlaceholder != false) hash ^= ReplacePreprocessorWithPlaceholder.GetHashCode(); + if (NumLayers != 0) hash ^= NumLayers.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Type.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Type); + } + if (DepthMultiplier != 0F) { + output.WriteRawTag(21); + output.WriteFloat(DepthMultiplier); + } + if (MinDepth != 0) { + output.WriteRawTag(24); + output.WriteInt32(MinDepth); + } + if (convHyperparams_ != null) { + output.WriteRawTag(34); + output.WriteMessage(ConvHyperparams); + } + if (PadToMultiple != 0) { + output.WriteRawTag(40); + output.WriteInt32(PadToMultiple); + } + if (UseExplicitPadding != false) { + output.WriteRawTag(56); + output.WriteBool(UseExplicitPadding); + } + if (UseDepthwise != false) { + output.WriteRawTag(64); + output.WriteBool(UseDepthwise); + } + if (OverrideBaseFeatureExtractorHyperparams != false) { + output.WriteRawTag(72); + output.WriteBool(OverrideBaseFeatureExtractorHyperparams); + } + if (fpn_ != null) { + output.WriteRawTag(82); + output.WriteMessage(Fpn); + } + if (ReplacePreprocessorWithPlaceholder != false) { + output.WriteRawTag(88); + output.WriteBool(ReplacePreprocessorWithPlaceholder); + } + if (NumLayers != 0) { + output.WriteRawTag(96); + output.WriteInt32(NumLayers); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Type.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Type); + } + if (DepthMultiplier != 0F) { + size += 1 + 4; + } + if (MinDepth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinDepth); + } + if (convHyperparams_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ConvHyperparams); + } + if (OverrideBaseFeatureExtractorHyperparams != false) { + size += 1 + 1; + } + if (PadToMultiple != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(PadToMultiple); + } + if (UseExplicitPadding != false) { + size += 1 + 1; + } + if (UseDepthwise != false) { + size += 1 + 1; + } + if (fpn_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Fpn); + } + if (ReplacePreprocessorWithPlaceholder != false) { + size += 1 + 1; + } + if (NumLayers != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumLayers); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SsdFeatureExtractor other) { + if (other == null) { + return; + } + if (other.Type.Length != 0) { + Type = other.Type; + } + if (other.DepthMultiplier != 0F) { + DepthMultiplier = other.DepthMultiplier; + } + if (other.MinDepth != 0) { + MinDepth = other.MinDepth; + } + if (other.convHyperparams_ != null) { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + ConvHyperparams.MergeFrom(other.ConvHyperparams); + } + if (other.OverrideBaseFeatureExtractorHyperparams != false) { + OverrideBaseFeatureExtractorHyperparams = other.OverrideBaseFeatureExtractorHyperparams; + } + if (other.PadToMultiple != 0) { + PadToMultiple = other.PadToMultiple; + } + if (other.UseExplicitPadding != false) { + UseExplicitPadding = other.UseExplicitPadding; + } + if (other.UseDepthwise != false) { + UseDepthwise = other.UseDepthwise; + } + if (other.fpn_ != null) { + if (fpn_ == null) { + fpn_ = new global::Tensorflow.Models.ObjectDetection.Protos.FeaturePyramidNetworks(); + } + Fpn.MergeFrom(other.Fpn); + } + if (other.ReplacePreprocessorWithPlaceholder != false) { + ReplacePreprocessorWithPlaceholder = other.ReplacePreprocessorWithPlaceholder; + } + if (other.NumLayers != 0) { + NumLayers = other.NumLayers; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Type = input.ReadString(); + break; + } + case 21: { + DepthMultiplier = input.ReadFloat(); + break; + } + case 24: { + MinDepth = input.ReadInt32(); + break; + } + case 34: { + if (convHyperparams_ == null) { + convHyperparams_ = new global::Tensorflow.Models.ObjectDetection.Protos.Hyperparams(); + } + input.ReadMessage(convHyperparams_); + break; + } + case 40: { + PadToMultiple = input.ReadInt32(); + break; + } + case 56: { + UseExplicitPadding = input.ReadBool(); + break; + } + case 64: { + UseDepthwise = input.ReadBool(); + break; + } + case 72: { + OverrideBaseFeatureExtractorHyperparams = input.ReadBool(); + break; + } + case 82: { + if (fpn_ == null) { + fpn_ = new global::Tensorflow.Models.ObjectDetection.Protos.FeaturePyramidNetworks(); + } + input.ReadMessage(fpn_); + break; + } + case 88: { + ReplacePreprocessorWithPlaceholder = input.ReadBool(); + break; + } + case 96: { + NumLayers = input.ReadInt32(); + break; + } + } + } + } + + } + + /// + /// Configuration for Feature Pyramid Networks. + /// + public sealed partial class FeaturePyramidNetworks : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FeaturePyramidNetworks()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.SsdReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FeaturePyramidNetworks() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FeaturePyramidNetworks(FeaturePyramidNetworks other) : this() { + minLevel_ = other.minLevel_; + maxLevel_ = other.maxLevel_; + additionalLayerDepth_ = other.additionalLayerDepth_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FeaturePyramidNetworks Clone() { + return new FeaturePyramidNetworks(this); + } + + /// Field number for the "min_level" field. + public const int MinLevelFieldNumber = 1; + private int minLevel_; + /// + /// minimum level in feature pyramid + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MinLevel { + get { return minLevel_; } + set { + minLevel_ = value; + } + } + + /// Field number for the "max_level" field. + public const int MaxLevelFieldNumber = 2; + private int maxLevel_; + /// + /// maximum level in feature pyramid + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxLevel { + get { return maxLevel_; } + set { + maxLevel_ = value; + } + } + + /// Field number for the "additional_layer_depth" field. + public const int AdditionalLayerDepthFieldNumber = 3; + private int additionalLayerDepth_; + /// + /// channel depth for additional coarse feature layers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int AdditionalLayerDepth { + get { return additionalLayerDepth_; } + set { + additionalLayerDepth_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FeaturePyramidNetworks); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FeaturePyramidNetworks other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MinLevel != other.MinLevel) return false; + if (MaxLevel != other.MaxLevel) return false; + if (AdditionalLayerDepth != other.AdditionalLayerDepth) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MinLevel != 0) hash ^= MinLevel.GetHashCode(); + if (MaxLevel != 0) hash ^= MaxLevel.GetHashCode(); + if (AdditionalLayerDepth != 0) hash ^= AdditionalLayerDepth.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MinLevel != 0) { + output.WriteRawTag(8); + output.WriteInt32(MinLevel); + } + if (MaxLevel != 0) { + output.WriteRawTag(16); + output.WriteInt32(MaxLevel); + } + if (AdditionalLayerDepth != 0) { + output.WriteRawTag(24); + output.WriteInt32(AdditionalLayerDepth); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MinLevel != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MinLevel); + } + if (MaxLevel != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(MaxLevel); + } + if (AdditionalLayerDepth != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(AdditionalLayerDepth); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FeaturePyramidNetworks other) { + if (other == null) { + return; + } + if (other.MinLevel != 0) { + MinLevel = other.MinLevel; + } + if (other.MaxLevel != 0) { + MaxLevel = other.MaxLevel; + } + if (other.AdditionalLayerDepth != 0) { + AdditionalLayerDepth = other.AdditionalLayerDepth; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + MinLevel = input.ReadInt32(); + break; + } + case 16: { + MaxLevel = input.ReadInt32(); + break; + } + case 24: { + AdditionalLayerDepth = input.ReadInt32(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/SsdAnchorGenerator.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/SsdAnchorGenerator.cs new file mode 100644 index 00000000..e2a4a457 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/SsdAnchorGenerator.cs @@ -0,0 +1,526 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/ssd_anchor_generator.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/ssd_anchor_generator.proto + public static partial class SsdAnchorGeneratorReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/ssd_anchor_generator.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static SsdAnchorGeneratorReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjJvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9zc2RfYW5jaG9yX2dlbmVyYXRv", + "ci5wcm90bxIXb2JqZWN0X2RldGVjdGlvbi5wcm90b3Mi1QIKElNzZEFuY2hv", + "ckdlbmVyYXRvchISCgpudW1fbGF5ZXJzGAEgASgFEhEKCW1pbl9zY2FsZRgC", + "IAEoAhIRCgltYXhfc2NhbGUYAyABKAISDgoGc2NhbGVzGAwgAygCEhUKDWFz", + "cGVjdF9yYXRpb3MYBCADKAISJwofaW50ZXJwb2xhdGVkX3NjYWxlX2FzcGVj", + "dF9yYXRpbxgNIAEoAhIkChxyZWR1Y2VfYm94ZXNfaW5fbG93ZXN0X2xheWVy", + "GAUgASgIEhoKEmJhc2VfYW5jaG9yX2hlaWdodBgGIAEoAhIZChFiYXNlX2Fu", + "Y2hvcl93aWR0aBgHIAEoAhIVCg1oZWlnaHRfc3RyaWRlGAggAygFEhQKDHdp", + "ZHRoX3N0cmlkZRgJIAMoBRIVCg1oZWlnaHRfb2Zmc2V0GAogAygFEhQKDHdp", + "ZHRoX29mZnNldBgLIAMoBWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.SsdAnchorGenerator), global::Tensorflow.Models.ObjectDetection.Protos.SsdAnchorGenerator.Parser, new[]{ "NumLayers", "MinScale", "MaxScale", "Scales", "AspectRatios", "InterpolatedScaleAspectRatio", "ReduceBoxesInLowestLayer", "BaseAnchorHeight", "BaseAnchorWidth", "HeightStride", "WidthStride", "HeightOffset", "WidthOffset" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Configuration proto for SSD anchor generator described in + /// https://arxiv.org/abs/1512.02325. See + /// anchor_generators/multiple_grid_anchor_generator.py for details. + /// + public sealed partial class SsdAnchorGenerator : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new SsdAnchorGenerator()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.SsdAnchorGeneratorReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SsdAnchorGenerator() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SsdAnchorGenerator(SsdAnchorGenerator other) : this() { + numLayers_ = other.numLayers_; + minScale_ = other.minScale_; + maxScale_ = other.maxScale_; + scales_ = other.scales_.Clone(); + aspectRatios_ = other.aspectRatios_.Clone(); + interpolatedScaleAspectRatio_ = other.interpolatedScaleAspectRatio_; + reduceBoxesInLowestLayer_ = other.reduceBoxesInLowestLayer_; + baseAnchorHeight_ = other.baseAnchorHeight_; + baseAnchorWidth_ = other.baseAnchorWidth_; + heightStride_ = other.heightStride_.Clone(); + widthStride_ = other.widthStride_.Clone(); + heightOffset_ = other.heightOffset_.Clone(); + widthOffset_ = other.widthOffset_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SsdAnchorGenerator Clone() { + return new SsdAnchorGenerator(this); + } + + /// Field number for the "num_layers" field. + public const int NumLayersFieldNumber = 1; + private int numLayers_; + /// + /// Number of grid layers to create anchors for. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumLayers { + get { return numLayers_; } + set { + numLayers_ = value; + } + } + + /// Field number for the "min_scale" field. + public const int MinScaleFieldNumber = 2; + private float minScale_; + /// + /// Scale of anchors corresponding to finest resolution. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MinScale { + get { return minScale_; } + set { + minScale_ = value; + } + } + + /// Field number for the "max_scale" field. + public const int MaxScaleFieldNumber = 3; + private float maxScale_; + /// + /// Scale of anchors corresponding to coarsest resolution + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float MaxScale { + get { return maxScale_; } + set { + maxScale_ = value; + } + } + + /// Field number for the "scales" field. + public const int ScalesFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_scales_codec + = pb::FieldCodec.ForFloat(98); + private readonly pbc::RepeatedField scales_ = new pbc::RepeatedField(); + /// + /// Can be used to override min_scale->max_scale, with an explicitly defined + /// set of scales. If empty, then min_scale->max_scale is used. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Scales { + get { return scales_; } + } + + /// Field number for the "aspect_ratios" field. + public const int AspectRatiosFieldNumber = 4; + private static readonly pb::FieldCodec _repeated_aspectRatios_codec + = pb::FieldCodec.ForFloat(34); + private readonly pbc::RepeatedField aspectRatios_ = new pbc::RepeatedField(); + /// + /// Aspect ratios for anchors at each grid point. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField AspectRatios { + get { return aspectRatios_; } + } + + /// Field number for the "interpolated_scale_aspect_ratio" field. + public const int InterpolatedScaleAspectRatioFieldNumber = 13; + private float interpolatedScaleAspectRatio_; + /// + /// When this aspect ratio is greater than 0, then an additional + /// anchor, with an interpolated scale is added with this aspect ratio. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float InterpolatedScaleAspectRatio { + get { return interpolatedScaleAspectRatio_; } + set { + interpolatedScaleAspectRatio_ = value; + } + } + + /// Field number for the "reduce_boxes_in_lowest_layer" field. + public const int ReduceBoxesInLowestLayerFieldNumber = 5; + private bool reduceBoxesInLowestLayer_; + /// + /// Whether to use the following aspect ratio and scale combination for the + /// layer with the finest resolution : (scale=0.1, aspect_ratio=1.0), + /// (scale=min_scale, aspect_ration=2.0), (scale=min_scale, aspect_ratio=0.5). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool ReduceBoxesInLowestLayer { + get { return reduceBoxesInLowestLayer_; } + set { + reduceBoxesInLowestLayer_ = value; + } + } + + /// Field number for the "base_anchor_height" field. + public const int BaseAnchorHeightFieldNumber = 6; + private float baseAnchorHeight_; + /// + /// The base anchor size in height dimension. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float BaseAnchorHeight { + get { return baseAnchorHeight_; } + set { + baseAnchorHeight_ = value; + } + } + + /// Field number for the "base_anchor_width" field. + public const int BaseAnchorWidthFieldNumber = 7; + private float baseAnchorWidth_; + /// + /// The base anchor size in width dimension. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float BaseAnchorWidth { + get { return baseAnchorWidth_; } + set { + baseAnchorWidth_ = value; + } + } + + /// Field number for the "height_stride" field. + public const int HeightStrideFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_heightStride_codec + = pb::FieldCodec.ForInt32(66); + private readonly pbc::RepeatedField heightStride_ = new pbc::RepeatedField(); + /// + /// Anchor stride in height dimension in pixels for each layer. The length of + /// this field is expected to be equal to the value of num_layers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField HeightStride { + get { return heightStride_; } + } + + /// Field number for the "width_stride" field. + public const int WidthStrideFieldNumber = 9; + private static readonly pb::FieldCodec _repeated_widthStride_codec + = pb::FieldCodec.ForInt32(74); + private readonly pbc::RepeatedField widthStride_ = new pbc::RepeatedField(); + /// + /// Anchor stride in width dimension in pixels for each layer. The length of + /// this field is expected to be equal to the value of num_layers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField WidthStride { + get { return widthStride_; } + } + + /// Field number for the "height_offset" field. + public const int HeightOffsetFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_heightOffset_codec + = pb::FieldCodec.ForInt32(82); + private readonly pbc::RepeatedField heightOffset_ = new pbc::RepeatedField(); + /// + /// Anchor height offset in pixels for each layer. The length of this field is + /// expected to be equal to the value of num_layers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField HeightOffset { + get { return heightOffset_; } + } + + /// Field number for the "width_offset" field. + public const int WidthOffsetFieldNumber = 11; + private static readonly pb::FieldCodec _repeated_widthOffset_codec + = pb::FieldCodec.ForInt32(90); + private readonly pbc::RepeatedField widthOffset_ = new pbc::RepeatedField(); + /// + /// Anchor width offset in pixels for each layer. The length of this field is + /// expected to be equal to the value of num_layers. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField WidthOffset { + get { return widthOffset_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as SsdAnchorGenerator); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(SsdAnchorGenerator other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (NumLayers != other.NumLayers) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MinScale, other.MinScale)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(MaxScale, other.MaxScale)) return false; + if(!scales_.Equals(other.scales_)) return false; + if(!aspectRatios_.Equals(other.aspectRatios_)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(InterpolatedScaleAspectRatio, other.InterpolatedScaleAspectRatio)) return false; + if (ReduceBoxesInLowestLayer != other.ReduceBoxesInLowestLayer) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(BaseAnchorHeight, other.BaseAnchorHeight)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(BaseAnchorWidth, other.BaseAnchorWidth)) return false; + if(!heightStride_.Equals(other.heightStride_)) return false; + if(!widthStride_.Equals(other.widthStride_)) return false; + if(!heightOffset_.Equals(other.heightOffset_)) return false; + if(!widthOffset_.Equals(other.widthOffset_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (NumLayers != 0) hash ^= NumLayers.GetHashCode(); + if (MinScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MinScale); + if (MaxScale != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(MaxScale); + hash ^= scales_.GetHashCode(); + hash ^= aspectRatios_.GetHashCode(); + if (InterpolatedScaleAspectRatio != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(InterpolatedScaleAspectRatio); + if (ReduceBoxesInLowestLayer != false) hash ^= ReduceBoxesInLowestLayer.GetHashCode(); + if (BaseAnchorHeight != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(BaseAnchorHeight); + if (BaseAnchorWidth != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(BaseAnchorWidth); + hash ^= heightStride_.GetHashCode(); + hash ^= widthStride_.GetHashCode(); + hash ^= heightOffset_.GetHashCode(); + hash ^= widthOffset_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (NumLayers != 0) { + output.WriteRawTag(8); + output.WriteInt32(NumLayers); + } + if (MinScale != 0F) { + output.WriteRawTag(21); + output.WriteFloat(MinScale); + } + if (MaxScale != 0F) { + output.WriteRawTag(29); + output.WriteFloat(MaxScale); + } + aspectRatios_.WriteTo(output, _repeated_aspectRatios_codec); + if (ReduceBoxesInLowestLayer != false) { + output.WriteRawTag(40); + output.WriteBool(ReduceBoxesInLowestLayer); + } + if (BaseAnchorHeight != 0F) { + output.WriteRawTag(53); + output.WriteFloat(BaseAnchorHeight); + } + if (BaseAnchorWidth != 0F) { + output.WriteRawTag(61); + output.WriteFloat(BaseAnchorWidth); + } + heightStride_.WriteTo(output, _repeated_heightStride_codec); + widthStride_.WriteTo(output, _repeated_widthStride_codec); + heightOffset_.WriteTo(output, _repeated_heightOffset_codec); + widthOffset_.WriteTo(output, _repeated_widthOffset_codec); + scales_.WriteTo(output, _repeated_scales_codec); + if (InterpolatedScaleAspectRatio != 0F) { + output.WriteRawTag(109); + output.WriteFloat(InterpolatedScaleAspectRatio); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (NumLayers != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumLayers); + } + if (MinScale != 0F) { + size += 1 + 4; + } + if (MaxScale != 0F) { + size += 1 + 4; + } + size += scales_.CalculateSize(_repeated_scales_codec); + size += aspectRatios_.CalculateSize(_repeated_aspectRatios_codec); + if (InterpolatedScaleAspectRatio != 0F) { + size += 1 + 4; + } + if (ReduceBoxesInLowestLayer != false) { + size += 1 + 1; + } + if (BaseAnchorHeight != 0F) { + size += 1 + 4; + } + if (BaseAnchorWidth != 0F) { + size += 1 + 4; + } + size += heightStride_.CalculateSize(_repeated_heightStride_codec); + size += widthStride_.CalculateSize(_repeated_widthStride_codec); + size += heightOffset_.CalculateSize(_repeated_heightOffset_codec); + size += widthOffset_.CalculateSize(_repeated_widthOffset_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(SsdAnchorGenerator other) { + if (other == null) { + return; + } + if (other.NumLayers != 0) { + NumLayers = other.NumLayers; + } + if (other.MinScale != 0F) { + MinScale = other.MinScale; + } + if (other.MaxScale != 0F) { + MaxScale = other.MaxScale; + } + scales_.Add(other.scales_); + aspectRatios_.Add(other.aspectRatios_); + if (other.InterpolatedScaleAspectRatio != 0F) { + InterpolatedScaleAspectRatio = other.InterpolatedScaleAspectRatio; + } + if (other.ReduceBoxesInLowestLayer != false) { + ReduceBoxesInLowestLayer = other.ReduceBoxesInLowestLayer; + } + if (other.BaseAnchorHeight != 0F) { + BaseAnchorHeight = other.BaseAnchorHeight; + } + if (other.BaseAnchorWidth != 0F) { + BaseAnchorWidth = other.BaseAnchorWidth; + } + heightStride_.Add(other.heightStride_); + widthStride_.Add(other.widthStride_); + heightOffset_.Add(other.heightOffset_); + widthOffset_.Add(other.widthOffset_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + NumLayers = input.ReadInt32(); + break; + } + case 21: { + MinScale = input.ReadFloat(); + break; + } + case 29: { + MaxScale = input.ReadFloat(); + break; + } + case 34: + case 37: { + aspectRatios_.AddEntriesFrom(input, _repeated_aspectRatios_codec); + break; + } + case 40: { + ReduceBoxesInLowestLayer = input.ReadBool(); + break; + } + case 53: { + BaseAnchorHeight = input.ReadFloat(); + break; + } + case 61: { + BaseAnchorWidth = input.ReadFloat(); + break; + } + case 66: + case 64: { + heightStride_.AddEntriesFrom(input, _repeated_heightStride_codec); + break; + } + case 74: + case 72: { + widthStride_.AddEntriesFrom(input, _repeated_widthStride_codec); + break; + } + case 82: + case 80: { + heightOffset_.AddEntriesFrom(input, _repeated_heightOffset_codec); + break; + } + case 90: + case 88: { + widthOffset_.AddEntriesFrom(input, _repeated_widthOffset_codec); + break; + } + case 98: + case 101: { + scales_.AddEntriesFrom(input, _repeated_scales_codec); + break; + } + case 109: { + InterpolatedScaleAspectRatio = input.ReadFloat(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/StringIntLabelMap.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/StringIntLabelMap.cs new file mode 100644 index 00000000..dcb2ae8b --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/StringIntLabelMap.cs @@ -0,0 +1,365 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/string_int_label_map.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/string_int_label_map.proto + public static partial class StringIntLabelMapReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/string_int_label_map.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static StringIntLabelMapReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjJvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy9zdHJpbmdfaW50X2xhYmVsX21h", + "cC5wcm90bxIXb2JqZWN0X2RldGVjdGlvbi5wcm90b3MiRwoVU3RyaW5nSW50", + "TGFiZWxNYXBJdGVtEgwKBG5hbWUYASABKAkSCgoCaWQYAiABKAUSFAoMZGlz", + "cGxheV9uYW1lGAMgASgJIlEKEVN0cmluZ0ludExhYmVsTWFwEjwKBGl0ZW0Y", + "ASADKAsyLi5vYmplY3RfZGV0ZWN0aW9uLnByb3Rvcy5TdHJpbmdJbnRMYWJl", + "bE1hcEl0ZW1iBnByb3RvMw==")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.StringIntLabelMapItem), global::Tensorflow.Models.ObjectDetection.Protos.StringIntLabelMapItem.Parser, new[]{ "Name", "Id", "DisplayName" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.StringIntLabelMap), global::Tensorflow.Models.ObjectDetection.Protos.StringIntLabelMap.Parser, new[]{ "Item" }, null, null, null) + })); + } + #endregion + + } + #region Messages + public sealed partial class StringIntLabelMapItem : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new StringIntLabelMapItem()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.StringIntLabelMapReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public StringIntLabelMapItem() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public StringIntLabelMapItem(StringIntLabelMapItem other) : this() { + name_ = other.name_; + id_ = other.id_; + displayName_ = other.displayName_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public StringIntLabelMapItem Clone() { + return new StringIntLabelMapItem(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// String name. The most common practice is to set this to a MID or synsets + /// id. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "id" field. + public const int IdFieldNumber = 2; + private int id_; + /// + /// Integer id that maps to the string name above. Label ids should start from + /// 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Id { + get { return id_; } + set { + id_ = value; + } + } + + /// Field number for the "display_name" field. + public const int DisplayNameFieldNumber = 3; + private string displayName_ = ""; + /// + /// Human readable string label. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string DisplayName { + get { return displayName_; } + set { + displayName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as StringIntLabelMapItem); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(StringIntLabelMapItem other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (Id != other.Id) return false; + if (DisplayName != other.DisplayName) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Id != 0) hash ^= Id.GetHashCode(); + if (DisplayName.Length != 0) hash ^= DisplayName.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (Id != 0) { + output.WriteRawTag(16); + output.WriteInt32(Id); + } + if (DisplayName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(DisplayName); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Id != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id); + } + if (DisplayName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(DisplayName); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(StringIntLabelMapItem other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Id != 0) { + Id = other.Id; + } + if (other.DisplayName.Length != 0) { + DisplayName = other.DisplayName; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 16: { + Id = input.ReadInt32(); + break; + } + case 26: { + DisplayName = input.ReadString(); + break; + } + } + } + } + + } + + public sealed partial class StringIntLabelMap : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new StringIntLabelMap()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.StringIntLabelMapReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public StringIntLabelMap() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public StringIntLabelMap(StringIntLabelMap other) : this() { + item_ = other.item_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public StringIntLabelMap Clone() { + return new StringIntLabelMap(this); + } + + /// Field number for the "item" field. + public const int ItemFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_item_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.Models.ObjectDetection.Protos.StringIntLabelMapItem.Parser); + private readonly pbc::RepeatedField item_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Item { + get { return item_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as StringIntLabelMap); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(StringIntLabelMap other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!item_.Equals(other.item_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= item_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + item_.WriteTo(output, _repeated_item_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += item_.CalculateSize(_repeated_item_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(StringIntLabelMap other) { + if (other == null) { + return; + } + item_.Add(other.item_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + item_.AddEntriesFrom(input, _repeated_item_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Protos/Train.cs b/src/TensorFlowNET.Models/ObjectDetection/Protos/Train.cs new file mode 100644 index 00000000..44318fb5 --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Protos/Train.cs @@ -0,0 +1,1020 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: object_detection/protos/train.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow.Models.ObjectDetection.Protos { + + /// Holder for reflection information generated from object_detection/protos/train.proto + public static partial class TrainReflection { + + #region Descriptor + /// File descriptor for object_detection/protos/train.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static TrainReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CiNvYmplY3RfZGV0ZWN0aW9uL3Byb3Rvcy90cmFpbi5wcm90bxIXb2JqZWN0", + "X2RldGVjdGlvbi5wcm90b3MaJ29iamVjdF9kZXRlY3Rpb24vcHJvdG9zL29w", + "dGltaXplci5wcm90bxoqb2JqZWN0X2RldGVjdGlvbi9wcm90b3MvcHJlcHJv", + "Y2Vzc29yLnByb3RvIpoHCgtUcmFpbkNvbmZpZxISCgpiYXRjaF9zaXplGAEg", + "ASgNEk0KGWRhdGFfYXVnbWVudGF0aW9uX29wdGlvbnMYAiADKAsyKi5vYmpl", + "Y3RfZGV0ZWN0aW9uLnByb3Rvcy5QcmVwcm9jZXNzaW5nU3RlcBIVCg1zeW5j", + "X3JlcGxpY2FzGAMgASgIEiUKHWtlZXBfY2hlY2twb2ludF9ldmVyeV9uX2hv", + "dXJzGAQgASgCEjUKCW9wdGltaXplchgFIAEoCzIiLm9iamVjdF9kZXRlY3Rp", + "b24ucHJvdG9zLk9wdGltaXplchIhChlncmFkaWVudF9jbGlwcGluZ19ieV9u", + "b3JtGAYgASgCEhwKFGZpbmVfdHVuZV9jaGVja3BvaW50GAcgASgJEiEKGWZp", + "bmVfdHVuZV9jaGVja3BvaW50X3R5cGUYFiABKAkSIQoZZnJvbV9kZXRlY3Rp", + "b25fY2hlY2twb2ludBgIIAEoCBIqCiJsb2FkX2FsbF9kZXRlY3Rpb25fY2hl", + "Y2twb2ludF92YXJzGBMgASgIEhEKCW51bV9zdGVwcxgJIAEoDRIbChNzdGFy", + "dHVwX2RlbGF5X3N0ZXBzGAogASgCEhwKFGJpYXNfZ3JhZF9tdWx0aXBsaWVy", + "GAsgASgCEiIKGnVwZGF0ZV90cmFpbmFibGVfdmFyaWFibGVzGBkgAygJEhgK", + "EGZyZWV6ZV92YXJpYWJsZXMYDCADKAkSHQoVcmVwbGljYXNfdG9fYWdncmVn", + "YXRlGA0gASgFEhwKFGJhdGNoX3F1ZXVlX2NhcGFjaXR5GA4gASgFEh8KF251", + "bV9iYXRjaF9xdWV1ZV90aHJlYWRzGA8gASgFEh8KF3ByZWZldGNoX3F1ZXVl", + "X2NhcGFjaXR5GBAgASgFEiIKGm1lcmdlX211bHRpcGxlX2xhYmVsX2JveGVz", + "GBEgASgIEh0KFXVzZV9tdWx0aWNsYXNzX3Njb3JlcxgYIAEoCBIfChdhZGRf", + "cmVndWxhcml6YXRpb25fbG9zcxgSIAEoCBIbChNtYXhfbnVtYmVyX29mX2Jv", + "eGVzGBQgASgFEiEKGXVucGFkX2dyb3VuZHRydXRoX3RlbnNvcnMYFSABKAgS", + "HgoWcmV0YWluX29yaWdpbmFsX2ltYWdlcxgXIAEoCBIUCgx1c2VfYmZsb2F0", + "MTYYGiABKAgSGwoTc3VtbWFyaXplX2dyYWRpZW50cxgbIAEoCGIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { global::Tensorflow.Models.ObjectDetection.Protos.OptimizerReflection.Descriptor, global::Tensorflow.Models.ObjectDetection.Protos.PreprocessorReflection.Descriptor, }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.Models.ObjectDetection.Protos.TrainConfig), global::Tensorflow.Models.ObjectDetection.Protos.TrainConfig.Parser, new[]{ "BatchSize", "DataAugmentationOptions", "SyncReplicas", "KeepCheckpointEveryNHours", "Optimizer", "GradientClippingByNorm", "FineTuneCheckpoint", "FineTuneCheckpointType", "FromDetectionCheckpoint", "LoadAllDetectionCheckpointVars", "NumSteps", "StartupDelaySteps", "BiasGradMultiplier", "UpdateTrainableVariables", "FreezeVariables", "ReplicasToAggregate", "BatchQueueCapacity", "NumBatchQueueThreads", "PrefetchQueueCapacity", "MergeMultipleLabelBoxes", "UseMulticlassScores", "AddRegularizationLoss", "MaxNumberOfBoxes", "UnpadGroundtruthTensors", "RetainOriginalImages", "UseBfloat16", "SummarizeGradients" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Message for configuring DetectionModel training jobs (train.py). + /// Next id: 28 + /// + public sealed partial class TrainConfig : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new TrainConfig()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.Models.ObjectDetection.Protos.TrainReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainConfig() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainConfig(TrainConfig other) : this() { + batchSize_ = other.batchSize_; + dataAugmentationOptions_ = other.dataAugmentationOptions_.Clone(); + syncReplicas_ = other.syncReplicas_; + keepCheckpointEveryNHours_ = other.keepCheckpointEveryNHours_; + optimizer_ = other.optimizer_ != null ? other.optimizer_.Clone() : null; + gradientClippingByNorm_ = other.gradientClippingByNorm_; + fineTuneCheckpoint_ = other.fineTuneCheckpoint_; + fineTuneCheckpointType_ = other.fineTuneCheckpointType_; + fromDetectionCheckpoint_ = other.fromDetectionCheckpoint_; + loadAllDetectionCheckpointVars_ = other.loadAllDetectionCheckpointVars_; + numSteps_ = other.numSteps_; + startupDelaySteps_ = other.startupDelaySteps_; + biasGradMultiplier_ = other.biasGradMultiplier_; + updateTrainableVariables_ = other.updateTrainableVariables_.Clone(); + freezeVariables_ = other.freezeVariables_.Clone(); + replicasToAggregate_ = other.replicasToAggregate_; + batchQueueCapacity_ = other.batchQueueCapacity_; + numBatchQueueThreads_ = other.numBatchQueueThreads_; + prefetchQueueCapacity_ = other.prefetchQueueCapacity_; + mergeMultipleLabelBoxes_ = other.mergeMultipleLabelBoxes_; + useMulticlassScores_ = other.useMulticlassScores_; + addRegularizationLoss_ = other.addRegularizationLoss_; + maxNumberOfBoxes_ = other.maxNumberOfBoxes_; + unpadGroundtruthTensors_ = other.unpadGroundtruthTensors_; + retainOriginalImages_ = other.retainOriginalImages_; + useBfloat16_ = other.useBfloat16_; + summarizeGradients_ = other.summarizeGradients_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrainConfig Clone() { + return new TrainConfig(this); + } + + /// Field number for the "batch_size" field. + public const int BatchSizeFieldNumber = 1; + private uint batchSize_; + /// + /// Effective batch size to use for training. + /// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be + /// `batch_size` / number of cores (or `batch_size` / number of GPUs). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint BatchSize { + get { return batchSize_; } + set { + batchSize_ = value; + } + } + + /// Field number for the "data_augmentation_options" field. + public const int DataAugmentationOptionsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_dataAugmentationOptions_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.Models.ObjectDetection.Protos.PreprocessingStep.Parser); + private readonly pbc::RepeatedField dataAugmentationOptions_ = new pbc::RepeatedField(); + /// + /// Data augmentation options. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField DataAugmentationOptions { + get { return dataAugmentationOptions_; } + } + + /// Field number for the "sync_replicas" field. + public const int SyncReplicasFieldNumber = 3; + private bool syncReplicas_; + /// + /// Whether to synchronize replicas during training. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SyncReplicas { + get { return syncReplicas_; } + set { + syncReplicas_ = value; + } + } + + /// Field number for the "keep_checkpoint_every_n_hours" field. + public const int KeepCheckpointEveryNHoursFieldNumber = 4; + private float keepCheckpointEveryNHours_; + /// + /// How frequently to keep checkpoints. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float KeepCheckpointEveryNHours { + get { return keepCheckpointEveryNHours_; } + set { + keepCheckpointEveryNHours_ = value; + } + } + + /// Field number for the "optimizer" field. + public const int OptimizerFieldNumber = 5; + private global::Tensorflow.Models.ObjectDetection.Protos.Optimizer optimizer_; + /// + /// Optimizer used to train the DetectionModel. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.Models.ObjectDetection.Protos.Optimizer Optimizer { + get { return optimizer_; } + set { + optimizer_ = value; + } + } + + /// Field number for the "gradient_clipping_by_norm" field. + public const int GradientClippingByNormFieldNumber = 6; + private float gradientClippingByNorm_; + /// + /// If greater than 0, clips gradients by this value. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float GradientClippingByNorm { + get { return gradientClippingByNorm_; } + set { + gradientClippingByNorm_ = value; + } + } + + /// Field number for the "fine_tune_checkpoint" field. + public const int FineTuneCheckpointFieldNumber = 7; + private string fineTuneCheckpoint_ = ""; + /// + /// Checkpoint to restore variables from. Typically used to load feature + /// extractor variables trained outside of object detection. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string FineTuneCheckpoint { + get { return fineTuneCheckpoint_; } + set { + fineTuneCheckpoint_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "fine_tune_checkpoint_type" field. + public const int FineTuneCheckpointTypeFieldNumber = 22; + private string fineTuneCheckpointType_ = ""; + /// + /// Type of checkpoint to restore variables from, e.g. 'classification' or + /// 'detection'. Provides extensibility to from_detection_checkpoint. + /// Typically used to load feature extractor variables from trained models. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string FineTuneCheckpointType { + get { return fineTuneCheckpointType_; } + set { + fineTuneCheckpointType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "from_detection_checkpoint" field. + public const int FromDetectionCheckpointFieldNumber = 8; + private bool fromDetectionCheckpoint_; + /// + /// [Deprecated]: use fine_tune_checkpoint_type instead. + /// Specifies if the finetune checkpoint is from an object detection model. + /// If from an object detection model, the model being trained should have + /// the same parameters with the exception of the num_classes parameter. + /// If false, it assumes the checkpoint was a object classification model. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool FromDetectionCheckpoint { + get { return fromDetectionCheckpoint_; } + set { + fromDetectionCheckpoint_ = value; + } + } + + /// Field number for the "load_all_detection_checkpoint_vars" field. + public const int LoadAllDetectionCheckpointVarsFieldNumber = 19; + private bool loadAllDetectionCheckpointVars_; + /// + /// Whether to load all checkpoint vars that match model variable names and + /// sizes. This option is only available if `from_detection_checkpoint` is + /// True. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool LoadAllDetectionCheckpointVars { + get { return loadAllDetectionCheckpointVars_; } + set { + loadAllDetectionCheckpointVars_ = value; + } + } + + /// Field number for the "num_steps" field. + public const int NumStepsFieldNumber = 9; + private uint numSteps_; + /// + /// Number of steps to train the DetectionModel for. If 0, will train the model + /// indefinitely. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public uint NumSteps { + get { return numSteps_; } + set { + numSteps_ = value; + } + } + + /// Field number for the "startup_delay_steps" field. + public const int StartupDelayStepsFieldNumber = 10; + private float startupDelaySteps_; + /// + /// Number of training steps between replica startup. + /// This flag must be set to 0 if sync_replicas is set to true. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float StartupDelaySteps { + get { return startupDelaySteps_; } + set { + startupDelaySteps_ = value; + } + } + + /// Field number for the "bias_grad_multiplier" field. + public const int BiasGradMultiplierFieldNumber = 11; + private float biasGradMultiplier_; + /// + /// If greater than 0, multiplies the gradient of bias variables by this + /// amount. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public float BiasGradMultiplier { + get { return biasGradMultiplier_; } + set { + biasGradMultiplier_ = value; + } + } + + /// Field number for the "update_trainable_variables" field. + public const int UpdateTrainableVariablesFieldNumber = 25; + private static readonly pb::FieldCodec _repeated_updateTrainableVariables_codec + = pb::FieldCodec.ForString(202); + private readonly pbc::RepeatedField updateTrainableVariables_ = new pbc::RepeatedField(); + /// + /// Variables that should be updated during training. Note that variables which + /// also match the patterns in freeze_variables will be excluded. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField UpdateTrainableVariables { + get { return updateTrainableVariables_; } + } + + /// Field number for the "freeze_variables" field. + public const int FreezeVariablesFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_freezeVariables_codec + = pb::FieldCodec.ForString(98); + private readonly pbc::RepeatedField freezeVariables_ = new pbc::RepeatedField(); + /// + /// Variables that should not be updated during training. If + /// update_trainable_variables is not empty, only eliminates the included + /// variables according to freeze_variables patterns. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField FreezeVariables { + get { return freezeVariables_; } + } + + /// Field number for the "replicas_to_aggregate" field. + public const int ReplicasToAggregateFieldNumber = 13; + private int replicasToAggregate_; + /// + /// Number of replicas to aggregate before making parameter updates. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int ReplicasToAggregate { + get { return replicasToAggregate_; } + set { + replicasToAggregate_ = value; + } + } + + /// Field number for the "batch_queue_capacity" field. + public const int BatchQueueCapacityFieldNumber = 14; + private int batchQueueCapacity_; + /// + /// Maximum number of elements to store within a queue. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int BatchQueueCapacity { + get { return batchQueueCapacity_; } + set { + batchQueueCapacity_ = value; + } + } + + /// Field number for the "num_batch_queue_threads" field. + public const int NumBatchQueueThreadsFieldNumber = 15; + private int numBatchQueueThreads_; + /// + /// Number of threads to use for batching. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumBatchQueueThreads { + get { return numBatchQueueThreads_; } + set { + numBatchQueueThreads_ = value; + } + } + + /// Field number for the "prefetch_queue_capacity" field. + public const int PrefetchQueueCapacityFieldNumber = 16; + private int prefetchQueueCapacity_; + /// + /// Maximum capacity of the queue used to prefetch assembled batches. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int PrefetchQueueCapacity { + get { return prefetchQueueCapacity_; } + set { + prefetchQueueCapacity_ = value; + } + } + + /// Field number for the "merge_multiple_label_boxes" field. + public const int MergeMultipleLabelBoxesFieldNumber = 17; + private bool mergeMultipleLabelBoxes_; + /// + /// If true, boxes with the same coordinates will be merged together. + /// This is useful when each box can have multiple labels. + /// Note that only Sigmoid classification losses should be used. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool MergeMultipleLabelBoxes { + get { return mergeMultipleLabelBoxes_; } + set { + mergeMultipleLabelBoxes_ = value; + } + } + + /// Field number for the "use_multiclass_scores" field. + public const int UseMulticlassScoresFieldNumber = 24; + private bool useMulticlassScores_; + /// + /// If true, will use multiclass scores from object annotations as ground + /// truth. Currently only compatible with annotated image inputs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseMulticlassScores { + get { return useMulticlassScores_; } + set { + useMulticlassScores_ = value; + } + } + + /// Field number for the "add_regularization_loss" field. + public const int AddRegularizationLossFieldNumber = 18; + private bool addRegularizationLoss_; + /// + /// Whether to add regularization loss to `total_loss`. This is true by + /// default and adds all regularization losses defined in the model to + /// `total_loss`. + /// Setting this option to false is very useful while debugging the model and + /// losses. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool AddRegularizationLoss { + get { return addRegularizationLoss_; } + set { + addRegularizationLoss_ = value; + } + } + + /// Field number for the "max_number_of_boxes" field. + public const int MaxNumberOfBoxesFieldNumber = 20; + private int maxNumberOfBoxes_; + /// + /// Maximum number of boxes used during training. + /// Set this to at least the maximum amount of boxes in the input data. + /// Otherwise, it may cause "Data loss: Attempted to pad to a smaller size + /// than the input element" errors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MaxNumberOfBoxes { + get { return maxNumberOfBoxes_; } + set { + maxNumberOfBoxes_ = value; + } + } + + /// Field number for the "unpad_groundtruth_tensors" field. + public const int UnpadGroundtruthTensorsFieldNumber = 21; + private bool unpadGroundtruthTensors_; + /// + /// Whether to remove padding along `num_boxes` dimension of the groundtruth + /// tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UnpadGroundtruthTensors { + get { return unpadGroundtruthTensors_; } + set { + unpadGroundtruthTensors_ = value; + } + } + + /// Field number for the "retain_original_images" field. + public const int RetainOriginalImagesFieldNumber = 23; + private bool retainOriginalImages_; + /// + /// Whether to retain original images (i.e. not pre-processed) in the tensor + /// dictionary, so that they can be displayed in Tensorboard. Note that this + /// will lead to a larger memory footprint. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool RetainOriginalImages { + get { return retainOriginalImages_; } + set { + retainOriginalImages_ = value; + } + } + + /// Field number for the "use_bfloat16" field. + public const int UseBfloat16FieldNumber = 26; + private bool useBfloat16_; + /// + /// Whether to use bfloat16 for training. This is currently only supported for + /// TPUs. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseBfloat16 { + get { return useBfloat16_; } + set { + useBfloat16_ = value; + } + } + + /// Field number for the "summarize_gradients" field. + public const int SummarizeGradientsFieldNumber = 27; + private bool summarizeGradients_; + /// + /// Whether to summarize gradients. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SummarizeGradients { + get { return summarizeGradients_; } + set { + summarizeGradients_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as TrainConfig); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(TrainConfig other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (BatchSize != other.BatchSize) return false; + if(!dataAugmentationOptions_.Equals(other.dataAugmentationOptions_)) return false; + if (SyncReplicas != other.SyncReplicas) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(KeepCheckpointEveryNHours, other.KeepCheckpointEveryNHours)) return false; + if (!object.Equals(Optimizer, other.Optimizer)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(GradientClippingByNorm, other.GradientClippingByNorm)) return false; + if (FineTuneCheckpoint != other.FineTuneCheckpoint) return false; + if (FineTuneCheckpointType != other.FineTuneCheckpointType) return false; + if (FromDetectionCheckpoint != other.FromDetectionCheckpoint) return false; + if (LoadAllDetectionCheckpointVars != other.LoadAllDetectionCheckpointVars) return false; + if (NumSteps != other.NumSteps) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(StartupDelaySteps, other.StartupDelaySteps)) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(BiasGradMultiplier, other.BiasGradMultiplier)) return false; + if(!updateTrainableVariables_.Equals(other.updateTrainableVariables_)) return false; + if(!freezeVariables_.Equals(other.freezeVariables_)) return false; + if (ReplicasToAggregate != other.ReplicasToAggregate) return false; + if (BatchQueueCapacity != other.BatchQueueCapacity) return false; + if (NumBatchQueueThreads != other.NumBatchQueueThreads) return false; + if (PrefetchQueueCapacity != other.PrefetchQueueCapacity) return false; + if (MergeMultipleLabelBoxes != other.MergeMultipleLabelBoxes) return false; + if (UseMulticlassScores != other.UseMulticlassScores) return false; + if (AddRegularizationLoss != other.AddRegularizationLoss) return false; + if (MaxNumberOfBoxes != other.MaxNumberOfBoxes) return false; + if (UnpadGroundtruthTensors != other.UnpadGroundtruthTensors) return false; + if (RetainOriginalImages != other.RetainOriginalImages) return false; + if (UseBfloat16 != other.UseBfloat16) return false; + if (SummarizeGradients != other.SummarizeGradients) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (BatchSize != 0) hash ^= BatchSize.GetHashCode(); + hash ^= dataAugmentationOptions_.GetHashCode(); + if (SyncReplicas != false) hash ^= SyncReplicas.GetHashCode(); + if (KeepCheckpointEveryNHours != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(KeepCheckpointEveryNHours); + if (optimizer_ != null) hash ^= Optimizer.GetHashCode(); + if (GradientClippingByNorm != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(GradientClippingByNorm); + if (FineTuneCheckpoint.Length != 0) hash ^= FineTuneCheckpoint.GetHashCode(); + if (FineTuneCheckpointType.Length != 0) hash ^= FineTuneCheckpointType.GetHashCode(); + if (FromDetectionCheckpoint != false) hash ^= FromDetectionCheckpoint.GetHashCode(); + if (LoadAllDetectionCheckpointVars != false) hash ^= LoadAllDetectionCheckpointVars.GetHashCode(); + if (NumSteps != 0) hash ^= NumSteps.GetHashCode(); + if (StartupDelaySteps != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(StartupDelaySteps); + if (BiasGradMultiplier != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(BiasGradMultiplier); + hash ^= updateTrainableVariables_.GetHashCode(); + hash ^= freezeVariables_.GetHashCode(); + if (ReplicasToAggregate != 0) hash ^= ReplicasToAggregate.GetHashCode(); + if (BatchQueueCapacity != 0) hash ^= BatchQueueCapacity.GetHashCode(); + if (NumBatchQueueThreads != 0) hash ^= NumBatchQueueThreads.GetHashCode(); + if (PrefetchQueueCapacity != 0) hash ^= PrefetchQueueCapacity.GetHashCode(); + if (MergeMultipleLabelBoxes != false) hash ^= MergeMultipleLabelBoxes.GetHashCode(); + if (UseMulticlassScores != false) hash ^= UseMulticlassScores.GetHashCode(); + if (AddRegularizationLoss != false) hash ^= AddRegularizationLoss.GetHashCode(); + if (MaxNumberOfBoxes != 0) hash ^= MaxNumberOfBoxes.GetHashCode(); + if (UnpadGroundtruthTensors != false) hash ^= UnpadGroundtruthTensors.GetHashCode(); + if (RetainOriginalImages != false) hash ^= RetainOriginalImages.GetHashCode(); + if (UseBfloat16 != false) hash ^= UseBfloat16.GetHashCode(); + if (SummarizeGradients != false) hash ^= SummarizeGradients.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (BatchSize != 0) { + output.WriteRawTag(8); + output.WriteUInt32(BatchSize); + } + dataAugmentationOptions_.WriteTo(output, _repeated_dataAugmentationOptions_codec); + if (SyncReplicas != false) { + output.WriteRawTag(24); + output.WriteBool(SyncReplicas); + } + if (KeepCheckpointEveryNHours != 0F) { + output.WriteRawTag(37); + output.WriteFloat(KeepCheckpointEveryNHours); + } + if (optimizer_ != null) { + output.WriteRawTag(42); + output.WriteMessage(Optimizer); + } + if (GradientClippingByNorm != 0F) { + output.WriteRawTag(53); + output.WriteFloat(GradientClippingByNorm); + } + if (FineTuneCheckpoint.Length != 0) { + output.WriteRawTag(58); + output.WriteString(FineTuneCheckpoint); + } + if (FromDetectionCheckpoint != false) { + output.WriteRawTag(64); + output.WriteBool(FromDetectionCheckpoint); + } + if (NumSteps != 0) { + output.WriteRawTag(72); + output.WriteUInt32(NumSteps); + } + if (StartupDelaySteps != 0F) { + output.WriteRawTag(85); + output.WriteFloat(StartupDelaySteps); + } + if (BiasGradMultiplier != 0F) { + output.WriteRawTag(93); + output.WriteFloat(BiasGradMultiplier); + } + freezeVariables_.WriteTo(output, _repeated_freezeVariables_codec); + if (ReplicasToAggregate != 0) { + output.WriteRawTag(104); + output.WriteInt32(ReplicasToAggregate); + } + if (BatchQueueCapacity != 0) { + output.WriteRawTag(112); + output.WriteInt32(BatchQueueCapacity); + } + if (NumBatchQueueThreads != 0) { + output.WriteRawTag(120); + output.WriteInt32(NumBatchQueueThreads); + } + if (PrefetchQueueCapacity != 0) { + output.WriteRawTag(128, 1); + output.WriteInt32(PrefetchQueueCapacity); + } + if (MergeMultipleLabelBoxes != false) { + output.WriteRawTag(136, 1); + output.WriteBool(MergeMultipleLabelBoxes); + } + if (AddRegularizationLoss != false) { + output.WriteRawTag(144, 1); + output.WriteBool(AddRegularizationLoss); + } + if (LoadAllDetectionCheckpointVars != false) { + output.WriteRawTag(152, 1); + output.WriteBool(LoadAllDetectionCheckpointVars); + } + if (MaxNumberOfBoxes != 0) { + output.WriteRawTag(160, 1); + output.WriteInt32(MaxNumberOfBoxes); + } + if (UnpadGroundtruthTensors != false) { + output.WriteRawTag(168, 1); + output.WriteBool(UnpadGroundtruthTensors); + } + if (FineTuneCheckpointType.Length != 0) { + output.WriteRawTag(178, 1); + output.WriteString(FineTuneCheckpointType); + } + if (RetainOriginalImages != false) { + output.WriteRawTag(184, 1); + output.WriteBool(RetainOriginalImages); + } + if (UseMulticlassScores != false) { + output.WriteRawTag(192, 1); + output.WriteBool(UseMulticlassScores); + } + updateTrainableVariables_.WriteTo(output, _repeated_updateTrainableVariables_codec); + if (UseBfloat16 != false) { + output.WriteRawTag(208, 1); + output.WriteBool(UseBfloat16); + } + if (SummarizeGradients != false) { + output.WriteRawTag(216, 1); + output.WriteBool(SummarizeGradients); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (BatchSize != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(BatchSize); + } + size += dataAugmentationOptions_.CalculateSize(_repeated_dataAugmentationOptions_codec); + if (SyncReplicas != false) { + size += 1 + 1; + } + if (KeepCheckpointEveryNHours != 0F) { + size += 1 + 4; + } + if (optimizer_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Optimizer); + } + if (GradientClippingByNorm != 0F) { + size += 1 + 4; + } + if (FineTuneCheckpoint.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(FineTuneCheckpoint); + } + if (FineTuneCheckpointType.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(FineTuneCheckpointType); + } + if (FromDetectionCheckpoint != false) { + size += 1 + 1; + } + if (LoadAllDetectionCheckpointVars != false) { + size += 2 + 1; + } + if (NumSteps != 0) { + size += 1 + pb::CodedOutputStream.ComputeUInt32Size(NumSteps); + } + if (StartupDelaySteps != 0F) { + size += 1 + 4; + } + if (BiasGradMultiplier != 0F) { + size += 1 + 4; + } + size += updateTrainableVariables_.CalculateSize(_repeated_updateTrainableVariables_codec); + size += freezeVariables_.CalculateSize(_repeated_freezeVariables_codec); + if (ReplicasToAggregate != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ReplicasToAggregate); + } + if (BatchQueueCapacity != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(BatchQueueCapacity); + } + if (NumBatchQueueThreads != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumBatchQueueThreads); + } + if (PrefetchQueueCapacity != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(PrefetchQueueCapacity); + } + if (MergeMultipleLabelBoxes != false) { + size += 2 + 1; + } + if (UseMulticlassScores != false) { + size += 2 + 1; + } + if (AddRegularizationLoss != false) { + size += 2 + 1; + } + if (MaxNumberOfBoxes != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(MaxNumberOfBoxes); + } + if (UnpadGroundtruthTensors != false) { + size += 2 + 1; + } + if (RetainOriginalImages != false) { + size += 2 + 1; + } + if (UseBfloat16 != false) { + size += 2 + 1; + } + if (SummarizeGradients != false) { + size += 2 + 1; + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(TrainConfig other) { + if (other == null) { + return; + } + if (other.BatchSize != 0) { + BatchSize = other.BatchSize; + } + dataAugmentationOptions_.Add(other.dataAugmentationOptions_); + if (other.SyncReplicas != false) { + SyncReplicas = other.SyncReplicas; + } + if (other.KeepCheckpointEveryNHours != 0F) { + KeepCheckpointEveryNHours = other.KeepCheckpointEveryNHours; + } + if (other.optimizer_ != null) { + if (optimizer_ == null) { + optimizer_ = new global::Tensorflow.Models.ObjectDetection.Protos.Optimizer(); + } + Optimizer.MergeFrom(other.Optimizer); + } + if (other.GradientClippingByNorm != 0F) { + GradientClippingByNorm = other.GradientClippingByNorm; + } + if (other.FineTuneCheckpoint.Length != 0) { + FineTuneCheckpoint = other.FineTuneCheckpoint; + } + if (other.FineTuneCheckpointType.Length != 0) { + FineTuneCheckpointType = other.FineTuneCheckpointType; + } + if (other.FromDetectionCheckpoint != false) { + FromDetectionCheckpoint = other.FromDetectionCheckpoint; + } + if (other.LoadAllDetectionCheckpointVars != false) { + LoadAllDetectionCheckpointVars = other.LoadAllDetectionCheckpointVars; + } + if (other.NumSteps != 0) { + NumSteps = other.NumSteps; + } + if (other.StartupDelaySteps != 0F) { + StartupDelaySteps = other.StartupDelaySteps; + } + if (other.BiasGradMultiplier != 0F) { + BiasGradMultiplier = other.BiasGradMultiplier; + } + updateTrainableVariables_.Add(other.updateTrainableVariables_); + freezeVariables_.Add(other.freezeVariables_); + if (other.ReplicasToAggregate != 0) { + ReplicasToAggregate = other.ReplicasToAggregate; + } + if (other.BatchQueueCapacity != 0) { + BatchQueueCapacity = other.BatchQueueCapacity; + } + if (other.NumBatchQueueThreads != 0) { + NumBatchQueueThreads = other.NumBatchQueueThreads; + } + if (other.PrefetchQueueCapacity != 0) { + PrefetchQueueCapacity = other.PrefetchQueueCapacity; + } + if (other.MergeMultipleLabelBoxes != false) { + MergeMultipleLabelBoxes = other.MergeMultipleLabelBoxes; + } + if (other.UseMulticlassScores != false) { + UseMulticlassScores = other.UseMulticlassScores; + } + if (other.AddRegularizationLoss != false) { + AddRegularizationLoss = other.AddRegularizationLoss; + } + if (other.MaxNumberOfBoxes != 0) { + MaxNumberOfBoxes = other.MaxNumberOfBoxes; + } + if (other.UnpadGroundtruthTensors != false) { + UnpadGroundtruthTensors = other.UnpadGroundtruthTensors; + } + if (other.RetainOriginalImages != false) { + RetainOriginalImages = other.RetainOriginalImages; + } + if (other.UseBfloat16 != false) { + UseBfloat16 = other.UseBfloat16; + } + if (other.SummarizeGradients != false) { + SummarizeGradients = other.SummarizeGradients; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + BatchSize = input.ReadUInt32(); + break; + } + case 18: { + dataAugmentationOptions_.AddEntriesFrom(input, _repeated_dataAugmentationOptions_codec); + break; + } + case 24: { + SyncReplicas = input.ReadBool(); + break; + } + case 37: { + KeepCheckpointEveryNHours = input.ReadFloat(); + break; + } + case 42: { + if (optimizer_ == null) { + optimizer_ = new global::Tensorflow.Models.ObjectDetection.Protos.Optimizer(); + } + input.ReadMessage(optimizer_); + break; + } + case 53: { + GradientClippingByNorm = input.ReadFloat(); + break; + } + case 58: { + FineTuneCheckpoint = input.ReadString(); + break; + } + case 64: { + FromDetectionCheckpoint = input.ReadBool(); + break; + } + case 72: { + NumSteps = input.ReadUInt32(); + break; + } + case 85: { + StartupDelaySteps = input.ReadFloat(); + break; + } + case 93: { + BiasGradMultiplier = input.ReadFloat(); + break; + } + case 98: { + freezeVariables_.AddEntriesFrom(input, _repeated_freezeVariables_codec); + break; + } + case 104: { + ReplicasToAggregate = input.ReadInt32(); + break; + } + case 112: { + BatchQueueCapacity = input.ReadInt32(); + break; + } + case 120: { + NumBatchQueueThreads = input.ReadInt32(); + break; + } + case 128: { + PrefetchQueueCapacity = input.ReadInt32(); + break; + } + case 136: { + MergeMultipleLabelBoxes = input.ReadBool(); + break; + } + case 144: { + AddRegularizationLoss = input.ReadBool(); + break; + } + case 152: { + LoadAllDetectionCheckpointVars = input.ReadBool(); + break; + } + case 160: { + MaxNumberOfBoxes = input.ReadInt32(); + break; + } + case 168: { + UnpadGroundtruthTensors = input.ReadBool(); + break; + } + case 178: { + FineTuneCheckpointType = input.ReadString(); + break; + } + case 184: { + RetainOriginalImages = input.ReadBool(); + break; + } + case 192: { + UseMulticlassScores = input.ReadBool(); + break; + } + case 202: { + updateTrainableVariables_.AddEntriesFrom(input, _repeated_updateTrainableVariables_codec); + break; + } + case 208: { + UseBfloat16 = input.ReadBool(); + break; + } + case 216: { + SummarizeGradients = input.ReadBool(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs b/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs new file mode 100644 index 00000000..2a6a672e --- /dev/null +++ b/src/TensorFlowNET.Models/ObjectDetection/Utils/ConfigUtil.cs @@ -0,0 +1,101 @@ +using Protobuf.Text; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Tensorflow.Models.ObjectDetection.Protos; + +namespace Tensorflow.Models.ObjectDetection.Utils +{ + public class ConfigUtil + { + public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path) + { + var config = File.ReadAllText(pipeline_config_path); + var pipeline_config = TrainEvalPipelineConfig.Parser.ParseText(config); + + return pipeline_config; + } + + public static ImageResizer get_image_resizer_config(DetectionModel model_config) + { + var meta_architecture = model_config.ModelCase; + + if (meta_architecture == DetectionModel.ModelOneofCase.FasterRcnn) + return model_config.FasterRcnn.ImageResizer; + else if (meta_architecture == DetectionModel.ModelOneofCase.Ssd) + return model_config.Ssd.ImageResizer; + + throw new Exception($"Unknown model type: {meta_architecture}"); + } + + public static (int, int) get_spatial_image_size(ImageResizer image_resizer_config) + { + if (image_resizer_config.ImageResizerOneofCase == ImageResizer.ImageResizerOneofOneofCase.FixedShapeResizer) + return (image_resizer_config.FixedShapeResizer.Height, image_resizer_config.FixedShapeResizer.Width); + else if (image_resizer_config.ImageResizerOneofCase == ImageResizer.ImageResizerOneofOneofCase.KeepAspectRatioResizer) + { + if (image_resizer_config.KeepAspectRatioResizer.PadToMaxDimension) + return (image_resizer_config.KeepAspectRatioResizer.MaxDimension, image_resizer_config.KeepAspectRatioResizer.MaxDimension); + else + return (-1, -1); + } + else if (image_resizer_config.ImageResizerOneofCase == ImageResizer.ImageResizerOneofOneofCase.IdentityResizer + || image_resizer_config.ImageResizerOneofCase == ImageResizer.ImageResizerOneofOneofCase.ConditionalShapeResizer) + { + return (-1, -1); + } + + throw new Exception("Unknown image resizer type."); + } + + public static Dictionary create_configs_from_pipeline_proto(TrainEvalPipelineConfig pipeline_config) + { + var configs = new Dictionary(StringComparer.OrdinalIgnoreCase); + + configs["model"] = pipeline_config.Model; + configs["train_config"] = pipeline_config.TrainConfig; + configs["train_input_config"] = pipeline_config.TrainInputReader; + configs["eval_config"] = pipeline_config.EvalConfig; + configs["eval_input_configs"] = pipeline_config.EvalInputReader; + + //# Keeps eval_input_config only for backwards compatibility. All clients should + //# read eval_input_configs instead. + if (pipeline_config.EvalInputReader != null && pipeline_config.EvalInputReader.Count > 0) + configs["eval_input_config"] = pipeline_config.EvalInputReader[0]; + + if (pipeline_config.GraphRewriter != null) + configs["graph_rewriter_config"] = pipeline_config.GraphRewriter; + + return configs; + } + + public static GraphRewriter get_graph_rewriter_config_from_file(string graph_rewriter_config_file) + { + throw new NotImplementedException(); + } + + public static int get_number_of_classes(DetectionModel model_config) + { + var meta_architecture = model_config.ModelCase; + + if (meta_architecture == DetectionModel.ModelOneofCase.FasterRcnn) + return model_config.FasterRcnn.NumClasses; + + if (meta_architecture == DetectionModel.ModelOneofCase.Ssd) + return model_config.Ssd.NumClasses; + + throw new Exception("Expected the model to be one of 'faster_rcnn' or 'ssd'."); + } + + public static Protos.Optimizer.OptimizerOneofCase get_optimizer_type(TrainConfig train_config) + { + return train_config.Optimizer.OptimizerCase; + } + + public static string get_learning_rate_type(Optimizer optimizer_config) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Models/Slim/Nets/ResNetV1.cs b/src/TensorFlowNET.Models/Slim/Nets/ResNetV1.cs new file mode 100644 index 00000000..1d949434 --- /dev/null +++ b/src/TensorFlowNET.Models/Slim/Nets/ResNetV1.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Models.Slim.Nets +{ + public class ResNetV1 + { + public static void resnet_v1_101() + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj b/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj new file mode 100644 index 00000000..aae55be9 --- /dev/null +++ b/src/TensorFlowNET.Models/TensorFlowNET.Models.csproj @@ -0,0 +1,37 @@ + + + + netcoreapp2.2 + TensorFlow.Models + Tensorflow.Models + 0.0.1 + Haiping Chen + SciSharp STACK + https://github.com/SciSharp/TensorFlow.NET + https://github.com/SciSharp/TensorFlow.NET + git + TensorFlow + Models and examples built with TensorFlow. + true + Apache2 + + + + + + + + + PreserveNewest + + + + + + + + + + + + diff --git a/src/TensorFlowText/README.md b/src/TensorFlowNET.Text/README.md similarity index 100% rename from src/TensorFlowText/README.md rename to src/TensorFlowNET.Text/README.md diff --git a/src/TensorFlowText/TensorFlowText.csproj b/src/TensorFlowNET.Text/TensorFlowNET.Text.csproj similarity index 91% rename from src/TensorFlowText/TensorFlowText.csproj rename to src/TensorFlowNET.Text/TensorFlowNET.Text.csproj index 92fee8a8..89ec8dd4 100644 --- a/src/TensorFlowText/TensorFlowText.csproj +++ b/src/TensorFlowNET.Text/TensorFlowNET.Text.csproj @@ -15,6 +15,8 @@ TensorFlow, SciSharp git https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + TensorFlow.Text + TensorFlow.Text diff --git a/src/TensorFlowText/Tokenizer.cs b/src/TensorFlowNET.Text/Tokenizer.cs similarity index 100% rename from src/TensorFlowText/Tokenizer.cs rename to src/TensorFlowNET.Text/Tokenizer.cs diff --git a/src/TensorFlowNET.Visualization/Controllers/ValuesController.cs b/src/TensorFlowNET.Visualization/Controllers/ValuesController.cs deleted file mode 100644 index 37089ab5..00000000 --- a/src/TensorFlowNET.Visualization/Controllers/ValuesController.cs +++ /dev/null @@ -1,45 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Mvc; - -namespace TensorFlowNET.Visualization.Controllers -{ - [Route("api/[controller]")] - [ApiController] - public class ValuesController : ControllerBase - { - // GET api/values - [HttpGet] - public ActionResult> Get() - { - return new string[] { "value1", "value2" }; - } - - // GET api/values/5 - [HttpGet("{id}")] - public ActionResult Get(int id) - { - return "value"; - } - - // POST api/values - [HttpPost] - public void Post([FromBody] string value) - { - } - - // PUT api/values/5 - [HttpPut("{id}")] - public void Put(int id, [FromBody] string value) - { - } - - // DELETE api/values/5 - [HttpDelete("{id}")] - public void Delete(int id) - { - } - } -} diff --git a/src/TensorFlowNET.Visualization/Program.cs b/src/TensorFlowNET.Visualization/Program.cs deleted file mode 100644 index de1def2e..00000000 --- a/src/TensorFlowNET.Visualization/Program.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.AspNetCore; -using Microsoft.AspNetCore.Hosting; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Logging; - -namespace TensorFlowNET.Visualization -{ - public class Program - { - public static void Main(string[] args) - { - CreateWebHostBuilder(args).Build().Run(); - } - - public static IWebHostBuilder CreateWebHostBuilder(string[] args) => - WebHost.CreateDefaultBuilder(args) - .UseStartup(); - } -} diff --git a/src/TensorFlowNET.Visualization/Startup.cs b/src/TensorFlowNET.Visualization/Startup.cs deleted file mode 100644 index c668b596..00000000 --- a/src/TensorFlowNET.Visualization/Startup.cs +++ /dev/null @@ -1,41 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Mvc; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; - -namespace TensorFlowNET.Visualization -{ - public class Startup - { - public Startup(IConfiguration configuration) - { - Configuration = configuration; - } - - public IConfiguration Configuration { get; } - - // This method gets called by the runtime. Use this method to add services to the container. - public void ConfigureServices(IServiceCollection services) - { - services.AddMvc().SetCompatibilityVersion(CompatibilityVersion.Version_2_2); - } - - // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. - public void Configure(IApplicationBuilder app, IHostingEnvironment env) - { - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - - app.UseMvc(); - } - } -} diff --git a/src/TensorFlowNET.Visualization/TensorFlowNET.Visualization.csproj b/src/TensorFlowNET.Visualization/TensorFlowNET.Visualization.csproj deleted file mode 100644 index 81860dad..00000000 --- a/src/TensorFlowNET.Visualization/TensorFlowNET.Visualization.csproj +++ /dev/null @@ -1,17 +0,0 @@ - - - - netcoreapp2.1 - InProcess - - - - - - - - - - - - diff --git a/src/TensorFlowNET.Visualization/appsettings.Development.json b/src/TensorFlowNET.Visualization/appsettings.Development.json deleted file mode 100644 index e203e940..00000000 --- a/src/TensorFlowNET.Visualization/appsettings.Development.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "Logging": { - "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" - } - } -} diff --git a/src/TensorFlowNET.Visualization/appsettings.json b/src/TensorFlowNET.Visualization/appsettings.json deleted file mode 100644 index def9159a..00000000 --- a/src/TensorFlowNET.Visualization/appsettings.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "Logging": { - "LogLevel": { - "Default": "Warning" - } - }, - "AllowedHosts": "*" -} diff --git a/src/TensorFlowNet.Benchmarks/Program.cs b/src/TensorFlowNet.Benchmarks/Program.cs deleted file mode 100644 index ea7c2bde..00000000 --- a/src/TensorFlowNet.Benchmarks/Program.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System; -using System.Reflection; -using BenchmarkDotNet.Configs; -using BenchmarkDotNet.Running; - -namespace TensorFlowBenchmark -{ - class Program - { - static void Main(string[] args) - { - if (args?.Length > 0) - { - for (int i = 0; i < args.Length; i++) - { - string name = $"TensorFlowBenchmark.{args[i]}"; - var type = Type.GetType(name); - BenchmarkRunner.Run(type); - } - } - else - { - BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, ManualConfig.Create(DefaultConfig.Instance).With(ConfigOptions.DisableOptimizationsValidator)); - } - - Console.ReadLine(); - } - } -} diff --git a/src/TensorFlowNet.Benchmarks/TensorBenchmark.cs b/src/TensorFlowNet.Benchmarks/TensorBenchmark.cs deleted file mode 100644 index d9386b99..00000000 --- a/src/TensorFlowNet.Benchmarks/TensorBenchmark.cs +++ /dev/null @@ -1,88 +0,0 @@ -using System; -using BenchmarkDotNet.Attributes; -using NumSharp; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowBenchmark -{ - [SimpleJob(launchCount: 1, warmupCount: 2, targetCount: 10)] - [MinColumn, MaxColumn, MeanColumn, MedianColumn] - public class TensorBenchmark - { - private double[] data; - - [GlobalSetup] - public void Setup() - { - data = new double[100]; - } - - [Benchmark] - public void ScalarTensor() - { - var g = new Graph(); - for (int i = 0; i < 100; i++) - { - using (var tensor = new Tensor(17.0)) - { - - } - } - } - - [Benchmark] - public unsafe void TensorFromFixedPtr() - { - var g = new Graph(); - for (int i = 0; i < 100; i++) - { - fixed (double* ptr = &data[0]) - { - using (var t = new Tensor((IntPtr)ptr, new long[] { data.Length }, tf.float64, 8 * data.Length)) - { - } - } - } - } - - [Benchmark] - public void TensorFromArray() - { - var g=new Graph(); - for (int i = 0; i < 100; i++) - { - using (var tensor = new Tensor(data)) - { - - } - } - } - - - [Benchmark] - public void TensorFromNDArray() - { - var g = new Graph(); - for (int i = 0; i < 1000; i++) - { - using (var tensor = new Tensor(new NDArray(data))) - { - - } - } - } - - //[Benchmark] - //public void Constant() - //{ - // for (int i = 0; i < 100; i++) - // { - // //var tensor = new Tensor(new NDArray(data)); - // var c = tf.constant(42.0); - // } - //} - - } -} - diff --git a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj b/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj deleted file mode 100644 index 4618f06b..00000000 --- a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj +++ /dev/null @@ -1,33 +0,0 @@ - - - - Exe - netcoreapp2.2 - true - TensorFlowBenchmark - TensorFlowBenchmark - 7.3 - - - - true - - - - true - - - - - - - - - - - - - - - - diff --git a/src/TensorFlowNet.Benchmarks/TensorFlowNET.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/TensorFlowNET.Benchmark.csproj deleted file mode 100644 index a0af6db4..00000000 --- a/src/TensorFlowNet.Benchmarks/TensorFlowNET.Benchmark.csproj +++ /dev/null @@ -1,20 +0,0 @@ - - - - Exe - netcoreapp2.2 - - - - - - - - - - - - - - - diff --git a/src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs b/src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs deleted file mode 100644 index 5b3a0cd3..00000000 --- a/src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs +++ /dev/null @@ -1,76 +0,0 @@ -using System; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using BenchmarkDotNet.Attributes; -using Google.Protobuf.WellKnownTypes; -using NumSharp; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowBenchmark.Unmanaged -{ - public struct UnmanagedStruct - { - public int a; - public long b; - public UnmanagedStruct(int _) - { - a = 2; - b = 3; - } - } - - [SimpleJob(launchCount: 1, warmupCount: 2, targetCount: 10)] - [MinColumn, MaxColumn, MeanColumn, MedianColumn] - public unsafe class StructCastBenchmark - { - private static void EnsureIsUnmanaged(T _) where T : unmanaged - { } - - static StructCastBenchmark() //if UnmanagedStruct is not unmanaged struct then this will fail to compile. - => EnsureIsUnmanaged(new UnmanagedStruct()); - - private IntPtr data; - private void* dataptr; - - [GlobalSetup] - public void Setup() - { - data = Marshal.AllocHGlobal(Marshal.SizeOf()); - dataptr = data.ToPointer(); - } - - [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] - public void Marshal_PtrToStructure() - { - UnmanagedStruct _; - for (int i = 0; i < 10000; i++) - { - _ = Marshal.PtrToStructure(data); - } - } - - [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] - public void PointerCast() - { - var dptr = dataptr; - UnmanagedStruct _; - for (int i = 0; i < 10000; i++) - { - _ = *(UnmanagedStruct*) dptr; - } - } - - [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] - public void Unsafe_Read() - { - var dptr = dataptr; - UnmanagedStruct _; - for (int i = 0; i < 10000; i++) - { - _ = Unsafe.Read(dptr); - } - } - - } -} \ No newline at end of file diff --git a/tensorflowlib/runtimes/win-x64/native/tensorflow.zip b/tensorflowlib/runtimes/win-x64/native/tensorflow.zip deleted file mode 100644 index add0a05c..00000000 Binary files a/tensorflowlib/runtimes/win-x64/native/tensorflow.zip and /dev/null differ diff --git a/test/TensorFlowNET.Examples.FSharp/FunctionApproximation.fs b/test/TensorFlowNET.Examples.FSharp/FunctionApproximation.fs deleted file mode 100644 index 9b32dc63..00000000 --- a/test/TensorFlowNET.Examples.FSharp/FunctionApproximation.fs +++ /dev/null @@ -1,104 +0,0 @@ -module FunctionApproximation - -//reduced example from https://github.com/tirthajyoti/Machine-Learning-with-Python/blob/master/Function%20Approximation%20by%20Neural%20Network/Function%20approximation%20by%20linear%20model%20and%20deep%20network.ipynb - -open NumSharp -open Tensorflow -open System - -let run()= - - let N_points = 75 // Number of points for constructing function - let x_min = 1.0 // Min of the range of x (feature) - let x_max = 15.0 // Max of the range of x (feature) - let noise_mean = 0.0 // Mean of the Gaussian noise adder - let noise_sd = 10.0 // Std.Dev of the Gaussian noise adder - - let linspace points = [| for i in 0 .. (points - 1) -> x_min + (x_max - x_min)/(float)points * (float)i |] - - let func_trans(xAr:float []) = - xAr - |>Array.map (fun (x:float) -> (20.0 * x+3.0 * System.Math.Pow(x,2.0)+0.1 * System.Math.Pow(x,3.0))*sin(x)*exp(-0.1*x)) - - let X_raw = linspace N_points - let Y_raw = func_trans(X_raw) - let X_mtr = Array2D.init X_raw.Length 1 (fun i j -> X_raw.[i]) - let X = np.array(X_mtr) - - let noise_x = np.random.normal(noise_mean,noise_sd,N_points) - let y = np.array(Y_raw)+noise_x - - let X_train = X - let y_train = y - - let learning_rate = 0.00001 - let training_epochs = 35000 - - let n_input = 1 // Number of features - let n_output = 1 // Regression output is a number only - let n_hidden_layer_1 = 25 // Hidden layer 1 - let n_hidden_layer_2 = 25 // Hidden layer 2 - - let tf = Binding.New() - let x = tf.placeholder(tf.float64, new TensorShape(N_points,n_input)) - let y = tf.placeholder(tf.float64, new TensorShape(n_output)) - - - let weights = dict[ - "hidden_layer_1", tf.Variable(tf.random_normal([|n_input; n_hidden_layer_1|],dtype=tf.float64)) - "hidden_layer_2", tf.Variable(tf.random_normal([|n_hidden_layer_1; n_hidden_layer_2|],dtype=tf.float64)) - "out", tf.Variable(tf.random_normal([|n_hidden_layer_2; n_output|],dtype=tf.float64)) - ] - let biases = dict[ - "hidden_layer_1", tf.Variable(tf.random_normal([|n_hidden_layer_1|],dtype=tf.float64)) - "hidden_layer_2", tf.Variable(tf.random_normal([|n_hidden_layer_2|],dtype=tf.float64)) - "out", tf.Variable(tf.random_normal([|n_output|],dtype=tf.float64)) - ] - - - // Hidden layer with RELU activation - - let layer_1 = tf.add(tf.matmul(x, weights.["hidden_layer_1"]._AsTensor()),biases.["hidden_layer_1"]) - let layer_1 = tf.nn.relu(layer_1) - - let layer_2 = tf.add(tf.matmul(layer_1, weights.["hidden_layer_2"]._AsTensor()),biases.["hidden_layer_2"]) - let layer_2 = tf.nn.relu(layer_2) - - // Output layer with linear activation - let ops = tf.add(tf.matmul(layer_2, weights.["out"]._AsTensor()), biases.["out"]) - - // Define loss and optimizer - let cost = tf.reduce_mean(tf.square(tf.squeeze(ops)-y)) - - let gs = tf.Variable(1, trainable= false, name= "global_step") - - let optimizer = tf.train.GradientDescentOptimizer(learning_rate=(float32)learning_rate).minimize(cost,global_step = gs) - - let init = tf.global_variables_initializer() - - - Tensorflow.Binding.``tf_with``(tf.Session(), fun (sess:Session) -> - sess.run(init) |> ignore - // Loop over epochs - for epoch in [0..training_epochs] do - // Run optimization process (backprop) and cost function (to get loss value) - - let result=sess.run([|optimizer:>ITensorOrOperation; gs._AsTensor():>ITensorOrOperation; cost:>ITensorOrOperation|], new FeedItem(x, X_train), new FeedItem(y, y_train)) - - - let loss_value = (double) result.[2]; - - let step = (int) result.[1]; - - if epoch % 1000 = 0 then - sprintf "Step %d loss: %f" step loss_value |> Console.WriteLine - let w=sess.run(weights |> Array.ofSeq |> Array.map (fun pair -> pair.Value)) - let b = sess.run(biases |> Array.ofSeq |> Array.map (fun pair -> pair.Value)) - let yhat=sess.run([|ops:>ITensorOrOperation|],new FeedItem(x,X_train)) - for i in [0..(N_points-1)] do - sprintf "pred %f real: %f" ((double)(yhat.[0].[i].[0])) ((double)Y_raw.[i]) |> Console.WriteLine - ) - - - - diff --git a/test/TensorFlowNET.Examples.FSharp/Program.fs b/test/TensorFlowNET.Examples.FSharp/Program.fs deleted file mode 100644 index 3cbe7ea9..00000000 --- a/test/TensorFlowNET.Examples.FSharp/Program.fs +++ /dev/null @@ -1,8 +0,0 @@ -// Learn more about F# at http://fsharp.org - -open System - -[] -let main argv = - FunctionApproximation.run() - 0 // return an integer exit code diff --git a/test/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj b/test/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj deleted file mode 100644 index b509f678..00000000 --- a/test/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj +++ /dev/null @@ -1,21 +0,0 @@ - - - - Exe - netcoreapp2.2 - - - - - - - - - - - - - - - - diff --git a/test/TensorFlowNET.Examples/AudioProcessing/README.md b/test/TensorFlowNET.Examples/AudioProcessing/README.md deleted file mode 100644 index 5f282702..00000000 --- a/test/TensorFlowNET.Examples/AudioProcessing/README.md +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/test/TensorFlowNET.Examples/BasicEagerApi.cs b/test/TensorFlowNET.Examples/BasicEagerApi.cs deleted file mode 100644 index ac6e6e82..00000000 --- a/test/TensorFlowNET.Examples/BasicEagerApi.cs +++ /dev/null @@ -1,73 +0,0 @@ -using System; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Basic introduction to TensorFlow's Eager API. - /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/basic_eager_api.py - /// - public class BasicEagerApi : IExample - { - public bool Enabled { get; set; } = false; - public string Name => "Basic Eager"; - public bool IsImportingGraph { get; set; } = false; - - private Tensor a, b, c, d; - - public bool Run() - { - // Set Eager API - Console.WriteLine("Setting Eager mode..."); - tf.enable_eager_execution(); - - // Define constant tensors - Console.WriteLine("Define constant tensors"); - a = tf.constant(2); - Console.WriteLine($"a = {a}"); - b = tf.constant(3); - Console.WriteLine($"b = {b}"); - - // Run the operation without the need for tf.Session - Console.WriteLine("Running operations, without tf.Session"); - c = a + b; - Console.WriteLine($"a + b = {c}"); - d = a * b; - Console.WriteLine($"a * b = {d}"); - - // Full compatibility with Numpy - - return true; - } - - public void PrepareData() - { - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs b/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs deleted file mode 100644 index ecd9e27d..00000000 --- a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs +++ /dev/null @@ -1,185 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.Diagnostics; -using Tensorflow; -using Tensorflow.Hub; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Implement K-Means algorithm with TensorFlow.NET, and apply it to classify - /// handwritten digit images. - /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/kmeans.py - /// - public class KMeansClustering : IExample - { - public bool Enabled { get; set; } = false; - public string Name => "K-means Clustering"; - public bool IsImportingGraph { get; set; } = true; - - public int? train_size = null; - public int validation_size = 5000; - public int? test_size = null; - public int batch_size = 1024; // The number of samples per batch - - Datasets mnist; - NDArray full_data_x; - int num_steps = 20; // Total steps to train - int k = 25; // The number of clusters - int num_classes = 10; // The 10 digits - int num_features = 784; // Each image is 28x28 pixels - - float accuray_test = 0f; - - public bool Run() - { - PrepareData(); - var graph = ImportGraph(); - using (var sess = tf.Session(graph)) - { - Train(sess); - } - - return accuray_test > 0.70; - } - - public void PrepareData() - { - var loader = new MnistModelLoader(); - - var setting = new ModelLoadSetting - { - TrainDir = ".resources/mnist", - OneHot = true, - TrainSize = train_size, - ValidationSize = validation_size, - TestSize = test_size, - ShowProgressInConsole = true - }; - - mnist = loader.LoadAsync(setting).Result; - - full_data_x = mnist.Train.Data; - - // download graph meta data - string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta"; - loader.DownloadAsync(url, ".resources/graph", "kmeans.meta").Wait(); - } - - public Graph ImportGraph() - { - var graph = tf.Graph().as_default(); - - tf.train.import_meta_graph(".resources/graph/kmeans.meta"); - - return graph; - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - var graph = tf.Graph(); - - // Input images - Tensor X = graph.get_operation_by_name("Placeholder"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features)); - // Labels (for assigning a label to a centroid and testing) - Tensor Y = graph.get_operation_by_name("Placeholder_1"); // tf.placeholder(tf.float32, shape: new TensorShape(-1, num_classes)); - - // K-Means Parameters - //var kmeans = new KMeans(X, k, distance_metric: KMeans.COSINE_DISTANCE, use_mini_batch: true); - - // Build KMeans graph - //var training_graph = kmeans.training_graph(); - - var init_vars = tf.global_variables_initializer(); - Tensor init_op = graph.get_operation_by_name("cond/Merge"); - var train_op = graph.get_operation_by_name("group_deps"); - Tensor avg_distance = graph.get_operation_by_name("Mean"); - Tensor cluster_idx = graph.get_operation_by_name("Squeeze_1"); - NDArray result = null; - - sess.run(init_vars, new FeedItem(X, full_data_x)); - sess.run(init_op, new FeedItem(X, full_data_x)); - - // Training - var sw = new Stopwatch(); - - foreach (var i in range(1, num_steps + 1)) - { - sw.Restart(); - result = sess.run(new ITensorOrOperation[] { train_op, avg_distance, cluster_idx }, new FeedItem(X, full_data_x)); - sw.Stop(); - - if (i % 4 == 0 || i == 1) - print($"Step {i}, Avg Distance: {result[1]} Elapse: {sw.ElapsedMilliseconds}ms"); - } - - var idx = result[2].Data(); - - // Assign a label to each centroid - // Count total number of labels per centroid, using the label of each training - // sample to their closest centroid (given by 'idx') - var counts = np.zeros((k, num_classes), np.float32); - - sw.Start(); - foreach (var i in range(idx.Count)) - { - var x = mnist.Train.Labels[i]; - counts[idx[i]] += x; - } - - sw.Stop(); - print($"Assign a label to each centroid took {sw.ElapsedMilliseconds}ms"); - - // Assign the most frequent label to the centroid - var labels_map_array = np.argmax(counts, 1); - var labels_map = tf.convert_to_tensor(labels_map_array); - - // Evaluation ops - // Lookup: centroid_id -> label - var cluster_label = tf.nn.embedding_lookup(labels_map, cluster_idx); - - // Compute accuracy - var correct_prediction = tf.equal(cluster_label, tf.cast(tf.argmax(Y, 1), tf.int32)); - var cast = tf.cast(correct_prediction, tf.float32); - var accuracy_op = tf.reduce_mean(cast); - - // Test Model - var (test_x, test_y) = (mnist.Test.Data, mnist.Test.Labels); - result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y)); - accuray_test = result; - print($"Test Accuracy: {accuray_test}"); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs deleted file mode 100644 index 9b33b28f..00000000 --- a/test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs +++ /dev/null @@ -1,145 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// A linear regression learning algorithm example using TensorFlow library. - /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py - /// - public class LinearRegression : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "Linear Regression"; - public bool IsImportingGraph { get; set; } = false; - - public int training_epochs = 1000; - - // Parameters - float learning_rate = 0.01f; - int display_step = 50; - - NumPyRandom rng = np.random; - NDArray train_X, train_Y; - int n_samples; - - public bool Run() - { - // Training Data - PrepareData(); - - // tf Graph Input - var X = tf.placeholder(tf.float32); - var Y = tf.placeholder(tf.float32); - - // Set model weights - // We can set a fixed init value in order to debug - // var rnd1 = rng.randn(); - // var rnd2 = rng.randn(); - var W = tf.Variable(-0.06f, name: "weight"); - var b = tf.Variable(-0.73f, name: "bias"); - - // Construct a linear model - var pred = tf.add(tf.multiply(X, W), b); - - // Mean squared error - var cost = tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * n_samples); - - // Gradient descent - // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default - var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); - - // Initialize the variables (i.e. assign their default value) - var init = tf.global_variables_initializer(); - - // Start training - using (var sess = tf.Session()) - { - // Run the initializer - sess.run(init); - - // Fit all training data - for (int epoch = 0; epoch < training_epochs; epoch++) - { - foreach (var (x, y) in zip(train_X, train_Y)) - sess.run(optimizer, (X, x), (Y, y)); - - // Display logs per epoch step - if ((epoch + 1) % display_step == 0) - { - var c = sess.run(cost, (X, train_X), (Y, train_Y)); - Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); - } - } - - Console.WriteLine("Optimization Finished!"); - var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); - Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); - - // Testing example - var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); - var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); - Console.WriteLine("Testing... (Mean square loss Comparison)"); - var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), - (X, test_X), (Y, test_Y)); - Console.WriteLine($"Testing cost={testing_cost}"); - var diff = Math.Abs((float)training_cost - (float)testing_cost); - Console.WriteLine($"Absolute mean square loss difference: {diff}"); - - return diff < 0.01; - } - } - - public void PrepareData() - { - train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, - 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f); - train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, - 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); - n_samples = train_X.shape[0]; - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs deleted file mode 100644 index 3116e6f4..00000000 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ /dev/null @@ -1,190 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.Diagnostics; -using System.IO; -using Tensorflow; -using Tensorflow.Hub; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// A logistic regression learning algorithm example using TensorFlow library. - /// This example is using the MNIST database of handwritten digits - /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py - /// - public class LogisticRegression : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "Logistic Regression"; - public bool IsImportingGraph { get; set; } = false; - - - public int training_epochs = 10; - public int? train_size = null; - public int validation_size = 5000; - public int? test_size = null; - public int batch_size = 100; - - private float learning_rate = 0.01f; - private int display_step = 1; - - Datasets mnist; - - public bool Run() - { - PrepareData(); - - // tf Graph Input - var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784 - var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes - - // Set model weights - var W = tf.Variable(tf.zeros(new Shape(784, 10))); - var b = tf.Variable(tf.zeros(new Shape(10))); - - // Construct model - var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax - - // Minimize error using cross entropy - var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1)); - - // Gradient Descent - var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); - - // Initialize the variables (i.e. assign their default value) - var init = tf.global_variables_initializer(); - - var sw = new Stopwatch(); - - using (var sess = tf.Session()) - { - // Run the initializer - sess.run(init); - - // Training cycle - foreach (var epoch in range(training_epochs)) - { - sw.Start(); - - var avg_cost = 0.0f; - var total_batch = mnist.Train.NumOfExamples / batch_size; - // Loop over all batches - foreach (var i in range(total_batch)) - { - var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(batch_size); - // Run optimization op (backprop) and cost op (to get loss value) - (_, float c) = sess.run((optimizer, cost), - (x, batch_xs), - (y, batch_ys)); - - // Compute average loss - avg_cost += c / total_batch; - } - - sw.Stop(); - - // Display logs per epoch step - if ((epoch + 1) % display_step == 0) - print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); - - sw.Reset(); - } - - print("Optimization Finished!"); - // SaveModel(sess); - - // Test model - var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); - // Calculate accuracy - var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); - float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); - print($"Accuracy: {acc:F4}"); - - return acc > 0.9; - } - } - - public void PrepareData() - { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size, showProgressInConsole: true).Result; - } - - public void SaveModel(Session sess) - { - var saver = tf.train.Saver(); - var save_path = saver.save(sess, ".resources/logistic_regression/model.ckpt"); - tf.train.write_graph(sess.graph, ".resources/logistic_regression", "model.pbtxt", as_text: true); - - FreezeGraph.freeze_graph(input_graph: ".resources/logistic_regression/model.pbtxt", - input_saver: "", - input_binary: false, - input_checkpoint: ".resources/logistic_regression/model.ckpt", - output_node_names: "Softmax", - restore_op_name: "save/restore_all", - filename_tensor_name: "save/Const:0", - output_graph: ".resources/logistic_regression/model.pb", - clear_devices: true, - initializer_nodes: ""); - } - - public void Predict(Session sess) - { - var graph = new Graph().as_default(); - graph.Import(Path.Join(".resources/logistic_regression", "model.pb")); - - // restoring the model - // var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta"); - // saver.restore(sess, tf.train.latest_checkpoint('logistic_regression')); - var pred = graph.OperationByName("Softmax"); - var output = pred.outputs[0]; - var x = graph.OperationByName("Placeholder"); - var input = x.outputs[0]; - - // predict - var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(10); - var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)])); - - if (results[0].argmax() == (batch_ys[0] as NDArray).argmax()) - print("predicted OK!"); - else - throw new ValueError("predict error, should be 90% accuracy"); - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs b/test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs deleted file mode 100644 index fde1653d..00000000 --- a/test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs +++ /dev/null @@ -1,221 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.Collections.Generic; -using Tensorflow; -using NumSharp; -using static Tensorflow.Binding; -using System.IO; -using TensorFlowNET.Examples.Utility; - -namespace TensorFlowNET.Examples -{ - /// - /// https://github.com/nicolov/naive_bayes_tensorflow - /// - public class NaiveBayesClassifier : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "Naive Bayes Classifier"; - public bool IsImportingGraph { get; set; } = false; - - public NDArray X, y; - public Normal dist { get; set; } - public bool Run() - { - PrepareData(); - - fit(X, y); - - // Create a regular grid and classify each point - float x_min = X.amin(0).Data()[0] - 0.5f; - float y_min = X.amin(0).Data()[1] - 0.5f; - float x_max = X.amax(0).Data()[1] + 0.5f; - float y_max = X.amax(0).Data()[1] + 0.5f; - - var (xx, yy) = np.meshgrid(np.linspace(x_min, x_max, 30), np.linspace(y_min, y_max, 30)); - using (var sess = tf.Session()) - { - //var samples = np.vstack(xx.ravel(), yy.ravel()); - //samples = np.transpose(samples); - var array = np.Load(Path.Join("nb", "nb_example.npy")); - var samples = np.array(array).astype(np.float32); - var Z = sess.run(predict(samples)); - } - - return true; - } - - public void fit(NDArray X, NDArray y) - { - var unique_y = np.unique(y); - - var dic = new Dictionary>>(); - // Init uy in dic - foreach (int uy in unique_y.Data()) - { - dic.Add(uy, new List>()); - } - // Separate training points by class - // Shape : nb_classes * nb_samples * nb_features - int maxCount = 0; - for (int i = 0; i < y.size; i++) - { - var curClass = y[i]; - var l = dic[curClass]; - var pair = new List(); - pair.Add(X[i,0]); - pair.Add(X[i, 1]); - l.Add(pair); - if (l.Count > maxCount) - { - maxCount = l.Count; - } - dic[curClass] = l; - } - float[,,] points = new float[dic.Count, maxCount, X.shape[1]]; - foreach (KeyValuePair>> kv in dic) - { - int j = (int) kv.Key; - for (int i = 0; i < maxCount; i++) - { - for (int k = 0; k < X.shape[1]; k++) - { - points[j, i, k] = kv.Value[i][k]; - } - } - - } - var points_by_class = np.array(points); - // estimate mean and variance for each class / feature - // shape : nb_classes * nb_features - var cons = tf.constant(points_by_class); - var tup = tf.nn.moments(cons, new int[]{1}); - var mean = tup.Item1; - var variance = tup.Item2; - - // Create a 3x2 univariate normal distribution with the - // Known mean and variance - var dist = tf.distributions.Normal(mean, tf.sqrt(variance)); - this.dist = dist; - } - - public Tensor predict(NDArray X) - { - if (dist == null) - { - throw new ArgumentNullException("cant not find the model (normal distribution)!"); - } - int nb_classes = (int) dist.scale().shape[0]; - int nb_features = (int)dist.scale().shape[1]; - - // Conditional probabilities log P(x|c) with shape - // (nb_samples, nb_classes) - var t1= ops.convert_to_tensor(X, TF_DataType.TF_FLOAT); - var t2 = ops.convert_to_tensor(new int[] { 1, nb_classes }); - Tensor tile = tf.tile(t1, t2); - var t3 = ops.convert_to_tensor(new int[] { -1, nb_classes, nb_features }); - Tensor r = tf.reshape(tile, t3); - var cond_probs = tf.reduce_sum(dist.log_prob(r), 2); - // uniform priors - float[] tem = new float[nb_classes]; - for (int i = 0; i < tem.Length; i++) - { - tem[i] = 1.0f / nb_classes; - } - var priors = np.log(np.array(tem)); - - // posterior log probability, log P(c) + log P(x|c) - var joint_likelihood = tf.add(ops.convert_to_tensor(priors, TF_DataType.TF_FLOAT), cond_probs); - // normalize to get (log)-probabilities - - var norm_factor = tf.reduce_logsumexp(joint_likelihood, new int[] { 1 }, keepdims: true); - var log_prob = joint_likelihood - norm_factor; - // exp to get the actual probabilities - return tf.exp(log_prob); - } - - public void PrepareData() - { - #region Training data - X = np.array(new float[,] { - {5.1f, 3.5f}, {4.9f, 3.0f}, {4.7f, 3.2f}, {4.6f, 3.1f}, {5.0f, 3.6f}, {5.4f, 3.9f}, - {4.6f, 3.4f}, {5.0f, 3.4f}, {4.4f, 2.9f}, {4.9f, 3.1f}, {5.4f, 3.7f}, {4.8f, 3.4f}, - {4.8f, 3.0f}, {4.3f, 3.0f}, {5.8f, 4.0f}, {5.7f, 4.4f}, {5.4f, 3.9f}, {5.1f, 3.5f}, - {5.7f, 3.8f}, {5.1f, 3.8f}, {5.4f, 3.4f}, {5.1f, 3.7f}, {5.1f, 3.3f}, {4.8f, 3.4f}, - {5.0f, 3.0f}, {5.0f, 3.4f}, {5.2f, 3.5f}, {5.2f, 3.4f}, {4.7f, 3.2f}, {4.8f, 3.1f}, - {5.4f, 3.4f}, {5.2f, 4.1f}, {5.5f, 4.2f}, {4.9f, 3.1f}, {5.0f, 3.2f}, {5.5f, 3.5f}, - {4.9f, 3.6f}, {4.4f, 3.0f}, {5.1f, 3.4f}, {5.0f, 3.5f}, {4.5f, 2.3f}, {4.4f, 3.2f}, - {5.0f, 3.5f}, {5.1f, 3.8f}, {4.8f, 3.0f}, {5.1f, 3.8f}, {4.6f, 3.2f}, {5.3f, 3.7f}, - {5.0f, 3.3f}, {7.0f, 3.2f}, {6.4f, 3.2f}, {6.9f, 3.1f}, {5.5f, 2.3f}, {6.5f, 2.8f}, - {5.7f, 2.8f}, {6.3f, 3.3f}, {4.9f, 2.4f}, {6.6f, 2.9f}, {5.2f, 2.7f}, {5.0f, 2.0f}, - {5.9f, 3.0f}, {6.0f, 2.2f}, {6.1f, 2.9f}, {5.6f, 2.9f}, {6.7f, 3.1f}, {5.6f, 3.0f}, - {5.8f, 2.7f}, {6.2f, 2.2f}, {5.6f, 2.5f}, {5.9f, 3.0f}, {6.1f, 2.8f}, {6.3f, 2.5f}, - {6.1f, 2.8f}, {6.4f, 2.9f}, {6.6f, 3.0f}, {6.8f, 2.8f}, {6.7f, 3.0f}, {6.0f, 2.9f}, - {5.7f, 2.6f}, {5.5f, 2.4f}, {5.5f, 2.4f}, {5.8f, 2.7f}, {6.0f, 2.7f}, {5.4f, 3.0f}, - {6.0f, 3.4f}, {6.7f, 3.1f}, {6.3f, 2.3f}, {5.6f, 3.0f}, {5.5f, 2.5f}, {5.5f, 2.6f}, - {6.1f, 3.0f}, {5.8f, 2.6f}, {5.0f, 2.3f}, {5.6f, 2.7f}, {5.7f, 3.0f}, {5.7f, 2.9f}, - {6.2f, 2.9f}, {5.1f, 2.5f}, {5.7f, 2.8f}, {6.3f, 3.3f}, {5.8f, 2.7f}, {7.1f, 3.0f}, - {6.3f, 2.9f}, {6.5f, 3.0f}, {7.6f, 3.0f}, {4.9f, 2.5f}, {7.3f, 2.9f}, {6.7f, 2.5f}, - {7.2f, 3.6f}, {6.5f, 3.2f}, {6.4f, 2.7f}, {6.8f, 3.0f}, {5.7f, 2.5f}, {5.8f, 2.8f}, - {6.4f, 3.2f}, {6.5f, 3.0f}, {7.7f, 3.8f}, {7.7f, 2.6f}, {6.0f, 2.2f}, {6.9f, 3.2f}, - {5.6f, 2.8f}, {7.7f, 2.8f}, {6.3f, 2.7f}, {6.7f, 3.3f}, {7.2f, 3.2f}, {6.2f, 2.8f}, - {6.1f, 3.0f}, {6.4f, 2.8f}, {7.2f, 3.0f}, {7.4f, 2.8f}, {7.9f, 3.8f}, {6.4f, 2.8f}, - {6.3f, 2.8f}, {6.1f, 2.6f}, {7.7f, 3.0f}, {6.3f, 3.4f}, {6.4f, 3.1f}, {6.0f, 3.0f}, - {6.9f, 3.1f}, {6.7f, 3.1f}, {6.9f, 3.1f}, {5.8f, 2.7f}, {6.8f, 3.2f}, {6.7f, 3.3f}, - {6.7f, 3.0f}, {6.3f, 2.5f}, {6.5f, 3.0f}, {6.2f, 3.4f}, {5.9f, 3.0f}, {5.8f, 3.0f}}); - - y = np.array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2); - - - string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/nb_example.npy"; - Web.Download(url, "nb", "nb_example.npy"); - #endregion - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs deleted file mode 100644 index 22607c3d..00000000 --- a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs +++ /dev/null @@ -1,118 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using Tensorflow; -using Tensorflow.Hub; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// A nearest neighbor learning algorithm example - /// This example is using the MNIST database of handwritten digits - /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/nearest_neighbor.py - /// - public class NearestNeighbor : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "Nearest Neighbor"; - Datasets mnist; - NDArray Xtr, Ytr, Xte, Yte; - public int? TrainSize = null; - public int ValidationSize = 5000; - public int? TestSize = null; - public bool IsImportingGraph { get; set; } = false; - - - public bool Run() - { - // tf Graph Input - var xtr = tf.placeholder(tf.float32, new TensorShape(-1, 784)); - var xte = tf.placeholder(tf.float32, new TensorShape(784)); - - // Nearest Neighbor calculation using L1 Distance - // Calculate L1 Distance - var distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices: 1); - // Prediction: Get min distance index (Nearest neighbor) - var pred = tf.arg_min(distance, 0); - - float accuracy = 0f; - // Initialize the variables (i.e. assign their default value) - var init = tf.global_variables_initializer(); - using (var sess = tf.Session()) - { - // Run the initializer - sess.run(init); - - PrepareData(); - - foreach(int i in range(Xte.shape[0])) - { - // Get nearest neighbor - long nn_index = sess.run(pred, (xtr, Xtr), (xte, Xte[i])); - // Get nearest neighbor class label and compare it to its true label - int index = (int)nn_index; - - if (i % 10 == 0 || i == 0) - print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}"); - - // Calculate accuracy - if (np.argmax(Ytr[index]) == np.argmax(Yte[i])) - accuracy += 1f/ Xte.shape[0]; - } - - print($"Accuracy: {accuracy}"); - } - - return accuracy > 0.8; - } - - public void PrepareData() - { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize, showProgressInConsole: true).Result; - // In this example, we limit mnist data - (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) - (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs b/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs deleted file mode 100644 index e9bcf0cb..00000000 --- a/test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs +++ /dev/null @@ -1,190 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using NumSharp; -using Tensorflow; -using TensorFlowNET.Examples.Utility; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Simple vanilla neural net solving the famous XOR problem - /// https://github.com/amygdala/tensorflow-workshop/blob/master/workshop_sections/getting_started/xor/README.md - /// - public class NeuralNetXor : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "NN XOR"; - public bool IsImportingGraph { get; set; } = false; - - public int num_steps = 10000; - - private NDArray data; - - private (Operation, Tensor, Tensor) make_graph(Tensor features,Tensor labels, int num_hidden = 8) - { - var stddev = 1 / Math.Sqrt(2); - var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, seed:1, stddev: (float) stddev )); - - // Shape [4, num_hidden] - var hidden_activations = tf.nn.relu(tf.matmul(features, hidden_weights)); - - var output_weights = tf.Variable(tf.truncated_normal( - new[] {num_hidden, 1}, - seed: 17, - stddev: (float) (1 / Math.Sqrt(num_hidden)) - )); - - // Shape [4, 1] - var logits = tf.matmul(hidden_activations, output_weights); - - // Shape [4] - var predictions = tf.sigmoid(tf.squeeze(logits)); - var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)), name:"loss"); - - var gs = tf.Variable(0, trainable: false, name: "global_step"); - var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs); - - return (train_op, loss, gs); - } - - public bool Run() - { - PrepareData(); - float loss_value = 0; - if (IsImportingGraph) - loss_value = RunWithImportedGraph(); - else - loss_value = RunWithBuiltGraph(); - - return loss_value < 0.0628; - } - - private float RunWithImportedGraph() - { - var graph = tf.Graph().as_default(); - - tf.train.import_meta_graph("graph/xor.meta"); - - Tensor features = graph.get_operation_by_name("Placeholder"); - Tensor labels = graph.get_operation_by_name("Placeholder_1"); - Tensor loss = graph.get_operation_by_name("loss"); - Tensor train_op = graph.get_operation_by_name("train_op"); - Tensor global_step = graph.get_operation_by_name("global_step"); - - var init = tf.global_variables_initializer(); - float loss_value = 0; - // Start tf session - using (var sess = tf.Session(graph)) - { - sess.run(init); - var step = 0; - - var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32); - while (step < num_steps) - { - // original python: - //_, step, loss_value = sess.run( - // [train_op, gs, loss], - // feed_dict={features: xy, labels: y_} - // ) - (_, step, loss_value) = sess.run((train_op, global_step, loss), (features, data), (labels, y_)); - if (step == 1 || step % 1000 == 0) - Console.WriteLine($"Step {step} loss: {loss_value}"); - } - Console.WriteLine($"Final loss: {loss_value}"); - } - - return loss_value; - } - - private float RunWithBuiltGraph() - { - var graph = tf.Graph().as_default(); - - var features = tf.placeholder(tf.float32, new TensorShape(4, 2)); - var labels = tf.placeholder(tf.int32, new TensorShape(4)); - - var (train_op, loss, gs) = make_graph(features, labels); - - var init = tf.global_variables_initializer(); - - float loss_value = 0; - // Start tf session - using (var sess = tf.Session(graph)) - { - sess.run(init); - var step = 0; - - var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32); - while (step < num_steps) - { - (_, step, loss_value) = sess.run((train_op, gs, loss), (features, data), (labels, y_)); - if (step == 1 || step % 1000 == 0) - Console.WriteLine($"Step {step} loss: {loss_value}"); - } - Console.WriteLine($"Final loss: {loss_value}"); - } - - return loss_value; - } - - public void PrepareData() - { - data = new float[,] - { - {1, 0 }, - {1, 1 }, - {0, 0 }, - {0, 1 } - }; - - if (IsImportingGraph) - { - // download graph meta data - string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta"; - Web.Download(url, "graph", "xor.meta"); - } - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs deleted file mode 100644 index ff382009..00000000 --- a/test/TensorFlowNET.Examples/BasicOperations.cs +++ /dev/null @@ -1,187 +0,0 @@ -using NumSharp; -using System; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Basic Operations example using TensorFlow library. - /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/basic_operations.py - /// - public class BasicOperations : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "Basic Operations"; - public bool IsImportingGraph { get; set; } = false; - - private Session sess; - - public bool Run() - { - // Basic constant operations - // The value returned by the constructor represents the output - // of the Constant op. - var a = tf.constant(2); - var b = tf.constant(3); - - // Launch the default graph. - using (sess = tf.Session()) - { - Console.WriteLine("a=2, b=3"); - Console.WriteLine($"Addition with constants: {sess.run(a + b)}"); - Console.WriteLine($"Multiplication with constants: {sess.run(a * b)}"); - } - - // Basic Operations with variable as graph input - // The value returned by the constructor represents the output - // of the Variable op. (define as input when running session) - // tf Graph input - a = tf.placeholder(tf.int16); - b = tf.placeholder(tf.int16); - - // Define some operations - var add = tf.add(a, b); - var mul = tf.multiply(a, b); - - // Launch the default graph. - using(sess = tf.Session()) - { - var feed_dict = new FeedItem[] - { - new FeedItem(a, (short)2), - new FeedItem(b, (short)3) - }; - // Run every operation with variable input - Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); - Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); - } - - // ---------------- - // More in details: - // Matrix Multiplication from TensorFlow official tutorial - - // Create a Constant op that produces a 1x2 matrix. The op is - // added as a node to the default graph. - // - // The value returned by the constructor represents the output - // of the Constant op. - var nd1 = np.array(3, 3).reshape(1, 2); - var matrix1 = tf.constant(nd1); - - // Create another Constant that produces a 2x1 matrix. - var nd2 = np.array(2, 2).reshape(2, 1); - var matrix2 = tf.constant(nd2); - - // Create a Matmul op that takes 'matrix1' and 'matrix2' as inputs. - // The returned value, 'product', represents the result of the matrix - // multiplication. - var product = tf.matmul(matrix1, matrix2); - - // To run the matmul op we call the session 'run()' method, passing 'product' - // which represents the output of the matmul op. This indicates to the call - // that we want to get the output of the matmul op back. - // - // All inputs needed by the op are run automatically by the session. They - // typically are run in parallel. - // - // The call 'run(product)' thus causes the execution of threes ops in the - // graph: the two constants and matmul. - // - // The output of the op is returned in 'result' as a numpy `ndarray` object. - using (sess = tf.Session()) - { - var result = sess.run(product); - Console.WriteLine(result.ToString()); // ==> [[ 12.]] - }; - - // `BatchMatMul` is actually embedded into the `MatMul` operation on the tf.dll side. Every time we ask - // for a multiplication between matrices with rank > 2, the first rank - 2 dimensions are checked to be consistent - // across the two matrices and a common matrix multiplication is done on the residual 2 dimensions. - // - // np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(3, 3, 3) - // array([[[1, 2, 3], - // [4, 5, 6], - // [7, 8, 9]], - // - // [[1, 2, 3], - // [4, 5, 6], - // [7, 8, 9]], - // - // [[1, 2, 3], - // [4, 5, 6], - // [7, 8, 9]]]) - var firstTensor = tf.convert_to_tensor( - np.reshape( - np.array(1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9), - 3, 3, 3)); - // - // np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]).reshape(3,3,2) - // array([[[0, 1], - // [0, 1], - // [0, 1]], - // - // [[0, 1], - // [0, 0], - // [1, 0]], - // - // [[1, 0], - // [1, 0], - // [1, 0]]]) - var secondTensor = tf.convert_to_tensor( - np.reshape( - np.array(0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0), - 3, 3, 2)); - var batchMul = tf.batch_matmul(firstTensor, secondTensor); - var checkTensor = np.array(0, 6, 0, 15, 0, 24, 3, 1, 6, 4, 9, 7, 6, 0, 15, 0, 24, 0); - using (var sess = tf.Session()) - { - var result = sess.run(batchMul); - Console.WriteLine(result.ToString()); - // - // ==> array([[[0, 6], - // [0, 15], - // [0, 24]], - // - // [[ 3, 1], - // [ 6, 4], - // [ 9, 7]], - // - // [[ 6, 0], - // [15, 0], - // [24, 0]]]) - return np.reshape(result, 18) - .array_equal(checkTensor); - } - } - - public void PrepareData() - { - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs deleted file mode 100644 index 28c4b093..00000000 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ /dev/null @@ -1,66 +0,0 @@ -using System; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Simple hello world using TensorFlow - /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/helloworld.py - /// - public class HelloWorld : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "Hello World"; - public bool IsImportingGraph { get; set; } = false; - - public bool Run() - { - /* Create a Constant op - The op is added as a node to the default graph. - - The value returned by the constructor represents the output - of the Constant op. */ - var str = "Hello, TensorFlow.NET!"; - var hello = tf.constant(str); - - // Start tf session - using (var sess = tf.Session()) - { - // Run the op - var result = sess.run(hello); - Console.WriteLine(result.ToString()); - return result.ToString().Equals(str); - } - } - - public void PrepareData() - { - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/IExample.cs b/test/TensorFlowNET.Examples/IExample.cs deleted file mode 100644 index f8cda9b9..00000000 --- a/test/TensorFlowNET.Examples/IExample.cs +++ /dev/null @@ -1,59 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using Tensorflow; - -namespace TensorFlowNET.Examples -{ - /// - /// Interface of Example project - /// All example should implement IExample so the entry program will find it. - /// - public interface IExample - { - /// - /// True to run example - /// - bool Enabled { get; set; } - - /// - /// Set true to import the computation graph instead of building it. - /// - bool IsImportingGraph { get; set; } - - string Name { get; } - - bool Run(); - - /// - /// Build dataflow graph, train and predict - /// - /// - void Train(Session sess); - void Test(Session sess); - - void Predict(Session sess); - - Graph ImportGraph(); - - Graph BuildGraph(); - - /// - /// Prepare dataset - /// - void PrepareData(); - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/CIFAR10-CNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/CIFAR10-CNN.cs deleted file mode 100644 index 477ce85e..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/CIFAR10-CNN.cs +++ /dev/null @@ -1,73 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow; -using TensorFlowDatasets; - -namespace TensorFlowNET.Examples -{ - /// - /// https://www.tensorflow.org/tutorials/images/deep_cnn - /// - public class CIFAR10_CNN : IExample - { - public bool Enabled { get; set; } = true; - public bool IsImportingGraph { get; set; } = false; - - public string Name => "CIFAR-10 CNN"; - - public bool Run() - { - PrepareData(); - - return true; - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void PrepareData() - { - var tfds = new DatasetBuilder(); - tfds.download_and_prepare(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs deleted file mode 100644 index ab387fa5..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs +++ /dev/null @@ -1,338 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.Diagnostics; -using Tensorflow; -using Tensorflow.Hub; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Convolutional Neural Network classifier for Hand Written Digits - /// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end. - /// Use Stochastic Gradient Descent (SGD) optimizer. - /// http://www.easy-tf.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1 - /// - public class DigitRecognitionCNN : IExample - { - public bool Enabled { get; set; } = true; - public bool IsImportingGraph { get; set; } = false; - - public string Name => "MNIST CNN"; - - string logs_path = "logs"; - - const int img_h = 28, img_w = 28; // MNIST images are 28x28 - int n_classes = 10; // Number of classes, one class per digit - int n_channels = 1; - - // Hyper-parameters - int epochs = 5; // accuracy > 98% - int batch_size = 100; - float learning_rate = 0.001f; - Datasets mnist; - - // Network configuration - // 1st Convolutional Layer - int filter_size1 = 5; // Convolution filters are 5 x 5 pixels. - int num_filters1 = 16; // There are 16 of these filters. - int stride1 = 1; // The stride of the sliding window - - // 2nd Convolutional Layer - int filter_size2 = 5; // Convolution filters are 5 x 5 pixels. - int num_filters2 = 32;// There are 32 of these filters. - int stride2 = 1; // The stride of the sliding window - - // Fully-connected layer. - int h1 = 128; // Number of neurons in fully-connected layer. - - Tensor x, y; - Tensor loss, accuracy, cls_prediction; - Operation optimizer; - - int display_freq = 100; - float accuracy_test = 0f; - float loss_test = 1f; - - NDArray x_train, y_train; - NDArray x_valid, y_valid; - NDArray x_test, y_test; - - public bool Run() - { - PrepareData(); - BuildGraph(); - - using (var sess = tf.Session()) - { - Train(sess); - Test(sess); - } - - return loss_test < 0.05 && accuracy_test > 0.98; - } - - public Graph BuildGraph() - { - var graph = new Graph().as_default(); - - tf_with(tf.name_scope("Input"), delegate - { - // Placeholders for inputs (x) and outputs(y) - x = tf.placeholder(tf.float32, shape: (-1, img_h, img_w, n_channels), name: "X"); - y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); - }); - - var conv1 = conv_layer(x, filter_size1, num_filters1, stride1, name: "conv1"); - var pool1 = max_pool(conv1, ksize: 2, stride: 2, name: "pool1"); - var conv2 = conv_layer(pool1, filter_size2, num_filters2, stride2, name: "conv2"); - var pool2 = max_pool(conv2, ksize: 2, stride: 2, name: "pool2"); - var layer_flat = flatten_layer(pool2); - var fc1 = fc_layer(layer_flat, h1, "FC1", use_relu: true); - var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); - - tf_with(tf.variable_scope("Train"), delegate - { - tf_with(tf.variable_scope("Loss"), delegate - { - loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits), name: "loss"); - }); - - tf_with(tf.variable_scope("Optimizer"), delegate - { - optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); - }); - - tf_with(tf.variable_scope("Accuracy"), delegate - { - var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); - }); - - tf_with(tf.variable_scope("Prediction"), delegate - { - cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); - }); - }); - - return graph; - } - - public void Train(Session sess) - { - // Number of training iterations in each epoch - var num_tr_iter = y_train.shape[0] / batch_size; - - var init = tf.global_variables_initializer(); - sess.run(init); - - float loss_val = 100.0f; - float accuracy_val = 0f; - - var sw = new Stopwatch(); - sw.Start(); - foreach (var epoch in range(epochs)) - { - print($"Training epoch: {epoch + 1}"); - // Randomly shuffle the training data at the beginning of each epoch - (x_train, y_train) = mnist.Randomize(x_train, y_train); - - foreach (var iteration in range(num_tr_iter)) - { - var start = iteration * batch_size; - var end = (iteration + 1) * batch_size; - var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); - - // Run optimization op (backprop) - sess.run(optimizer, (x, x_batch), (y, y_batch)); - - if (iteration % display_freq == 0) - { - // Calculate and display the batch loss and accuracy - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); - loss_val = result[0]; - accuracy_val = result[1]; - print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms"); - sw.Restart(); - } - } - - // Run validation after every epoch - (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid)); - print("---------------------------------------------------------"); - print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); - print("---------------------------------------------------------"); - } - } - - public void Test(Session sess) - { - (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, x_test), (y, y_test)); - print("---------------------------------------------------------"); - print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); - print("---------------------------------------------------------"); - } - - /// - /// Create a 2D convolution layer - /// - /// input from previous layer - /// size of each filter - /// number of filters(or output feature maps) - /// filter stride - /// layer name - /// The output array - private Tensor conv_layer(Tensor x, int filter_size, int num_filters, int stride, string name) - { - return tf_with(tf.variable_scope(name), delegate { - - var num_in_channel = x.shape[x.NDims - 1]; - var shape = new[] { filter_size, filter_size, num_in_channel, num_filters }; - var W = weight_variable("W", shape); - // var tf.summary.histogram("weight", W); - var b = bias_variable("b", new[] { num_filters }); - // tf.summary.histogram("bias", b); - var layer = tf.nn.conv2d(x, W, - strides: new[] { 1, stride, stride, 1 }, - padding: "SAME"); - layer += b; - return tf.nn.relu(layer); - }); - } - - /// - /// Create a max pooling layer - /// - /// input to max-pooling layer - /// size of the max-pooling filter - /// stride of the max-pooling filter - /// layer name - /// The output array - private Tensor max_pool(Tensor x, int ksize, int stride, string name) - { - return tf.nn.max_pool(x, - ksize: new[] { 1, ksize, ksize, 1 }, - strides: new[] { 1, stride, stride, 1 }, - padding: "SAME", - name: name); - } - - /// - /// Flattens the output of the convolutional layer to be fed into fully-connected layer - /// - /// input array - /// flattened array - private Tensor flatten_layer(Tensor layer) - { - return tf_with(tf.variable_scope("Flatten_layer"), delegate - { - var layer_shape = layer.TensorShape; - var num_features = layer_shape[new Slice(1, 4)].size; - var layer_flat = tf.reshape(layer, new[] { -1, num_features }); - - return layer_flat; - }); - } - - /// - /// Create a weight variable with appropriate initialization - /// - /// - /// - /// - private RefVariable weight_variable(string name, int[] shape) - { - var initer = tf.truncated_normal_initializer(stddev: 0.01f); - return tf.get_variable(name, - dtype: tf.float32, - shape: shape, - initializer: initer); - } - - /// - /// Create a bias variable with appropriate initialization - /// - /// - /// - /// - private RefVariable bias_variable(string name, int[] shape) - { - var initial = tf.constant(0f, shape: shape, dtype: tf.float32); - return tf.get_variable(name, - dtype: tf.float32, - initializer: initial); - } - - /// - /// Create a fully-connected layer - /// - /// input from previous layer - /// number of hidden units in the fully-connected layer - /// layer name - /// boolean to add ReLU non-linearity (or not) - /// The output array - private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) - { - return tf_with(tf.variable_scope(name), delegate - { - var in_dim = x.shape[1]; - - var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units }); - var b = bias_variable("b_" + name, new[] { num_units }); - - var layer = tf.matmul(x, W) + b; - if (use_relu) - layer = tf.nn.relu(layer); - - return layer; - }); - } - - public void PrepareData() - { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; - (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); - (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); - (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); - - print("Size of:"); - print($"- Training-set:\t\t{len(mnist.Train.Data)}"); - print($"- Validation-set:\t{len(mnist.Validation.Data)}"); - } - - /// - /// Reformats the data to the format acceptable for convolutional layers - /// - /// - /// - /// - private (NDArray, NDArray) Reformat(NDArray x, NDArray y) - { - var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]).astype(np.int32), 1, len(np.unique(np.argmax(y, 1)))); - var dataset = x.reshape(x.shape[0], img_size, img_size, num_ch).astype(np.float32); - //y[0] = np.arange(num_class) == y[0]; - //var labels = (np.arange(num_class) == y.reshape(y.shape[0], 1, y.shape[1])).astype(np.float32); - return (dataset, y); - } - - public Graph ImportGraph() => throw new NotImplementedException(); - - public void Predict(Session sess) => throw new NotImplementedException(); - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs deleted file mode 100644 index a68b58ef..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs +++ /dev/null @@ -1,177 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using Tensorflow; -using Tensorflow.Hub; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Neural Network classifier for Hand Written Digits - /// Sample Neural Network architecture with two layers implemented for classifying MNIST digits. - /// Use Stochastic Gradient Descent (SGD) optimizer. - /// http://www.easy-tf.com/tf-tutorials/neural-networks - /// - public class DigitRecognitionNN : IExample - { - public bool Enabled { get; set; } = true; - public bool IsImportingGraph { get; set; } = false; - - public string Name => "Digits Recognition Neural Network"; - - const int img_h = 28; - const int img_w = 28; - int img_size_flat = img_h * img_w; // 784, the total number of pixels - int n_classes = 10; // Number of classes, one class per digit - // Hyper-parameters - int epochs = 10; - int batch_size = 100; - float learning_rate = 0.001f; - int h1 = 200; // number of nodes in the 1st hidden layer - Datasets mnist; - - Tensor x, y; - Tensor loss, accuracy; - Operation optimizer; - - int display_freq = 100; - float accuracy_test = 0f; - float loss_test = 1f; - - public bool Run() - { - PrepareData(); - BuildGraph(); - - using (var sess = tf.Session()) - { - Train(sess); - Test(sess); - }; - - return loss_test < 0.09 && accuracy_test > 0.95; - } - - public Graph BuildGraph() - { - var graph = new Graph().as_default(); - - // Placeholders for inputs (x) and outputs(y) - x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); - y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); - - // Create a fully-connected layer with h1 nodes as hidden layer - var fc1 = fc_layer(x, h1, "FC1", use_relu: true); - // Create a fully-connected layer with n_classes nodes as output layer - var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); - // Define the loss function, optimizer, and accuracy - var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits); - loss = tf.reduce_mean(logits, name: "loss"); - optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); - var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); - - // Network predictions - var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); - - return graph; - } - - private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) - { - var in_dim = x.shape[1]; - - var initer = tf.truncated_normal_initializer(stddev: 0.01f); - var W = tf.get_variable("W_" + name, - dtype: tf.float32, - shape: (in_dim, num_units), - initializer: initer); - - var initial = tf.constant(0f, num_units); - var b = tf.get_variable("b_" + name, - dtype: tf.float32, - initializer: initial); - - var layer = tf.matmul(x, W) + b; - if (use_relu) - layer = tf.nn.relu(layer); - - return layer; - } - - public Graph ImportGraph() => throw new NotImplementedException(); - - public void Predict(Session sess) => throw new NotImplementedException(); - - public void PrepareData() - { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; - } - - public void Train(Session sess) - { - // Number of training iterations in each epoch - var num_tr_iter = mnist.Train.Labels.shape[0] / batch_size; - - var init = tf.global_variables_initializer(); - sess.run(init); - - float loss_val = 100.0f; - float accuracy_val = 0f; - - foreach (var epoch in range(epochs)) - { - print($"Training epoch: {epoch + 1}"); - // Randomly shuffle the training data at the beginning of each epoch - var (x_train, y_train) = mnist.Randomize(mnist.Train.Data, mnist.Train.Labels); - - foreach (var iteration in range(num_tr_iter)) - { - var start = iteration * batch_size; - var end = (iteration + 1) * batch_size; - var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); - - // Run optimization op (backprop) - sess.run(optimizer, (x, x_batch), (y, y_batch)); - - if (iteration % display_freq == 0) - { - // Calculate and display the batch loss and accuracy - (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_batch), (y, y_batch)); - print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); - } - } - - // Run validation after every epoch - (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, mnist.Validation.Data), (y, mnist.Validation.Labels)); - print("---------------------------------------------------------"); - print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); - print("---------------------------------------------------------"); - } - } - - public void Test(Session sess) - { - (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, mnist.Test.Data), (y, mnist.Test.Labels)); - print("---------------------------------------------------------"); - print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); - print("---------------------------------------------------------"); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs deleted file mode 100644 index 76d3d63f..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs +++ /dev/null @@ -1,161 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using Tensorflow; -using Tensorflow.Hub; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Recurrent Neural Network for handwritten digits MNIST. - /// https://medium.com/machine-learning-algorithms/mnist-using-recurrent-neural-network-2d070a5915a2 - /// - public class DigitRecognitionRNN : IExample - { - public bool Enabled { get; set; } = false; - public bool IsImportingGraph { get; set; } = false; - - public string Name => "MNIST RNN"; - - string logs_path = "logs"; - - // Hyper-parameters - int n_neurons = 128; - float learning_rate = 0.001f; - int batch_size = 128; - int epochs = 10; - - int n_steps = 28; - int n_inputs = 28; - int n_outputs = 10; - - Datasets mnist; - - Tensor x, y; - Tensor loss, accuracy, cls_prediction; - Operation optimizer; - - int display_freq = 100; - float accuracy_test = 0f; - float loss_test = 1f; - - NDArray x_train, y_train; - NDArray x_valid, y_valid; - NDArray x_test, y_test; - - public bool Run() - { - PrepareData(); - BuildGraph(); - - using (var sess = tf.Session()) - { - Train(sess); - Test(sess); - } - - return loss_test < 0.09 && accuracy_test > 0.95; - } - - public Graph BuildGraph() - { - var graph = new Graph().as_default(); - - var X = tf.placeholder(tf.float32, new[] { -1, n_steps, n_inputs }); - var y = tf.placeholder(tf.int32, new[] { -1 }); - var cell = tf.nn.rnn_cell.BasicRNNCell(num_units: n_neurons); - var (output, state) = tf.nn.dynamic_rnn(cell, X, dtype: tf.float32); - - return graph; - } - - public void Train(Session sess) - { - // Number of training iterations in each epoch - var num_tr_iter = y_train.shape[0] / batch_size; - - var init = tf.global_variables_initializer(); - sess.run(init); - - float loss_val = 100.0f; - float accuracy_val = 0f; - - foreach (var epoch in range(epochs)) - { - print($"Training epoch: {epoch + 1}"); - // Randomly shuffle the training data at the beginning of each epoch - (x_train, y_train) = mnist.Randomize(x_train, y_train); - - foreach (var iteration in range(num_tr_iter)) - { - var start = iteration * batch_size; - var end = (iteration + 1) * batch_size; - var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); - - // Run optimization op (backprop) - sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); - - if (iteration % display_freq == 0) - { - // Calculate and display the batch loss and accuracy - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); - loss_val = result[0]; - accuracy_val = result[1]; - print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); - } - } - - // Run validation after every epoch - var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid)); - loss_val = results1[0]; - accuracy_val = results1[1]; - print("---------------------------------------------------------"); - print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); - print("---------------------------------------------------------"); - } - } - - public void Test(Session sess) - { - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test)); - loss_test = result[0]; - accuracy_test = result[1]; - print("---------------------------------------------------------"); - print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); - print("---------------------------------------------------------"); - } - - public void PrepareData() - { - mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; - (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); - (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); - (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); - - print("Size of:"); - print($"- Training-set:\t\t{len(mnist.Train.Data)}"); - print($"- Validation-set:\t{len(mnist.Validation.Data)}"); - print($"- Test-set:\t\t{len(mnist.Test.Data)}"); - } - - public Graph ImportGraph() => throw new NotImplementedException(); - - public void Predict(Session sess) => throw new NotImplementedException(); - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ImageBackgroundRemoval.cs b/test/TensorFlowNET.Examples/ImageProcessing/ImageBackgroundRemoval.cs deleted file mode 100644 index 1bc9781f..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/ImageBackgroundRemoval.cs +++ /dev/null @@ -1,84 +0,0 @@ -using System; -using System.IO; -using Tensorflow; -using TensorFlowNET.Examples.Utility; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// This example removes the background from an input image. - /// - /// https://github.com/susheelsk/image-background-removal - /// - public class ImageBackgroundRemoval : IExample - { - public bool Enabled { get; set; } = false; - public bool IsImportingGraph { get; set; } = true; - - public string Name => "Image Background Removal"; - - string dataDir = "deeplabv3"; - string modelDir = "deeplabv3_mnv2_pascal_train_aug"; - string modelName = "frozen_inference_graph.pb"; - - public bool Run() - { - PrepareData(); - - // import GraphDef from pb file - var graph = new Graph().as_default(); - graph.Import(Path.Join(dataDir, modelDir, modelName)); - - Tensor output = graph.OperationByName("SemanticPredictions"); - - using (var sess = tf.Session(graph)) - { - // Runs inference on a single image. - sess.run(output, new FeedItem(output, "[np.asarray(resized_image)]")); - } - - return false; - } - - public void PrepareData() - { - // get mobile_net_model file - string fileName = "deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz"; - string url = $"http://download.tensorflow.org/models/{fileName}"; - Web.Download(url, dataDir, fileName); - Compress.ExtractTGZ(Path.Join(dataDir, fileName), dataDir); - - // xception_model, better accuracy - /*fileName = "deeplabv3_pascal_train_aug_2018_01_04.tar.gz"; - url = $"http://download.tensorflow.org/models/{fileName}"; - Web.Download(url, modelDir, fileName); - Compress.ExtractTGZ(Path.Join(modelDir, fileName), modelDir);*/ - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs b/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs deleted file mode 100644 index 0414d68d..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs +++ /dev/null @@ -1,140 +0,0 @@ -using NumSharp; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using Console = Colorful.Console; -using Tensorflow; -using System.Drawing; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Inception v3 is a widely-used image recognition model - /// that has been shown to attain greater than 78.1% accuracy on the ImageNet dataset. - /// The model is the culmination of many ideas developed by multiple researchers over the years. - /// - public class ImageRecognitionInception : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "Image Recognition Inception"; - public bool IsImportingGraph { get; set; } = false; - - - string dir = "ImageRecognitionInception"; - string pbFile = "tensorflow_inception_graph.pb"; - string labelFile = "imagenet_comp_graph_label_strings.txt"; - List file_ndarrays = new List(); - - public bool Run() - { - PrepareData(); - - var graph = new Graph(); - //import GraphDef from pb file - graph.Import(Path.Join(dir, pbFile)); - - var input_name = "input"; - var output_name = "output"; - - var input_operation = graph.OperationByName(input_name); - var output_operation = graph.OperationByName(output_name); - - var labels = File.ReadAllLines(Path.Join(dir, labelFile)); - var result_labels = new List(); - var sw = new Stopwatch(); - - using (var sess = tf.Session(graph)) - { - foreach (var nd in file_ndarrays) - { - sw.Restart(); - - var results = sess.run(output_operation.outputs[0], (input_operation.outputs[0], nd)); - results = np.squeeze(results); - int idx = np.argmax(results); - - Console.WriteLine($"{labels[idx]} {results[idx]} in {sw.ElapsedMilliseconds}ms", Color.Tan); - result_labels.Add(labels[idx]); - } - } - - return result_labels.Contains("military uniform"); - } - - private NDArray ReadTensorFromImageFile(string file_name, - int input_height = 224, - int input_width = 224, - int input_mean = 117, - int input_std = 1) - { - var graph = tf.Graph().as_default(); - - var file_reader = tf.read_file(file_name, "file_reader"); - var decodeJpeg = tf.image.decode_jpeg(file_reader, channels: 3, name: "DecodeJpeg"); - var cast = tf.cast(decodeJpeg, tf.float32); - var dims_expander = tf.expand_dims(cast, 0); - var resize = tf.constant(new int[] { input_height, input_width }); - var bilinear = tf.image.resize_bilinear(dims_expander, resize); - var sub = tf.subtract(bilinear, new float[] { input_mean }); - var normalized = tf.divide(sub, new float[] { input_std }); - - using (var sess = tf.Session(graph)) - return sess.run(normalized); - } - - public void PrepareData() - { - Directory.CreateDirectory(dir); - - // get model file - string url = "https://storage.googleapis.com/download.tf.org/models/inception5h.zip"; - - Utility.Web.Download(url, dir, "inception5h.zip"); - - Utility.Compress.UnZip(Path.Join(dir, "inception5h.zip"), dir); - - // download sample picture - Directory.CreateDirectory(Path.Join(dir, "img")); - url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/grace_hopper.jpg"; - Utility.Web.Download(url, Path.Join(dir, "img"), "grace_hopper.jpg"); - - url = $"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/shasta-daisy.jpg"; - Utility.Web.Download(url, Path.Join(dir, "img"), "shasta-daisy.jpg"); - - // load image file - var files = Directory.GetFiles(Path.Join(dir, "img")); - for (int i = 0; i < files.Length; i++) - { - var nd = ReadTensorFromImageFile(files[i]); - file_ndarrays.Add(nd); - } - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs b/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs deleted file mode 100644 index 704067fc..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs +++ /dev/null @@ -1,133 +0,0 @@ -using NumSharp; -using System; -using System.IO; -using System.Linq; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Inception Architecture for Computer Vision - /// Port from tensorflow\examples\label_image\label_image.py - /// - public class InceptionArchGoogLeNet : IExample - { - public bool Enabled { get; set; } = false; - public string Name => "Inception Arch GoogLeNet"; - public bool IsImportingGraph { get; set; } = false; - - - string dir = "label_image_data"; - string pbFile = "inception_v3_2016_08_28_frozen.pb"; - string labelFile = "imagenet_slim_labels.txt"; - string picFile = "grace_hopper.jpg"; - int input_height = 299; - int input_width = 299; - int input_mean = 0; - int input_std = 255; - string input_name = "import/input"; - string output_name = "import/InceptionV3/Predictions/Reshape_1"; - - public bool Run() - { - PrepareData(); - - var labels = File.ReadAllLines(Path.Join(dir, labelFile)); - - var nd = ReadTensorFromImageFile(Path.Join(dir, picFile), - input_height: input_height, - input_width: input_width, - input_mean: input_mean, - input_std: input_std); - - var graph = new Graph(); - graph.Import(Path.Join(dir, pbFile)); - var input_operation = graph.get_operation_by_name(input_name); - var output_operation = graph.get_operation_by_name(output_name); - - NDArray results; - using (var sess = tf.Session(graph)) - { - results = sess.run(output_operation.outputs[0], - new FeedItem(input_operation.outputs[0], nd)); - } - - results = np.squeeze(results); - - var argsort = results.argsort(); - var top_k = argsort.Data() - .Skip(results.size - 5) - .Reverse() - .ToArray(); - - foreach (float idx in top_k) - Console.WriteLine($"{picFile}: {idx} {labels[(int)idx]}, {results[(int)idx]}"); - - return true; - } - - private NDArray ReadTensorFromImageFile(string file_name, - int input_height = 299, - int input_width = 299, - int input_mean = 0, - int input_std = 255) - { - var graph = tf.Graph().as_default(); - - var file_reader = tf.read_file(file_name, "file_reader"); - var image_reader = tf.image.decode_jpeg(file_reader, channels: 3, name: "jpeg_reader"); - var caster = tf.cast(image_reader, tf.float32); - var dims_expander = tf.expand_dims(caster, 0); - var resize = tf.constant(new int[] { input_height, input_width }); - var bilinear = tf.image.resize_bilinear(dims_expander, resize); - var sub = tf.subtract(bilinear, new float[] { input_mean }); - var normalized = tf.divide(sub, new float[] { input_std }); - - using (var sess = tf.Session(graph)) - return sess.run(normalized); - } - - public void PrepareData() - { - Directory.CreateDirectory(dir); - - // get model file - string url = "https://storage.googleapis.com/download.tf.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz"; - - Utility.Web.Download(url, dir, $"{pbFile}.tar.gz"); - - Utility.Compress.ExtractTGZ(Path.Join(dir, $"{pbFile}.tar.gz"), dir); - - // download sample picture - string pic = "grace_hopper.jpg"; - url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}"; - Utility.Web.Download(url, dir, pic); - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs deleted file mode 100644 index d0c06704..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs +++ /dev/null @@ -1,175 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.IO; -using Tensorflow; -using TensorFlowNET.Examples.Utility; -using System.Drawing; -using System.Drawing.Drawing2D; -using System.Linq; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - public class ObjectDetection : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "Object Detection"; - public bool IsImportingGraph { get; set; } = true; - - public float MIN_SCORE = 0.5f; - - string modelDir = "ssd_mobilenet_v1_coco_2018_01_28"; - string imageDir = "images"; - string pbFile = "frozen_inference_graph.pb"; - string labelFile = "mscoco_label_map.pbtxt"; - string picFile = "input.jpg"; - - NDArray imgArr; - - public bool Run() - { - PrepareData(); - - // read in the input image - imgArr = ReadTensorFromImageFile(Path.Join(imageDir, "input.jpg")); - - var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); - - using (var sess = tf.Session(graph)) - Predict(sess); - - return true; - } - - public Graph ImportGraph() - { - var graph = new Graph().as_default(); - graph.Import(Path.Join(modelDir, pbFile)); - - return graph; - } - - public void Predict(Session sess) - { - var graph = tf.get_default_graph(); - - Tensor tensorNum = graph.OperationByName("num_detections"); - Tensor tensorBoxes = graph.OperationByName("detection_boxes"); - Tensor tensorScores = graph.OperationByName("detection_scores"); - Tensor tensorClasses = graph.OperationByName("detection_classes"); - Tensor imgTensor = graph.OperationByName("image_tensor"); - Tensor[] outTensorArr = new Tensor[] { tensorNum, tensorBoxes, tensorScores, tensorClasses }; - - var results = sess.run(outTensorArr, new FeedItem(imgTensor, imgArr)); - - buildOutputImage(results); - } - - public void PrepareData() - { - // get model file - string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; - Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz"); - - Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./"); - - // download sample picture - url = $"https://github.com/tensorflow/models/raw/master/research/object_detection/test_images/image2.jpg"; - Web.Download(url, imageDir, "input.jpg"); - - // download the pbtxt file - url = $"https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt"; - Web.Download(url, modelDir, "mscoco_label_map.pbtxt"); - } - - private NDArray ReadTensorFromImageFile(string file_name) - { - var graph = tf.Graph().as_default(); - - var file_reader = tf.read_file(file_name, "file_reader"); - var decodeJpeg = tf.image.decode_jpeg(file_reader, channels: 3, name: "DecodeJpeg"); - var casted = tf.cast(decodeJpeg, TF_DataType.TF_UINT8); - var dims_expander = tf.expand_dims(casted, 0); - - using (var sess = tf.Session(graph)) - return sess.run(dims_expander); - } - - private void buildOutputImage(NDArray[] resultArr) - { - // get pbtxt items - PbtxtItems pbTxtItems = PbtxtParser.ParsePbtxtFile(Path.Join(modelDir, "mscoco_label_map.pbtxt")); - - // get bitmap - Bitmap bitmap = new Bitmap(Path.Join(imageDir, "input.jpg")); - - var scores = resultArr[2].AsIterator(); - var boxes = resultArr[1].GetData(); - var id = np.squeeze(resultArr[3]).GetData(); - for (int i=0; i< scores.size; i++) - { - float score = scores.MoveNext(); - if (score > MIN_SCORE) - { - float top = boxes[i * 4] * bitmap.Height; - float left = boxes[i * 4 + 1] * bitmap.Width; - float bottom = boxes[i * 4 + 2] * bitmap.Height; - float right = boxes[i * 4 + 3] * bitmap.Width; - - Rectangle rect = new Rectangle() - { - X = (int)left, - Y = (int)top, - Width = (int)(right - left), - Height = (int)(bottom - top) - }; - - string name = pbTxtItems.items.Where(w => w.id == id[i]).Select(s=>s.display_name).FirstOrDefault(); - - drawObjectOnBitmap(bitmap, rect, score, name); - } - } - - string path = Path.Join(imageDir, "output.jpg"); - bitmap.Save(path); - Console.WriteLine($"Processed image is saved as {path}"); - } - - private void drawObjectOnBitmap(Bitmap bmp, Rectangle rect, float score, string name) - { - using (Graphics graphic = Graphics.FromImage(bmp)) - { - graphic.SmoothingMode = SmoothingMode.AntiAlias; - - using (Pen pen = new Pen(Color.Red, 2)) - { - graphic.DrawRectangle(pen, rect); - - Point p = new Point(rect.Right + 5, rect.Top + 5); - string text = string.Format("{0}:{1}%", name, (int)(score * 100)); - graphic.DrawString(text, new Font("Verdana", 8), Brushes.Red, p); - } - } - } - - public Graph BuildGraph() => throw new NotImplementedException(); - public void Train(Session sess) => throw new NotImplementedException(); - public void Test(Session sess) => throw new NotImplementedException(); - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs deleted file mode 100644 index 7f2d81f4..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ /dev/null @@ -1,791 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using Google.Protobuf; -using NumSharp; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Threading.Tasks; -using Tensorflow; -using TensorFlowNET.Examples.Utility; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// In this tutorial, we will reuse the feature extraction capabilities from powerful image classifiers trained on ImageNet - /// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this - /// by taking a piece of a model that has already been trained on a related task and reusing it in a new model. - /// - /// https://www.tf.org/hub/tutorials/image_retraining - /// - public class RetrainImageClassifier : IExample - { - public int Priority => 16; - - public bool Enabled { get; set; } = true; - public bool IsImportingGraph { get; set; } = true; - - public string Name => "Retrain Image Classifier"; - - const string data_dir = "retrain_images"; - string summaries_dir = Path.Join(data_dir, "retrain_logs"); - string image_dir = Path.Join(data_dir, "flower_photos"); - string bottleneck_dir = Path.Join(data_dir, "bottleneck"); - string output_graph = Path.Join(data_dir, "output_graph.pb"); - string output_labels = Path.Join(data_dir, "output_labels.txt"); - // The location where variable checkpoints will be stored. - string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint"); - string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3"; - string input_tensor_name = "Placeholder"; - string final_tensor_name = "Score"; - float testing_percentage = 0.1f; - float validation_percentage = 0.1f; - float learning_rate = 0.01f; - Tensor resized_image_tensor; - Dictionary> image_lists; - int how_many_training_steps = 100; - int eval_step_interval = 10; - int train_batch_size = 100; - int test_batch_size = -1; - int validation_batch_size = 100; - int intermediate_store_frequency = 0; - int class_count = 0; - const int MAX_NUM_IMAGES_PER_CLASS = 134217727; - Operation train_step; - Tensor final_tensor; - Tensor bottleneck_input; - Tensor cross_entropy; - Tensor ground_truth_input; - Tensor bottleneck_tensor; - bool wants_quantization; - float test_accuracy; - NDArray predictions; - - public bool Run() - { - PrepareData(); - - #region For debug purpose - - // predict images - // Predict(null); - - // load saved pb and test new images. - // Test(null); - - #endregion - - var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); - - using (var sess = tf.Session(graph)) - { - Train(sess); - } - - return test_accuracy > 0.75f; - } - - /// - /// Runs a final evaluation on an eval graph using the test data set. - /// - /// - /// - /// - /// - /// - /// - /// - /// - private (float, NDArray) run_final_eval(Session train_session, object module_spec, int class_count, - Dictionary> image_lists, - Tensor jpeg_data_tensor, Tensor decoded_image_tensor, - Tensor resized_image_tensor, Tensor bottleneck_tensor) - { - var (test_bottlenecks, test_ground_truth, test_filenames) = get_random_cached_bottlenecks(train_session, image_lists, - test_batch_size, "testing", bottleneck_dir, image_dir, jpeg_data_tensor, - decoded_image_tensor, resized_image_tensor, bottleneck_tensor, tfhub_module); - - var (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step, - prediction) = build_eval_session(class_count); - - (float accuracy, NDArray prediction1) = eval_session.run((evaluation_step, prediction), - (bottleneck_input, test_bottlenecks), - (ground_truth_input, test_ground_truth)); - - print($"final test accuracy: {(accuracy * 100).ToString("G4")}% (N={len(test_bottlenecks)})"); - - return (accuracy, prediction1); - } - - private (Session, Tensor, Tensor, Tensor, Tensor, Tensor) - build_eval_session(int class_count) - { - // If quantized, we need to create the correct eval graph for exporting. - var (eval_graph, bottleneck_tensor, resized_input_tensor, wants_quantization) = create_module_graph(); - var eval_sess = tf.Session(graph: eval_graph); - Tensor evaluation_step = null; - Tensor prediction = null; - - var graph = eval_graph.as_default(); - // Add the new layer for exporting. - var (_, _, bottleneck_input, ground_truth_input, final_tensor) = - add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor, - wants_quantization, is_training: false); - - // Now we need to restore the values from the training graph to the eval - // graph. - tf.train.Saver().restore(eval_sess, CHECKPOINT_NAME); - - (evaluation_step, prediction) = add_evaluation_step(final_tensor, - ground_truth_input); - - return (eval_sess, resized_input_tensor, bottleneck_input, ground_truth_input, - evaluation_step, prediction); - } - - /// - /// Adds a new softmax and fully-connected layer for training and eval. - /// - /// We need to retrain the top layer to identify our new classes, so this function - /// adds the right operations to the graph, along with some variables to hold the - /// weights, and then sets up all the gradients for the backward pass. - /// - /// The set up for the softmax and fully-connected layers is based on: - /// https://www.tf.org/tutorials/mnist/beginners/index.html - /// - /// - /// - /// - /// - /// - /// - private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name, - Tensor bottleneck_tensor, bool quantize_layer, bool is_training) - { - var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.TensorShape.dims[0], bottleneck_tensor.TensorShape.dims[1]); - tf_with(tf.name_scope("input"), scope => - { - bottleneck_input = tf.placeholder_with_default( - bottleneck_tensor, - shape: bottleneck_tensor.TensorShape.dims, - name: "BottleneckInputPlaceholder"); - - ground_truth_input = tf.placeholder(tf.int64, new TensorShape(batch_size), name: "GroundTruthInput"); - }); - - // Organizing the following ops so they are easier to see in TensorBoard. - string layer_name = "final_retrain_ops"; - Tensor logits = null; - tf_with(tf.name_scope(layer_name), scope => - { - RefVariable layer_weights = null; - tf_with(tf.name_scope("weights"), delegate - { - var initial_value = tf.truncated_normal(new int[] { bottleneck_tensor_size, class_count }, stddev: 0.001f); - layer_weights = tf.Variable(initial_value, name: "final_weights"); - variable_summaries(layer_weights); - }); - - RefVariable layer_biases = null; - tf_with(tf.name_scope("biases"), delegate - { - layer_biases = tf.Variable(tf.zeros(new TensorShape(class_count)), name: "final_biases"); - variable_summaries(layer_biases); - }); - - tf_with(tf.name_scope("Wx_plus_b"), delegate - { - logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases; - tf.summary.histogram("pre_activations", logits); - }); - }); - - final_tensor = tf.nn.softmax(logits, name: final_tensor_name); - - // The tf.contrib.quantize functions rewrite the graph in place for - // quantization. The imported model graph has already been rewritten, so upon - // calling these rewrites, only the newly added final layer will be - // transformed. - if (quantize_layer) - { - throw new NotImplementedException("quantize_layer"); - /*if (is_training) - tf.contrib.quantize.create_training_graph(); - else - tf.contrib.quantize.create_eval_graph();*/ - } - - tf.summary.histogram("activations", final_tensor); - - // If this is an eval graph, we don't need to add loss ops or an optimizer. - if (!is_training) - return (null, null, bottleneck_input, ground_truth_input, final_tensor); - - Tensor cross_entropy_mean = null; - tf_with(tf.name_scope("cross_entropy"), delegate - { - cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( - labels: ground_truth_input, logits: logits); - }); - - tf.summary.scalar("cross_entropy", cross_entropy_mean); - - tf_with(tf.name_scope("train"), delegate - { - var optimizer = tf.train.GradientDescentOptimizer(learning_rate); - train_step = optimizer.minimize(cross_entropy_mean); - }); - - return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input, - final_tensor); - } - - private void variable_summaries(RefVariable var) - { - tf_with(tf.name_scope("summaries"), delegate - { - var mean = tf.reduce_mean(var); - tf.summary.scalar("mean", mean); - Tensor stddev = null; - tf_with(tf.name_scope("stddev"), delegate - { - stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))); - }); - tf.summary.scalar("stddev", stddev); - tf.summary.scalar("max", tf.reduce_max(var)); - tf.summary.scalar("min", tf.reduce_min(var)); - tf.summary.histogram("histogram", var); - }); - } - - private (Graph, Tensor, Tensor, bool) create_module_graph() - { - var (height, width) = (299, 299); - var graph = tf.Graph().as_default(); - tf.train.import_meta_graph("graph/InceptionV3.meta"); - Tensor resized_input_tensor = graph.OperationByName(input_tensor_name); //tf.placeholder(tf.float32, new TensorShape(-1, height, width, 3)); - // var m = hub.Module(module_spec); - Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");// m(resized_input_tensor); - var wants_quantization = false; - return (graph, bottleneck_tensor, resized_input_tensor, wants_quantization); - } - - private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary> image_lists, - int how_many, string category, string bottleneck_dir, string image_dir, - Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, - Tensor bottleneck_tensor, string module_name) - { - var bottlenecks = new List(); - var ground_truths = new List(); - var filenames = new List(); - class_count = image_lists.Keys.Count; - if (how_many >= 0) - { - // Retrieve a random sample of bottlenecks. - foreach (var unused_i in range(how_many)) - { - int label_index = new Random().Next(class_count); - string label_name = image_lists.Keys.ToArray()[label_index]; - int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS); - string image_name = get_image_path(image_lists, label_name, image_index, - image_dir, category); - var bottleneck = get_or_create_bottleneck( - sess, image_lists, label_name, image_index, image_dir, category, - bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, - resized_input_tensor, bottleneck_tensor, module_name); - bottlenecks.Add(bottleneck); - ground_truths.Add(label_index); - filenames.Add(image_name); - } - } - else - { - // Retrieve all bottlenecks. - foreach (var (label_index, label_name) in enumerate(image_lists.Keys.ToArray())) - { - foreach (var (image_index, image_name) in enumerate(image_lists[label_name][category])) - { - var bottleneck = get_or_create_bottleneck( - sess, image_lists, label_name, image_index, image_dir, category, - bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, - resized_input_tensor, bottleneck_tensor, module_name); - - bottlenecks.Add(bottleneck); - ground_truths.Add(label_index); - filenames.Add(image_name); - } - } - } - - return (bottlenecks.ToArray(), ground_truths.ToArray(), filenames.ToArray()); - } - - /// - /// Inserts the operations we need to evaluate the accuracy of our results. - /// - /// - /// - /// - private (Tensor, Tensor) add_evaluation_step(Tensor result_tensor, Tensor ground_truth_tensor) - { - Tensor evaluation_step = null, correct_prediction = null, prediction = null; - - tf_with(tf.name_scope("accuracy"), scope => - { - tf_with(tf.name_scope("correct_prediction"), delegate - { - prediction = tf.argmax(result_tensor, 1); - correct_prediction = tf.equal(prediction, ground_truth_tensor); - }); - - tf_with(tf.name_scope("accuracy"), delegate - { - evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); - }); - }); - - tf.summary.scalar("accuracy", evaluation_step); - return (evaluation_step, prediction); - } - - /// - /// Ensures all the training, testing, and validation bottlenecks are cached. - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - private void cache_bottlenecks(Session sess, Dictionary> image_lists, - string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, - Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) - { - int how_many_bottlenecks = 0; - var kvs = image_lists.ToArray(); - var categories = new string[] {"training", "testing", "validation"}; - Parallel.For(0, kvs.Length, i => - { - var (label_name, label_lists) = kvs[i]; - - Parallel.For(0, categories.Length, j => - { - var category = categories[j]; - var category_list = label_lists[category]; - foreach (var (index, unused_base_name) in enumerate(category_list)) - { - get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category, - bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, - resized_input_tensor, bottleneck_tensor, module_name); - how_many_bottlenecks++; - if (how_many_bottlenecks % 300 == 0) - print($"{how_many_bottlenecks} bottleneck files created."); - } - }); - }); - } - - private float[] get_or_create_bottleneck(Session sess, Dictionary> image_lists, - string label_name, int index, string image_dir, string category, string bottleneck_dir, - Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, - Tensor bottleneck_tensor, string module_name) - { - var label_lists = image_lists[label_name]; - var sub_dir_path = Path.Join(bottleneck_dir, label_name); - Directory.CreateDirectory(sub_dir_path); - string bottleneck_path = get_bottleneck_path(image_lists, label_name, index, - bottleneck_dir, category, module_name); - - if (!File.Exists(bottleneck_path)) - create_bottleneck_file(bottleneck_path, image_lists, label_name, index, - image_dir, category, sess, jpeg_data_tensor, - decoded_image_tensor, resized_input_tensor, - bottleneck_tensor); - var bottleneck_string = File.ReadAllText(bottleneck_path); - var bottleneck_values = Array.ConvertAll(bottleneck_string.Split(','), x => float.Parse(x)); - return bottleneck_values; - } - - private void create_bottleneck_file(string bottleneck_path, Dictionary> image_lists, - string label_name, int index, string image_dir, string category, Session sess, - Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor) - { - // Create a single bottleneck file. - print("Creating bottleneck at " + bottleneck_path); - var image_path = get_image_path(image_lists, label_name, index, image_dir, category); - if (!File.Exists(image_path)) - print($"File does not exist {image_path}"); - - var image_data = File.ReadAllBytes(image_path); - var bottleneck_values = run_bottleneck_on_image( - sess, image_data, jpeg_data_tensor, decoded_image_tensor, - resized_input_tensor, bottleneck_tensor); - var values = bottleneck_values.Data(); - var bottleneck_string = string.Join(",", values); - File.WriteAllText(bottleneck_path, bottleneck_string); - } - - /// - /// Runs inference on an image to extract the 'bottleneck' summary layer. - /// - /// Current active TensorFlow Session. - /// Data of raw JPEG data. - /// Input data layer in the graph. - /// Output of initial image resizing and preprocessing. - /// The input node of the recognition graph. - /// Layer before the final softmax. - /// - private NDArray run_bottleneck_on_image(Session sess, byte[] image_data, Tensor image_data_tensor, - Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor) - { - // First decode the JPEG image, resize it, and rescale the pixel values. - var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, new Tensor(image_data, TF_DataType.TF_STRING))); - // Then run it through the recognition network. - var bottleneck_values = sess.run(bottleneck_tensor, new FeedItem(resized_input_tensor, resized_input_values))[0]; - bottleneck_values = np.squeeze(bottleneck_values); - return bottleneck_values; - } - - private string get_bottleneck_path(Dictionary> image_lists, string label_name, int index, - string bottleneck_dir, string category, string module_name) - { - module_name = (module_name.Replace("://", "~") // URL scheme. - .Replace('/', '~') // URL and Unix paths. - .Replace(':', '~').Replace('\\', '~')); // Windows paths. - return get_image_path(image_lists, label_name, index, bottleneck_dir, - category) + "_" + module_name + ".txt"; - } - - private string get_image_path(Dictionary> image_lists, string label_name, - int index, string image_dir, string category) - { - if (!image_lists.ContainsKey(label_name)) - print($"Label does not exist {label_name}"); - - var label_lists = image_lists[label_name]; - if (!label_lists.ContainsKey(category)) - print($"Category does not exist {category}"); - var category_list = label_lists[category]; - if (category_list.Length == 0) - print($"Label {label_name} has no images in the category {category}."); - - var mod_index = index % len(category_list); - var base_name = category_list[mod_index].Split(Path.DirectorySeparatorChar).Last(); - var sub_dir = label_name; - var full_path = Path.Join(image_dir, sub_dir, base_name); - return full_path; - } - - /// - /// Saves an graph to file, creating a valid quantized one if necessary. - /// - /// - /// - private void save_graph_to_file(string graph_file_name, int class_count) - { - var (sess, _, _, _, _, _) = build_eval_session(class_count); - var graph = sess.graph; - var output_graph_def = tf.graph_util.convert_variables_to_constants( - sess, graph.as_graph_def(), new string[] { final_tensor_name }); - File.WriteAllBytes(graph_file_name, output_graph_def.ToByteArray()); - } - - public void PrepareData() - { - // get a set of images to teach the network about the new classes - string fileName = "flower_photos.tgz"; - string url = $"http://download.tensorflow.org/example_images/{fileName}"; - Web.Download(url, data_dir, fileName); - Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); - - // download graph meta data - url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta"; - Web.Download(url, "graph", "InceptionV3.meta"); - - // download variables.data checkpoint file. - url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip"; - Web.Download(url, data_dir, "tfhub_modules.zip"); - Compress.UnZip(Path.Join(data_dir, "tfhub_modules.zip"), "tfhub_modules"); - - // Prepare necessary directories that can be used during training - Directory.CreateDirectory(summaries_dir); - Directory.CreateDirectory(bottleneck_dir); - - // Look at the folder structure, and create lists of all the images. - image_lists = create_image_lists(); - class_count = len(image_lists); - if (class_count == 0) - print($"No valid folders of images found at {image_dir}"); - if (class_count == 1) - print("Only one valid folder of images found at " + - image_dir + - " - multiple classes are needed for classification."); - } - - private (Tensor, Tensor) add_jpeg_decoding() - { - // height, width, depth - var input_dim = (299, 299, 3); - var jpeg_data = tf.placeholder(tf.@string, name: "DecodeJPGInput"); - var decoded_image = tf.image.decode_jpeg(jpeg_data, channels: input_dim.Item3); - // Convert from full range of uint8 to range [0,1] of float32. - var decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, tf.float32); - var decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0); - var resize_shape = tf.stack(new int[] { input_dim.Item1, input_dim.Item2 }); - var resize_shape_as_int = tf.cast(resize_shape, dtype: tf.int32); - var resized_image = tf.image.resize_bilinear(decoded_image_4d, resize_shape_as_int); - return (jpeg_data, resized_image); - } - - /// - /// Builds a list of training images from the file system. - /// - private Dictionary> create_image_lists() - { - var sub_dirs = tf.gfile.Walk(image_dir) - .Select(x => x.Item1) - .OrderBy(x => x) - .ToArray(); - - var result = new Dictionary>(); - - foreach (var sub_dir in sub_dirs) - { - var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last(); - print($"Looking for images in '{dir_name}'"); - var file_list = Directory.GetFiles(sub_dir); - if (len(file_list) < 20) - print($"WARNING: Folder has less than 20 images, which may cause issues."); - - var label_name = dir_name.ToLower(); - result[label_name] = new Dictionary(); - int testing_count = (int)Math.Floor(file_list.Length * testing_percentage); - int validation_count = (int)Math.Floor(file_list.Length * validation_percentage); - result[label_name]["testing"] = file_list.Take(testing_count).ToArray(); - result[label_name]["validation"] = file_list.Skip(testing_count).Take(validation_count).ToArray(); - result[label_name]["training"] = file_list.Skip(testing_count + validation_count).ToArray(); - } - - return result; - } - - public Graph ImportGraph() - { - Graph graph; - - // Set up the pre-trained graph. - (graph, bottleneck_tensor, resized_image_tensor, wants_quantization) = - create_module_graph(); - - // Add the new layer that we'll be training. - (train_step, cross_entropy, bottleneck_input, - ground_truth_input, final_tensor) = add_final_retrain_ops( - class_count, final_tensor_name, bottleneck_tensor, - wants_quantization, is_training: true); - - return graph; - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - var sw = new Stopwatch(); - - // Initialize all weights: for the module to their pretrained values, - // and for the newly added retraining layer to random initial values. - var init = tf.global_variables_initializer(); - sess.run(init); - - var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding(); - - // We'll make sure we've calculated the 'bottleneck' image summaries and - // cached them on disk. - cache_bottlenecks(sess, image_lists, image_dir, - bottleneck_dir, jpeg_data_tensor, - decoded_image_tensor, resized_image_tensor, - bottleneck_tensor, tfhub_module); - - // Create the operations we need to evaluate the accuracy of our new layer. - var (evaluation_step, _) = add_evaluation_step(final_tensor, ground_truth_input); - - // Merge all the summaries and write them out to the summaries_dir - var merged = tf.summary.merge_all(); - var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph); - var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph); - - // Create a train saver that is used to restore values into an eval graph - // when exporting models. - var train_saver = tf.train.Saver(); - train_saver.save(sess, CHECKPOINT_NAME); - - sw.Restart(); - - for (int i = 0; i < how_many_training_steps; i++) - { - var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks( - sess, image_lists, train_batch_size, "training", - bottleneck_dir, image_dir, jpeg_data_tensor, - decoded_image_tensor, resized_image_tensor, bottleneck_tensor, - tfhub_module); - - // Feed the bottlenecks and ground truth into the graph, and run a training - // step. Capture training summaries for TensorBoard with the `merged` op. - var results = sess.run( - new ITensorOrOperation[] { merged, train_step }, - new FeedItem(bottleneck_input, train_bottlenecks), - new FeedItem(ground_truth_input, train_ground_truth)); - var train_summary = results[0]; - - // TODO - // train_writer.add_summary(train_summary, i); - - // Every so often, print out how well the graph is training. - bool is_last_step = (i + 1 == how_many_training_steps); - if ((i % eval_step_interval) == 0 || is_last_step) - { - (float train_accuracy, float cross_entropy_value) = sess.run((evaluation_step, cross_entropy), - (bottleneck_input, train_bottlenecks), - (ground_truth_input, train_ground_truth)); - print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}"); - - var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks( - sess, image_lists, validation_batch_size, "validation", - bottleneck_dir, image_dir, jpeg_data_tensor, - decoded_image_tensor, resized_image_tensor, bottleneck_tensor, - tfhub_module); - - // Run a validation step and capture training summaries for TensorBoard - // with the `merged` op. - (_, float validation_accuracy) = sess.run((merged, evaluation_step), - (bottleneck_input, validation_bottlenecks), - (ground_truth_input, validation_ground_truth)); - - // validation_writer.add_summary(validation_summary, i); - print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms"); - sw.Restart(); - } - - // Store intermediate results - int intermediate_frequency = intermediate_store_frequency; - if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0) - { - - } - } - - // After training is complete, force one last save of the train checkpoint. - train_saver.save(sess, CHECKPOINT_NAME); - - // We've completed all our training, so run a final test evaluation on - // some new images we haven't used before. - (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists, - jpeg_data_tensor, decoded_image_tensor, resized_image_tensor, - bottleneck_tensor); - - // Write out the trained graph and labels with the weights stored as - // constants. - print($"Save final result to : {output_graph}"); - save_graph_to_file(output_graph, class_count); - File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys)); - } - - /// - /// Prediction - /// labels mapping, it's from output_lables.txt - /// 0 - daisy - /// 1 - dandelion - /// 2 - roses - /// 3 - sunflowers - /// 4 - tulips - /// - /// - public void Predict(Session sess_) - { - if (!File.Exists(output_graph)) - return; - - var labels = File.ReadAllLines(output_labels); - - // predict image - var img_path = Path.Join(image_dir, "daisy", "5547758_eea9edfd54_n.jpg"); - var fileBytes = ReadTensorFromImageFile(img_path); - - // import graph and variables - var graph = new Graph(); - graph.Import(output_graph, ""); - - Tensor input = graph.OperationByName(input_tensor_name); - Tensor output = graph.OperationByName(final_tensor_name); - - using (var sess = tf.Session(graph)) - { - var result = sess.run(output, (input, fileBytes)); - var prob = np.squeeze(result); - var idx = np.argmax(prob); - print($"Prediction result: [{labels[idx]} {prob[idx]}] for {img_path}."); - } - } - - private NDArray ReadTensorFromImageFile(string file_name, - int input_height = 299, - int input_width = 299, - int input_mean = 0, - int input_std = 255) - { - var graph = tf.Graph().as_default(); - - var file_reader = tf.read_file(file_name, "file_reader"); - var image_reader = tf.image.decode_jpeg(file_reader, channels: 3, name: "jpeg_reader"); - var caster = tf.cast(image_reader, tf.float32); - var dims_expander = tf.expand_dims(caster, 0); - var resize = tf.constant(new int[] { input_height, input_width }); - var bilinear = tf.image.resize_bilinear(dims_expander, resize); - var sub = tf.subtract(bilinear, new float[] { input_mean }); - var normalized = tf.divide(sub, new float[] { input_std }); - - using (var sess = tf.Session(graph)) - return sess.run(normalized); - } - - public void Test(Session sess_) - { - if (!File.Exists(output_graph)) - return; - - var graph = new Graph(); - graph.Import(output_graph); - var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding(); - - tf_with(tf.Session(graph), sess => - { - (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists, - jpeg_data_tensor, decoded_image_tensor, resized_image_tensor, - bottleneck_tensor); - }); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs deleted file mode 100644 index 482280ca..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Dataset.cs +++ /dev/null @@ -1,54 +0,0 @@ -using NumSharp; -using System; -using System.Collections.Generic; -using System.IO; -using System.Text; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples.ImageProcessing.YOLO -{ - public class Dataset - { - string annot_path; - int[] input_sizes; - int batch_size; - bool data_aug; - int[] train_input_sizes; - NDArray strides; - NDArray anchors; - Dictionary classes; - int num_classes; - int anchor_per_scale; - int max_bbox_per_scale; - string[] annotations; - int num_samples; - int batch_count; - - public int Length = 0; - - public Dataset(string dataset_type, Config cfg) - { - annot_path = dataset_type == "train" ? cfg.TRAIN.ANNOT_PATH : cfg.TEST.ANNOT_PATH; - input_sizes = dataset_type == "train" ? cfg.TRAIN.INPUT_SIZE : cfg.TEST.INPUT_SIZE; - batch_size = dataset_type == "train" ? cfg.TRAIN.BATCH_SIZE : cfg.TEST.BATCH_SIZE; - data_aug = dataset_type == "train" ? cfg.TRAIN.DATA_AUG : cfg.TEST.DATA_AUG; - train_input_sizes = cfg.TRAIN.INPUT_SIZE; - strides = np.array(cfg.YOLO.STRIDES); - - classes = Utils.read_class_names(cfg.YOLO.CLASSES); - num_classes = classes.Count; - anchors = np.array(Utils.get_anchors(cfg.YOLO.ANCHORS)); - anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; - max_bbox_per_scale = 150; - - annotations = load_annotations(); - num_samples = len(annotations); - batch_count = 0; - } - - string[] load_annotations() - { - return File.ReadAllLines(annot_path); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs deleted file mode 100644 index c3201f8c..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ /dev/null @@ -1,139 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Text; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples.ImageProcessing.YOLO -{ - /// - /// Implementation of YOLO v3 object detector in Tensorflow - /// https://github.com/YunYang1994/tensorflow-yolov3 - /// - public class Main : IExample - { - public bool Enabled { get; set; } = true; - public bool IsImportingGraph { get; set; } = false; - public string Name => "YOLOv3"; - - #region args - Dictionary classes; - int num_classes; - float learn_rate_init; - float learn_rate_end; - int first_stage_epochs; - int second_stage_epochs; - int warmup_periods; - string time; - float moving_ave_decay; - int max_bbox_per_scale; - int steps_per_period; - - Dataset trainset, testset; - - Config cfg; - - Tensor input_data; - Tensor label_sbbox; - Tensor label_mbbox; - Tensor label_lbbox; - Tensor true_sbboxes; - Tensor true_mbboxes; - Tensor true_lbboxes; - Tensor trainable; - - Session sess; - YOLOv3 model; - #endregion - - public bool Run() - { - PrepareData(); - - var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); - - var options = new SessionOptions(); - options.SetConfig(new ConfigProto { AllowSoftPlacement = true }); - using (var sess = tf.Session(graph, opts: options)) - { - Train(sess); - } - - return true; - } - - public void Train(Session sess) - { - - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - var graph = new Graph().as_default(); - - tf_with(tf.name_scope("define_input"), scope => - { - input_data = tf.placeholder(dtype: tf.float32, name: "input_data"); - label_sbbox = tf.placeholder(dtype: tf.float32, name: "label_sbbox"); - label_mbbox = tf.placeholder(dtype: tf.float32, name: "label_mbbox"); - label_lbbox = tf.placeholder(dtype: tf.float32, name: "label_lbbox"); - true_sbboxes = tf.placeholder(dtype: tf.float32, name: "sbboxes"); - true_mbboxes = tf.placeholder(dtype: tf.float32, name: "mbboxes"); - true_lbboxes = tf.placeholder(dtype: tf.float32, name: "lbboxes"); - trainable = tf.placeholder(dtype: tf.@bool, name: "training"); - }); - - tf_with(tf.name_scope("define_loss"), scope => - { - model = new YOLOv3(cfg, input_data, trainable); - }); - - tf_with(tf.name_scope("define_weight_decay"), scope => - { - var moving_ave = tf.train.ExponentialMovingAverage(moving_ave_decay).apply((RefVariable[])tf.trainable_variables()); - }); - - return graph; - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void PrepareData() - { - cfg = new Config(Name); - - string dataDir = Path.Combine(Name, "data"); - Directory.CreateDirectory(dataDir); - - classes = Utils.read_class_names(cfg.YOLO.CLASSES); - num_classes = classes.Count; - - learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT; - learn_rate_end = cfg.TRAIN.LEARN_RATE_END; - first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS; - second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS; - warmup_periods = cfg.TRAIN.WARMUP_EPOCHS; - DateTime now = DateTime.Now; - time = $"{now.Year}-{now.Month}-{now.Day}-{now.Hour}-{now.Minute}-{now.Minute}"; - moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY; - max_bbox_per_scale = 150; - trainset = new Dataset("train", cfg); - testset = new Dataset("test", cfg); - steps_per_period = trainset.Length; - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs deleted file mode 100644 index 3a0d3089..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Utils.cs +++ /dev/null @@ -1,27 +0,0 @@ -using NumSharp; -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Text; - -namespace TensorFlowNET.Examples.ImageProcessing.YOLO -{ - class Utils - { - public static Dictionary read_class_names(string file) - { - var classes = new Dictionary(); - foreach (var line in File.ReadAllLines(file)) - classes[classes.Count] = line; - return classes; - } - - public static NDArray get_anchors(string file) - { - return np.array(File.ReadAllText(file).Split(',') - .Select(x => float.Parse(x)) - .ToArray()).reshape(3, 3, 2); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs deleted file mode 100644 index de5f0acc..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/YOLOv3.cs +++ /dev/null @@ -1,65 +0,0 @@ -using NumSharp; -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples.ImageProcessing.YOLO -{ - public class YOLOv3 - { - Config cfg; - Tensor trainable; - Tensor input_data; - Dictionary classes; - int num_class; - NDArray strides; - NDArray anchors; - int anchor_per_scale; - float iou_loss_thresh; - string upsample_method; - Tensor conv_lbbox; - Tensor conv_mbbox; - Tensor conv_sbbox; - - public YOLOv3(Config cfg_, Tensor input_data_, Tensor trainable_) - { - cfg = cfg_; - input_data = input_data_; - trainable = trainable_; - classes = Utils.read_class_names(cfg.YOLO.CLASSES); - num_class = len(classes); - strides = np.array(cfg.YOLO.STRIDES); - anchors = Utils.get_anchors(cfg.YOLO.ANCHORS); - anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE; - iou_loss_thresh = cfg.YOLO.IOU_LOSS_THRESH; - upsample_method = cfg.YOLO.UPSAMPLE_METHOD; - - (conv_lbbox, conv_mbbox, conv_sbbox) = __build_nework(input_data); - - tf_with(tf.variable_scope("pred_sbbox"), scope => - { - // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); - }); - - tf_with(tf.variable_scope("pred_mbbox"), scope => - { - // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); - }); - - tf_with(tf.variable_scope("pred_lbbox"), scope => - { - // pred_sbbox = decode(conv_sbbox, anchors[0], strides[0]); - }); - } - - private (Tensor, Tensor, Tensor) __build_nework(Tensor input_data) - { - Tensor route_1, route_2; - (route_1, route_2, input_data) = backbone.darknet53(input_data, trainable); - - return (conv_lbbox, conv_mbbox, conv_sbbox); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs deleted file mode 100644 index 0e7b1446..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/backbone.cs +++ /dev/null @@ -1,28 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples.ImageProcessing.YOLO -{ - class backbone - { - public static (Tensor, Tensor, Tensor) darknet53(Tensor input_data, Tensor trainable) - { - return tf_with(tf.variable_scope("darknet"), scope => - { - input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 3, 32 }, trainable: trainable, name: "conv0"); - input_data = common.convolutional(input_data, filters_shape: new int[] { 3, 3, 32, 64 }, trainable: trainable, name: "conv1", downsample: true); - - foreach (var i in range(1)) - input_data = common.residual_block(input_data, 64, 32, 64, trainable: trainable, name: $"residual{i + 0}"); - - var route_1 = input_data; - var route_2 = input_data; - - return (route_1, route_2, input_data); - }); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs deleted file mode 100644 index 57105aa1..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs +++ /dev/null @@ -1,72 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples.ImageProcessing.YOLO -{ - class common - { - public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable, - string name, bool downsample = false, bool activate = true, - bool bn = true) - { - return tf_with(tf.variable_scope(name), scope => - { - int[] strides; - string padding; - - if (downsample) - { - throw new NotImplementedException(""); - } - else - { - strides = new int[] { 1, 1, 1, 1 }; - padding = "SAME"; - } - - var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true, - shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f)); - - var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding); - - if (bn) - { - conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer, - gamma_initializer: tf.ones_initializer, - moving_mean_initializer: tf.zeros_initializer, - moving_variance_initializer: tf.ones_initializer, training: trainable); - } - else - { - throw new NotImplementedException(""); - } - - if (activate) - conv = tf.nn.leaky_relu(conv, alpha: 0.1f); - - return conv; - }); - } - - public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1, - int filter_num2, Tensor trainable, string name) - { - var short_cut = input_data; - - return tf_with(tf.variable_scope(name), scope => - { - input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 }, - trainable: trainable, name: "conv1"); - input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 }, - trainable: trainable, name: "conv2"); - - var residual_output = input_data + short_cut; - - return residual_output; - }); - } - } -} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs deleted file mode 100644 index b5c46151..00000000 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/config.cs +++ /dev/null @@ -1,94 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Text; - -namespace TensorFlowNET.Examples.ImageProcessing.YOLO -{ - public class Config - { - public YoloConfig YOLO; - public TrainConfig TRAIN; - public TestConfig TEST; - - public Config(string root) - { - YOLO = new YoloConfig(root); - TRAIN = new TrainConfig(root); - TEST = new TestConfig(root); - } - - public class YoloConfig - { - string _root; - - public string CLASSES; - public string ANCHORS; - public float MOVING_AVE_DECAY = 0.9995f; - public int[] STRIDES = new int[] { 8, 16, 32 }; - public int ANCHOR_PER_SCALE = 3; - public float IOU_LOSS_THRESH = 0.5f; - public string UPSAMPLE_METHOD = "resize"; - public string ORIGINAL_WEIGHT; - public string DEMO_WEIGHT; - - public YoloConfig(string root) - { - _root = root; - CLASSES = Path.Combine(_root, "data", "classes", "coco.names"); - ANCHORS = Path.Combine(_root, "data", "anchors", "basline_anchors.txt"); - ORIGINAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco.ckpt"); - DEMO_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); - } - } - - public class TrainConfig - { - string _root; - - public int BATCH_SIZE = 6; - public int[] INPUT_SIZE = new int[] { 320, 352, 384, 416, 448, 480, 512, 544, 576, 608 }; - public bool DATA_AUG = true; - public float LEARN_RATE_INIT = 1e-4f; - public float LEARN_RATE_END = 1e-6f; - public int WARMUP_EPOCHS = 2; - public int FISRT_STAGE_EPOCHS = 20; - public int SECOND_STAGE_EPOCHS = 30; - public string INITIAL_WEIGHT; - public string ANNOT_PATH; - - public TrainConfig(string root) - { - _root = root; - INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt"); - ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); - } - } - - public class TestConfig - { - string _root; - - public int BATCH_SIZE = 2; - public int[] INPUT_SIZE = new int[] { 544 }; - public bool DATA_AUG = false; - public bool WRITE_IMAGE = true; - public string WRITE_IMAGE_PATH; - public string WEIGHT_FILE; - public bool WRITE_IMAGE_SHOW_LABEL = true; - public bool SHOW_LABEL = true; - public int SECOND_STAGE_EPOCHS = 30; - public float SCORE_THRESHOLD = 0.3f; - public float IOU_THRESHOLD = 0.45f; - public string ANNOT_PATH; - - public TestConfig(string root) - { - _root = root; - ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_test.txt"); - WRITE_IMAGE_PATH = Path.Combine(_root, "data", "detection"); - WEIGHT_FILE = Path.Combine(_root, "checkpoint", "yolov3_test_loss=9.2099.ckpt-5"); - } - } - } -} diff --git a/test/TensorFlowNET.Examples/Keras.cs b/test/TensorFlowNET.Examples/Keras.cs deleted file mode 100644 index aa17e9b2..00000000 --- a/test/TensorFlowNET.Examples/Keras.cs +++ /dev/null @@ -1,99 +0,0 @@ -using System; -using System.Collections.Generic; -using Tensorflow; -using Keras.Layers; -using NumSharp; -using Keras; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - public class Keras : IExample - { - public bool Enabled { get; set; } = true; - public bool IsImportingGraph { get; set; } = false; - - public string Name => "Keras"; - - public bool Run() - { - Console.WriteLine("================================== Keras =================================="); - - #region data - var batch_size = 1000; - var (X, Y) = XOR(batch_size); - //var (X, Y, batch_size) = (np.array(new float[,]{{1, 0 },{1, 1 },{0, 0 },{0, 1 }}), np.array(new int[] { 0, 1, 1, 0 }), 4); - #endregion - - #region features - var (features, labels) = (new Tensor(X), new Tensor(Y)); - var num_steps = 10000; - #endregion - - #region model - var m = new Model(); - - //m.Add(new Dense(8, name: "Hidden", activation: tf.nn.relu())).Add(new Dense(1, name:"Output")); - - m.Add( - new ILayer[] { - new Dense(8, name: "Hidden_1", activation: tf.nn.relu()), - new Dense(1, name: "Output") - }); - - m.train(num_steps, (X, Y)); - #endregion - - return true; - } - - static (NDArray, NDArray) XOR(int samples) - { - var X = new List(); - var Y = new List(); - var r = new Random(); - for (int i = 0; i < samples; i++) - { - var x1 = (float)r.Next(0, 2); - var x2 = (float)r.Next(0, 2); - var y = 0.0f; - if (x1 == x2) - y = 1.0f; - X.Add(new float[] { x1, x2 }); - Y.Add(y); - } - - return (np.array(X.ToArray()), np.array(Y.ToArray())); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void PrepareData() - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs deleted file mode 100644 index 8679a730..00000000 --- a/test/TensorFlowNET.Examples/Program.cs +++ /dev/null @@ -1,120 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Drawing; -using System.Linq; -using System.Reflection; -using Tensorflow; -using Console = Colorful.Console; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - class Program - { - static void Main(string[] args) - { - int finished = 0; - var errors = new List(); - var success = new List(); - - var parsedArgs = ParseArgs(args); - - var examples = Assembly.GetEntryAssembly().GetTypes() - .Where(x => x.GetInterfaces().Contains(typeof(IExample))) - .Select(x => (IExample)Activator.CreateInstance(x)) - .Where(x => x.Enabled) - .OrderBy(x => x.Name) - .ToArray(); - - if (parsedArgs.ContainsKey("ex")) - examples = examples.Where(x => x.Name == parsedArgs["ex"]).ToArray(); - - Console.WriteLine(Environment.OSVersion.ToString(), Color.Yellow); - Console.WriteLine($"TensorFlow Binary v{tf.VERSION}", Color.Yellow); - Console.WriteLine($"TensorFlow.NET v{Assembly.GetAssembly(typeof(TF_DataType)).GetName().Version}", Color.Yellow); - - for (var i = 0; i < examples.Length; i++) - Console.WriteLine($"[{i}]: {examples[i].Name}"); - - var key = "0"; - - if (examples.Length > 1) - { - Console.Write($"Choose one example to run, hit [Enter] to run all: ", Color.Yellow); - key = Console.ReadLine(); - } - - var sw = new Stopwatch(); - for (var i = 0; i < examples.Length; i++) - { - if (i.ToString() != key && key != "") continue; - - var example = examples[i]; - Console.WriteLine($"{DateTime.UtcNow} Starting {example.Name}", Color.White); - - try - { - sw.Restart(); - bool isSuccess = example.Run(); - sw.Stop(); - - if (isSuccess) - success.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s"); - else - errors.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s"); - } - catch (Exception ex) - { - errors.Add($"Example: {example.Name}"); - Console.WriteLine(ex); - } - - finished++; - Console.WriteLine($"{DateTime.UtcNow} Completed {example.Name}", Color.White); - } - - success.ForEach(x => Console.WriteLine($"{x} is OK!", Color.Green)); - errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red)); - - Console.WriteLine($"{finished} of {examples.Length} example(s) are completed."); - Console.ReadLine(); - } - - private static Dictionary ParseArgs(string[] args) - { - var parsed = new Dictionary(); - - for (int i = 0; i < args.Length; i++) - { - string key = args[i].Substring(1); - switch (key) - { - case "ex": - parsed.Add(key, args[++i]); - break; - default: - break; - } - } - - return parsed; - } - } -} diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj deleted file mode 100644 index 55e9b27d..00000000 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj +++ /dev/null @@ -1,32 +0,0 @@ - - - - Exe - netcoreapp2.2 - false - - - - bin\debug-gpu - - - - bin\release-gpu - - - - - - - - - - - - - - - - - - diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj deleted file mode 100644 index c675bedc..00000000 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ /dev/null @@ -1,28 +0,0 @@ - - - - Exe - netcoreapp2.2 - false - - - - DEBUG;TRACE - - - - - - - - - - - - - - - - - - diff --git a/test/TensorFlowNET.Examples/TextProcessing/BinaryTextClassification.cs b/test/TensorFlowNET.Examples/TextProcessing/BinaryTextClassification.cs deleted file mode 100644 index dbfbc37d..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/BinaryTextClassification.cs +++ /dev/null @@ -1,164 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using Tensorflow; -using Newtonsoft.Json; -using System.Linq; -using NumSharp; - -namespace TensorFlowNET.Examples -{ - /// - /// This example classifies movie reviews as positive or negative using the text of the review. - /// This is a binary—or two-class—classification, an important and widely applicable kind of machine learning problem. - /// https://github.com/tensorflow/docs/blob/master/site/en/tutorials/keras/basic_text_classification.ipynb - /// - public class BinaryTextClassification : IExample - { - public bool Enabled { get; set; } = false; - public string Name => "Binary Text Classification"; - public bool IsImportingGraph { get; set; } = true; - - string dir = "binary_text_classification"; - string dataFile = "imdb.zip"; - NDArray train_data, train_labels, test_data, test_labels; - - public bool Run() - { - PrepareData(); - - Console.WriteLine($"Training entries: {train_data.shape[0]}, labels: {train_labels.shape[0]}"); - - // A dictionary mapping words to an integer index - var word_index = GetWordIndex(); - - /*train_data = keras.preprocessing.sequence.pad_sequences(train_data, - value: word_index[""], - padding: "post", - maxlen: 256); - - test_data = keras.preprocessing.sequence.pad_sequences(test_data, - value: word_index[""], - padding: "post", - maxlen: 256);*/ - - // input shape is the vocabulary count used for the movie reviews (10,000 words) - int vocab_size = 10000; - - var model = keras.Sequential(); - var layer = keras.layers.Embedding(vocab_size, 16); - model.add(layer); - - return false; - } - - public void PrepareData() - { - Directory.CreateDirectory(dir); - - // get model file - string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}"; - - Utility.Web.Download(url, dir, "imdb.zip"); - Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir); - - // prepare training dataset - var x_train = ReadData(Path.Join(dir, "x_train.txt")); - var labels_train = ReadData(Path.Join(dir, "y_train.txt")); - var indices_train = ReadData(Path.Join(dir, "indices_train.txt")); - x_train = x_train[indices_train]; - labels_train = labels_train[indices_train]; - - var x_test = ReadData(Path.Join(dir, "x_test.txt")); - var labels_test = ReadData(Path.Join(dir, "y_test.txt")); - var indices_test = ReadData(Path.Join(dir, "indices_test.txt")); - x_test = x_test[indices_test]; - labels_test = labels_test[indices_test]; - - // not completed - var xs = x_train.hstack(x_test); - var labels = labels_train.hstack(labels_test); - - var idx = x_train.size; - var y_train = labels_train; - var y_test = labels_test; - - // convert x_train - train_data = new NDArray(np.int32, (x_train.size, 256)); - /*for (int i = 0; i < x_train.size; i++) - train_data[i] = x_train[i].Data()[1].Split(',').Select(x => int.Parse(x)).ToArray();*/ - - test_data = new NDArray(np.int32, (x_test.size, 256)); - /*for (int i = 0; i < x_test.size; i++) - test_data[i] = x_test[i].Data()[1].Split(',').Select(x => int.Parse(x)).ToArray();*/ - - train_labels = y_train; - test_labels = y_test; - } - - private NDArray ReadData(string file) - { - var lines = File.ReadAllLines(file); - var nd = new NDArray(lines[0].StartsWith("[") ? typeof(string) : np.int32, new Shape(lines.Length)); - - if (lines[0].StartsWith("[")) - { - for (int i = 0; i < lines.Length; i++) - { - /*var matches = Regex.Matches(lines[i], @"\d+\s*"); - var data = new int[matches.Count]; - for (int j = 0; j < data.Length; j++) - data[j] = Convert.ToInt32(matches[j].Value); - nd[i] = data.ToArray();*/ - nd[i] = lines[i].Substring(1, lines[i].Length - 2).Replace(" ", string.Empty); - } - } - else - { - for (int i = 0; i < lines.Length; i++) - nd[i] = Convert.ToInt32(lines[i]); - } - return nd; - } - - private Dictionary GetWordIndex() - { - var result = new Dictionary(); - var json = File.ReadAllText(Path.Join(dir, "imdb_word_index.json")); - var dict = JsonConvert.DeserializeObject>(json); - - dict.Keys.Select(k => result[k] = dict[k] + 3).ToList(); - result[""] = 0; - result[""] = 1; - result[""] = 2; // unknown - result[""] = 3; - - return result; - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs deleted file mode 100644 index 3519d972..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs +++ /dev/null @@ -1,267 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using NumSharp; -using Tensorflow; -using Tensorflow.Sessions; -using TensorFlowNET.Examples.Text; -using TensorFlowNET.Examples.Utility; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// https://github.com/dongjun-Lee/text-classification-models-tf - /// - public class CnnTextClassification : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "CNN Text Classification"; - public int? DataLimit = null; - public bool IsImportingGraph { get; set; } = false; - - const string dataDir = "cnn_text"; - string dataFileName = "dbpedia_csv.tar.gz"; - - string TRAIN_PATH = $"{dataDir}/dbpedia_csv/train.csv"; - string TEST_PATH = $"{dataDir}/dbpedia_csv/test.csv"; - - int NUM_CLASS = 14; - int BATCH_SIZE = 64; - int NUM_EPOCHS = 10; - int WORD_MAX_LEN = 100; - int CHAR_MAX_LEN = 1014; - - float loss_value = 0; - double max_accuracy = 0; - - int alphabet_size = -1; - int vocabulary_size = -1; - NDArray train_x, valid_x, train_y, valid_y; - - ITextModel textModel; - public string ModelName = "word_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn - - public bool Run() - { - PrepareData(); - var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); - - using (var sess = tf.Session(graph)) - Train(sess); - - return max_accuracy > 0.9; - } - - // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here - private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f) - { - Console.WriteLine("Splitting in Training and Testing data..."); - int len = x.shape[0]; - //int classes = y.Data().Distinct().Count(); - //int samples = len / classes; - int train_size = (int)Math.Round(len * (1 - test_size)); - train_x = x[new Slice(stop: train_size), new Slice()]; - valid_x = x[new Slice(start: train_size), new Slice()]; - train_y = y[new Slice(stop: train_size)]; - valid_y = y[new Slice(start: train_size)]; - Console.WriteLine("\tDONE"); - - return (train_x, valid_x, train_y, valid_y); - } - - private void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels) - { - int i = 0; - var label_keys = labels.Keys.ToArray(); - while (i < shuffled_x.Length) - { - var key = label_keys[random.Next(label_keys.Length)]; - var set = labels[key]; - var index = set.First(); - if (set.Count == 0) - { - labels.Remove(key); // remove the set as it is empty - label_keys = labels.Keys.ToArray(); - } - shuffled_x[i] = x[index]; - shuffled_y[i] = y[index]; - i++; - } - } - - private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) - { - var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1; - var total_batches = num_batches_per_epoch * num_epochs; - foreach (var epoch in range(num_epochs)) - { - foreach (var batch_num in range(num_batches_per_epoch)) - { - var start_index = batch_num * batch_size; - var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); - if (end_index <= start_index) - break; - yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches); - } - } - } - - public void PrepareData() - { - // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz - var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; - Web.Download(url, dataDir, "dbpedia_subset.zip"); - Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); - - Console.WriteLine("Building dataset..."); - var (x, y) = (new int[0][], new int[0]); - - if(ModelName == "char_cnn") - { - (x, y, alphabet_size) = DataHelpers.build_char_dataset(TRAIN_PATH, "char_cnn", CHAR_MAX_LEN); - } - else - { - var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); - vocabulary_size = len(word_dict); - (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); - } - - Console.WriteLine("\tDONE "); - - var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); - Console.WriteLine("Training set size: " + train_x.shape[0]); - Console.WriteLine("Test set size: " + valid_x.shape[0]); - } - - public Graph ImportGraph() - { - var graph = tf.Graph().as_default(); - - // download graph meta data - var meta_file = "word_cnn.meta"; - var meta_path = Path.Combine("graph", meta_file); - if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11)) - { - // delete old cached file which contains errors - Console.WriteLine("Discarding cached file: " + meta_path); - if(File.Exists(meta_path)) - File.Delete(meta_path); - } - var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; - Web.Download(url, "graph", meta_file); - - Console.WriteLine("Import graph..."); - tf.train.import_meta_graph(Path.Join("graph", meta_file)); - Console.WriteLine("\tDONE "); - - return graph; - } - - public Graph BuildGraph() - { - var graph = tf.Graph().as_default(); - - switch (ModelName) - { - case "word_cnn": - textModel = new WordCnn(vocabulary_size, WORD_MAX_LEN, NUM_CLASS); - break; - case "char_cnn": - textModel = new CharCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); - break; - } - - return graph; - } - - public void Train(Session sess) - { - var graph = tf.get_default_graph(); - var stopwatch = Stopwatch.StartNew(); - - sess.run(tf.global_variables_initializer()); - var saver = tf.train.Saver(tf.global_variables()); - - var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); - var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1; - - Tensor is_training = graph.OperationByName("is_training"); - Tensor model_x = graph.OperationByName("x"); - Tensor model_y = graph.OperationByName("y"); - Tensor loss = graph.OperationByName("loss/Mean"); - Operation optimizer = graph.OperationByName("loss/Adam"); - Tensor global_step = graph.OperationByName("Variable"); - Tensor accuracy = graph.OperationByName("accuracy/accuracy"); - stopwatch = Stopwatch.StartNew(); - int step = 0; - foreach (var (x_batch, y_batch, total) in train_batches) - { - (_, step, loss_value) = sess.run((optimizer, global_step, loss), - (model_x, x_batch), (model_y, y_batch), (is_training, true)); - if (step == 1 || step % 10 == 0) - Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")}."); - - if (step % 100 == 0) - { - // Test accuracy with validation data for each epoch. - var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); - var (sum_accuracy, cnt) = (0.0f, 0); - foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches) - { - var valid_feed_dict = new FeedDict - { - [model_x] = valid_x_batch, - [model_y] = valid_y_batch, - [is_training] = false - }; - float accuracy_value = sess.run(accuracy, (model_x, valid_x_batch), (model_y, valid_y_batch), (is_training, false)); - sum_accuracy += accuracy_value; - cnt += 1; - } - - var valid_accuracy = sum_accuracy / cnt; - - print($"\nValidation Accuracy = {valid_accuracy.ToString("P")}\n"); - - // Save model - if (valid_accuracy > max_accuracy) - { - max_accuracy = valid_accuracy; - saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step); - print("Model is saved.\n"); - } - } - } - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/DataHelpers.cs b/test/TensorFlowNET.Examples/TextProcessing/DataHelpers.cs deleted file mode 100644 index 92d60fcb..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/DataHelpers.cs +++ /dev/null @@ -1,218 +0,0 @@ -using NumSharp; -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Security.Cryptography; -using System.Text; -using System.Text.RegularExpressions; -using TensorFlowNET.Examples.Utility; - -namespace TensorFlowNET.Examples -{ - public class DataHelpers - { - public static Dictionary build_word_dict(string path) - { - var contents = File.ReadAllLines(path); - - var words = new List(); - foreach (var content in contents) - words.AddRange(clean_str(content).Split(' ').Where(x => x.Length > 1)); - var word_counter = words.GroupBy(x => x) - .Select(x => new { Word = x.Key, Count = x.Count() }) - .OrderByDescending(x => x.Count) - .ToArray(); - - var word_dict = new Dictionary(); - word_dict[""] = 0; - word_dict[""] = 1; - word_dict[""] = 2; - foreach (var word in word_counter) - word_dict[word.Word] = word_dict.Count; - - return word_dict; - } - - public static (int[][], int[]) build_word_dataset(string path, Dictionary word_dict, int document_max_len) - { - var contents = File.ReadAllLines(path); - var x = contents.Select(c => (clean_str(c) + " ") - .Split(' ').Take(document_max_len) - .Select(w => word_dict.ContainsKey(w) ? word_dict[w] : word_dict[""]).ToArray()) - .ToArray(); - - for (int i = 0; i < x.Length; i++) - if (x[i].Length == document_max_len) - x[i][document_max_len - 1] = word_dict[""]; - else - Array.Resize(ref x[i], document_max_len); - - var y = contents.Select(c => int.Parse(c.Substring(0, c.IndexOf(','))) - 1).ToArray(); - - return (x, y); - } - - public static (int[][], int[], int) build_char_dataset(string path, string model, int document_max_len, int? limit = null, bool shuffle=true) - { - string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} "; - /*if (step == "train") - df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/ - var char_dict = new Dictionary(); - char_dict[""] = 0; - char_dict[""] = 1; - foreach (char c in alphabet) - char_dict[c.ToString()] = char_dict.Count; - var contents = File.ReadAllLines(path); - if (shuffle) - new Random(17).Shuffle(contents); - //File.WriteAllLines("text_classification/dbpedia_csv/train_6400.csv", contents.Take(6400)); - var size = limit == null ? contents.Length : limit.Value; - - var x = new int[size][]; - var y = new int[size]; - var tenth = size / 10; - var percent = 0; - for (int i = 0; i < size; i++) - { - if ((i + 1) % tenth == 0) - { - percent += 10; - Console.WriteLine($"\t{percent}%"); - } - - string[] parts = contents[i].ToLower().Split(",\"").ToArray(); - string content = parts[2]; - content = content.Substring(0, content.Length - 1); - var a = new int[document_max_len]; - for (int j = 0; j < document_max_len; j++) - { - if (j >= content.Length) - a[j] = char_dict[""]; - else - a[j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict[""]; - } - x[i] = a; - y[i] = int.Parse(parts[0]); - } - - return (x, y, alphabet.Length + 2); - } - - /// - /// Loads MR polarity data from files, splits the data into words and generates labels. - /// Returns split sentences and labels. - /// - /// - /// - /// - public static (string[], NDArray) load_data_and_labels(string positive_data_file, string negative_data_file) - { - Directory.CreateDirectory("CnnTextClassification"); - Utility.Web.Download(positive_data_file, "CnnTextClassification", "rt -polarity.pos"); - Utility.Web.Download(negative_data_file, "CnnTextClassification", "rt-polarity.neg"); - - // Load data from files - var positive_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.pos") - .Select(x => x.Trim()) - .ToArray(); - - var negative_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.neg") - .Select(x => x.Trim()) - .ToArray(); - - var x_text = new List(); - x_text.AddRange(positive_examples); - x_text.AddRange(negative_examples); - x_text = x_text.Select(x => clean_str(x)).ToList(); - - var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray(); - var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray(); - var y = np.concatenate(new NDArray[] { new int[][][] { positive_labels, negative_labels } }); - return (x_text.ToArray(), y); - } - - private static string clean_str(string str) - { - str = Regex.Replace(str, "[^A-Za-z0-9(),!?]", " "); - str = Regex.Replace(str, ",", " "); - return str; - } - - /// - /// Padding - /// - /// - /// the char to pad with - /// a list of list where each sublist has same length - public static (int[][], int[]) pad_sequences(int[][] sequences, int pad_tok = 0) - { - int max_length = sequences.Select(x => x.Length).Max(); - return _pad_sequences(sequences, pad_tok, max_length); - } - - public static (int[][][], int[][]) pad_sequences(int[][][] sequences, int pad_tok = 0) - { - int max_length_word = sequences.Select(x => x.Select(w => w.Length).Max()).Max(); - int[][][] sequence_padded; - var sequence_length = new int[sequences.Length][]; - for (int i = 0; i < sequences.Length; i++) - { - // all words are same length now - var (sp, sl) = _pad_sequences(sequences[i], pad_tok, max_length_word); - sequence_length[i] = sl; - } - - int max_length_sentence = sequences.Select(x => x.Length).Max(); - (sequence_padded, _) = _pad_sequences(sequences, np.repeat(pad_tok, max_length_word).GetData().ToArray(), max_length_sentence); - (sequence_length, _) = _pad_sequences(sequence_length, 0, max_length_sentence); - - return (sequence_padded, sequence_length); - } - - private static (int[][], int[]) _pad_sequences(int[][] sequences, int pad_tok, int max_length) - { - var sequence_length = new int[sequences.Length]; - for (int i = 0; i < sequences.Length; i++) - { - sequence_length[i] = sequences[i].Length; - Array.Resize(ref sequences[i], max_length); - } - - return (sequences, sequence_length); - } - - private static (int[][][], int[]) _pad_sequences(int[][][] sequences, int[] pad_tok, int max_length) - { - var sequence_length = new int[sequences.Length]; - for (int i = 0; i < sequences.Length; i++) - { - sequence_length[i] = sequences[i].Length; - Array.Resize(ref sequences[i], max_length); - for (int j = 0; j < max_length - sequence_length[i]; j++) - { - sequences[i][max_length - j - 1] = new int[pad_tok.Length]; - Array.Copy(pad_tok, sequences[i][max_length - j - 1], pad_tok.Length); - } - } - - return (sequences, sequence_length); - } - - public static string CalculateMD5Hash(string input) - { - // step 1, calculate MD5 hash from input - MD5 md5 = System.Security.Cryptography.MD5.Create(); - byte[] inputBytes = System.Text.Encoding.ASCII.GetBytes(input); - byte[] hash = md5.ComputeHash(inputBytes); - - // step 2, convert byte array to hex string - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < hash.Length; i++) - { - sb.Append(hash[i].ToString("X2")); - } - return sb.ToString(); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/NER/BiLstmCrfNer.cs b/test/TensorFlowNET.Examples/TextProcessing/NER/BiLstmCrfNer.cs deleted file mode 100644 index 18c6e46f..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/NER/BiLstmCrfNer.cs +++ /dev/null @@ -1,59 +0,0 @@ -using System; -using System.IO; -using Tensorflow; -using Tensorflow.Estimator; - -namespace TensorFlowNET.Examples -{ - /// - /// Bidirectional LSTM-CRF Models for Sequence Tagging - /// https://github.com/guillaumegenthial/tf_ner/tree/master/models/lstm_crf - /// - public class BiLstmCrfNer : IExample - { - public bool Enabled { get; set; } = false; - public bool IsImportingGraph { get; set; } = false; - - public string Name => "bi-LSTM + CRF NER"; - - public bool Run() - { - PrepareData(); - return false; - } - - public void PrepareData() - { - var hp = new HyperParams("BiLstmCrfNer"); - hp.filepath_words = Path.Combine(hp.data_root_dir, "vocab.words.txt"); - hp.filepath_chars = Path.Combine(hp.data_root_dir, "vocab.chars.txt"); - hp.filepath_tags = Path.Combine(hp.data_root_dir, "vocab.tags.txt"); - hp.filepath_glove = Path.Combine(hp.data_root_dir, "glove.npz"); - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/NER/CRF.cs b/test/TensorFlowNET.Examples/TextProcessing/NER/CRF.cs deleted file mode 100644 index 7819dcef..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/NER/CRF.cs +++ /dev/null @@ -1,55 +0,0 @@ -using System; -using Tensorflow; - -namespace TensorFlowNET.Examples -{ - /// - /// The CRF module implements a linear-chain CRF layer for learning to predict tag sequences. - /// This variant of the CRF is factored into unary potentials for every element - /// in the sequence and binary potentials for every transition between output tags. - /// - /// tensorflow\contrib\crf\python\ops\crf.py - /// - public class CRF : IExample - { - public bool Enabled { get; set; } = true; - public bool IsImportingGraph { get; set; } = false; - - public string Name => "CRF"; - - public bool Run() - { - return true; - } - - public void PrepareData() - { - - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/NER/LstmCrfNer.cs b/test/TensorFlowNET.Examples/TextProcessing/NER/LstmCrfNer.cs deleted file mode 100644 index 42eab6c4..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/NER/LstmCrfNer.cs +++ /dev/null @@ -1,233 +0,0 @@ -using NumSharp; -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using Tensorflow; -using Tensorflow.Estimator; -using TensorFlowNET.Examples.Utility; -using static Tensorflow.Binding; -using static TensorFlowNET.Examples.DataHelpers; - -namespace TensorFlowNET.Examples.Text.NER -{ - /// - /// A NER model using Tensorflow (LSTM + CRF + chars embeddings). - /// State-of-the-art performance (F1 score between 90 and 91). - /// - /// https://github.com/guillaumegenthial/sequence_tagging - /// - public class LstmCrfNer : IExample - { - public bool Enabled { get; set; } = true; - public bool IsImportingGraph { get; set; } = true; - - public string Name => "LSTM + CRF NER"; - - HyperParams hp; - - int nwords, nchars, ntags; - CoNLLDataset dev, train; - - Tensor word_ids_tensor; - Tensor sequence_lengths_tensor; - Tensor char_ids_tensor; - Tensor word_lengths_tensor; - Tensor labels_tensor; - Tensor dropout_tensor; - Tensor lr_tensor; - Operation train_op; - Tensor loss; - Tensor merged; - - public bool Run() - { - PrepareData(); - var graph = tf.Graph().as_default(); - - tf.train.import_meta_graph("graph/lstm_crf_ner.meta"); - - float loss_value = 0f; - - //add_summary(); - word_ids_tensor = graph.OperationByName("word_ids"); - sequence_lengths_tensor = graph.OperationByName("sequence_lengths"); - char_ids_tensor = graph.OperationByName("char_ids"); - word_lengths_tensor = graph.OperationByName("word_lengths"); - labels_tensor = graph.OperationByName("labels"); - dropout_tensor = graph.OperationByName("dropout"); - lr_tensor = graph.OperationByName("lr"); - train_op = graph.OperationByName("train_step/Adam"); - loss = graph.OperationByName("Mean"); - //merged = graph.OperationByName("Merge/MergeSummary"); - - var init = tf.global_variables_initializer(); - - using (var sess = tf.Session()) - { - sess.run(init); - - foreach (var epoch in range(hp.epochs)) - { - Console.Write($"Epoch {epoch + 1} out of {hp.epochs}, "); - loss_value = run_epoch(sess, train, dev, epoch); - print($"train loss: {loss_value}"); - } - } - - return loss_value < 0.1; - } - - private float run_epoch(Session sess, CoNLLDataset train, CoNLLDataset dev, int epoch) - { - float accuracy = 0; - // iterate over dataset - var batches = minibatches(train, hp.batch_size); - foreach (var(words, labels) in batches) - { - var (fd, _) = get_feed_dict(words, labels, hp.lr, hp.dropout); - (_, accuracy) = sess.run((train_op, loss), feed_dict: fd); - } - - return accuracy; - } - - private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size) - { - var x_batch = new List<(int[][], int[])>(); - var y_batch = new List(); - foreach(var (x, y) in data.GetItems()) - { - if (len(y_batch) == minibatch_size) - { - yield return (x_batch.ToArray(), y_batch.ToArray()); - x_batch.Clear(); - y_batch.Clear(); - } - - var x3 = (x.Select(x1 => x1.Item1).ToArray(), x.Select(x2 => x2.Item2).ToArray()); - x_batch.Add(x3); - y_batch.Add(y); - } - - if (len(y_batch) > 0) - yield return (x_batch.ToArray(), y_batch.ToArray()); - } - - /// - /// Given some data, pad it and build a feed dictionary - /// - /// - /// list of sentences. A sentence is a list of ids of a list of - /// words. A word is a list of ids - /// - /// list of ids - /// learning rate - /// keep prob - private (FeedItem[], int[]) get_feed_dict((int[][], int[])[] words, int[][] labels, float lr = 0f, float dropout = 0f) - { - int[] sequence_lengths; - int[][] word_lengths; - int[][] word_ids; - int[][][] char_ids; - - if (true) // use_chars - { - (char_ids, word_ids) = (words.Select(x => x.Item1).ToArray(), words.Select(x => x.Item2).ToArray()); - (word_ids, sequence_lengths) = pad_sequences(word_ids, pad_tok: 0); - (char_ids, word_lengths) = pad_sequences(char_ids, pad_tok: 0); - } - - // build feed dictionary - var feeds = new List(); - feeds.Add(new FeedItem(word_ids_tensor, np.array(word_ids))); - feeds.Add(new FeedItem(sequence_lengths_tensor, np.array(sequence_lengths))); - - if(true) // use_chars - { - feeds.Add(new FeedItem(char_ids_tensor, np.array(char_ids))); - feeds.Add(new FeedItem(word_lengths_tensor, np.array(word_lengths))); - } - - (labels, _) = pad_sequences(labels, 0); - feeds.Add(new FeedItem(labels_tensor, np.array(labels))); - - feeds.Add(new FeedItem(lr_tensor, lr)); - - feeds.Add(new FeedItem(dropout_tensor, dropout)); - - return (feeds.ToArray(), sequence_lengths); - } - - public void PrepareData() - { - hp = new HyperParams("LstmCrfNer") - { - epochs = 50, - dropout = 0.5f, - batch_size = 20, - lr_method = "adam", - lr = 0.001f, - lr_decay = 0.9f, - clip = false, - epoch_no_imprv = 3, - hidden_size_char = 100, - hidden_size_lstm = 300 - }; - hp.filepath_dev = hp.filepath_test = hp.filepath_train = Path.Combine(hp.data_root_dir, "test.txt"); - - // Loads vocabulary, processing functions and embeddings - hp.filepath_words = Path.Combine(hp.data_root_dir, "words.txt"); - hp.filepath_tags = Path.Combine(hp.data_root_dir, "tags.txt"); - hp.filepath_chars = Path.Combine(hp.data_root_dir, "chars.txt"); - - string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/lstm_crf_ner.zip"; - Web.Download(url, hp.data_root_dir, "lstm_crf_ner.zip"); - Compress.UnZip(Path.Combine(hp.data_root_dir, "lstm_crf_ner.zip"), hp.data_root_dir); - - // 1. vocabulary - /*vocab_tags = load_vocab(hp.filepath_tags); - - - nwords = vocab_words.Count; - nchars = vocab_chars.Count; - ntags = vocab_tags.Count;*/ - - // 2. get processing functions that map str -> id - dev = new CoNLLDataset(hp.filepath_dev, hp); - train = new CoNLLDataset(hp.filepath_train, hp); - - // download graph meta data - var meta_file = "lstm_crf_ner.meta"; - var meta_path = Path.Combine("graph", meta_file); - url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; - Web.Download(url, "graph", meta_file); - - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/NamedEntityRecognition.cs b/test/TensorFlowNET.Examples/TextProcessing/NamedEntityRecognition.cs deleted file mode 100644 index 157677ea..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/NamedEntityRecognition.cs +++ /dev/null @@ -1,51 +0,0 @@ -using System; -using Tensorflow; - -namespace TensorFlowNET.Examples -{ - /// - /// https://github.com/guillaumegenthial/tf_ner - /// - public class NamedEntityRecognition : IExample - { - public bool Enabled { get; set; } = false; - public string Name => "NER"; - public bool IsImportingGraph { get; set; } = false; - - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void PrepareData() - { - throw new NotImplementedException(); - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public bool Run() - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs b/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs deleted file mode 100644 index ba945606..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs +++ /dev/null @@ -1,243 +0,0 @@ -using NumSharp; -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using Tensorflow; -using TensorFlowNET.Examples.Utility; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples -{ - /// - /// Implement Word2Vec algorithm to compute vector representations of words. - /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/word2vec.py - /// - public class Word2Vec : IExample - { - public bool Enabled { get; set; } = true; - public string Name => "Word2Vec"; - public bool IsImportingGraph { get; set; } = true; - - // Training Parameters - float learning_rate = 0.1f; - int batch_size = 128; - int num_steps = 30000; //3000000; - int display_step = 1000; //10000; - int eval_step = 5000;//200000; - - // Evaluation Parameters - string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" }; - string[] text_words; - List word2id; - int[] data; - - // Word2Vec Parameters - int embedding_size = 200; // Dimension of the embedding vector - int max_vocabulary_size = 50000; // Total number of different words in the vocabulary - int min_occurrence = 10; // Remove all words that does not appears at least n times - int skip_window = 3; // How many words to consider left and right - int num_skips = 2; // How many times to reuse an input to generate a label - int num_sampled = 64; // Number of negative examples to sample - - int data_index = 0; - int top_k = 8; // number of nearest neighbors - float average_loss = 0; - - public bool Run() - { - PrepareData(); - - var graph = tf.Graph().as_default(); - - tf.train.import_meta_graph($"graph{Path.DirectorySeparatorChar}word2vec.meta"); - - // Input data - Tensor X = graph.OperationByName("Placeholder"); - // Input label - Tensor Y = graph.OperationByName("Placeholder_1"); - - // Compute the average NCE loss for the batch - Tensor loss_op = graph.OperationByName("Mean"); - // Define the optimizer - var train_op = graph.OperationByName("GradientDescent"); - Tensor cosine_sim_op = graph.OperationByName("MatMul_1"); - - // Initialize the variables (i.e. assign their default value) - var init = tf.global_variables_initializer(); - - using (var sess = tf.Session(graph)) - { - // Run the initializer - sess.run(init); - - var x_test = (from word in eval_words - join id in word2id on word equals id.Word into wi - from wi2 in wi.DefaultIfEmpty() - select wi2 == null ? 0 : wi2.Id).ToArray(); - - foreach (var step in range(1, num_steps + 1)) - { - // Get a new batch of data - var (batch_x, batch_y) = next_batch(batch_size, num_skips, skip_window); - - (_, float loss) = sess.run((train_op, loss_op), (X, batch_x), (Y, batch_y)); - average_loss += loss; - - if (step % display_step == 0 || step == 1) - { - if (step > 1) - average_loss /= display_step; - - print($"Step {step}, Average Loss= {average_loss.ToString("F4")}"); - average_loss = 0; - } - - // Evaluation - if (step % eval_step == 0 || step == 1) - { - print("Evaluation..."); - var sim = sess.run(cosine_sim_op, (X, x_test)); - foreach(var i in range(len(eval_words))) - { - var nearest = (0f - sim[i]).argsort() - .Data() - .Skip(1) - .Take(top_k) - .ToArray(); - string log_str = $"\"{eval_words[i]}\" nearest neighbors:"; - foreach (var k in range(top_k)) - log_str = $"{log_str} {word2id.First(x => x.Id == nearest[k]).Word},"; - print(log_str); - } - } - } - } - - return average_loss < 100; - } - - // Generate training batch for the skip-gram model - private (NDArray, NDArray) next_batch(int batch_size, int num_skips, int skip_window) - { - var batch = np.ndarray(new Shape(batch_size), dtype: np.int32); - var labels = np.ndarray((batch_size, 1), dtype: np.int32); - // get window size (words left and right + current one) - int span = 2 * skip_window + 1; - var buffer = new Queue(span); - if (data_index + span > data.Length) - data_index = 0; - data.Skip(data_index).Take(span).ToList().ForEach(x => buffer.Enqueue(x)); - data_index += span; - - foreach (var i in range(batch_size / num_skips)) - { - var context_words = range(span).Where(x => x != skip_window).ToArray(); - var words_to_use = new int[] { 1, 6 }; - foreach(var (j, context_word) in enumerate(words_to_use)) - { - batch[i * num_skips + j] = buffer.ElementAt(skip_window); - labels[i * num_skips + j, 0] = buffer.ElementAt(context_word); - } - - if (data_index == len(data)) - { - //buffer.extend(data[0:span]); - data_index = span; - } - else - { - buffer.Enqueue(data[data_index]); - data_index += 1; - } - } - - // Backtrack a little bit to avoid skipping words in the end of a batch - data_index = (data_index + len(data) - span) % len(data); - - return (batch, labels); - } - - public void PrepareData() - { - // Download graph meta - var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta"; - Web.Download(url, "graph", "word2vec.meta"); - - // Download a small chunk of Wikipedia articles collection - url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip"; - Web.Download(url, "word2vec", "text8.zip"); - // Unzip the dataset file. Text has already been processed - Compress.UnZip($"word2vec{Path.DirectorySeparatorChar}text8.zip", "word2vec"); - - int wordId = 0; - text_words = File.ReadAllText($"word2vec{Path.DirectorySeparatorChar}text8").Trim().ToLower().Split(); - // Build the dictionary and replace rare words with UNK token - word2id = text_words.GroupBy(x => x) - .Select(x => new WordId - { - Word = x.Key, - Occurrence = x.Count() - }) - .Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences - .OrderByDescending(x => x.Occurrence) // Retrieve the most common words - .Select(x => new WordId - { - Word = x.Word, - Id = ++wordId, // Assign an id to each word - Occurrence = x.Occurrence - }) - .ToList(); - - // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary - data = (from word in text_words - join id in word2id on word equals id.Word into wi - from wi2 in wi.DefaultIfEmpty() - select wi2 == null ? 0 : wi2.Id).ToArray(); - - word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) }); - - print($"Words count: {text_words.Length}"); - print($"Unique words: {text_words.Distinct().Count()}"); - print($"Vocabulary size: {word2id.Count}"); - print($"Most common words: {string.Join(", ", word2id.Take(10))}"); - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - - private class WordId - { - public string Word { get; set; } - public int Id { get; set; } - public int Occurrence { get; set; } - - public override string ToString() - { - return Word + " " + Id + " " + Occurrence; - } - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/CharCnn.cs b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/CharCnn.cs deleted file mode 100644 index 63132809..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/CharCnn.cs +++ /dev/null @@ -1,147 +0,0 @@ -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples.Text -{ - public class CharCnn : ITextModel - { - public CharCnn(int alphabet_size, int document_max_len, int num_class) - { - var learning_rate = 0.001f; - var filter_sizes = new int[] { 7, 7, 3, 3, 3, 3 }; - var num_filters = 256; - var kernel_initializer = tf.truncated_normal_initializer(stddev: 0.05f); - - var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); - var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); - var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training"); - var global_step = tf.Variable(0, trainable: false); - var keep_prob = tf.where(is_training, 0.5f, 1.0f); - - var x_one_hot = tf.one_hot(x, alphabet_size); - var x_expanded = tf.expand_dims(x_one_hot, -1); - - // ============= Convolutional Layers ============= - Tensor pool1 = null, pool2 = null; - Tensor conv3 = null, conv4 = null, conv5 = null, conv6 = null; - Tensor h_pool = null; - - tf_with(tf.name_scope("conv-maxpool-1"), delegate - { - var conv1 = tf.layers.conv2d(x_expanded, - filters: num_filters, - kernel_size: new[] { filter_sizes[0], alphabet_size }, - kernel_initializer: kernel_initializer, - activation: tf.nn.relu()); - - pool1 = tf.layers.max_pooling2d(conv1, - pool_size: new[] { 3, 1 }, - strides: new[] { 3, 1 }); - pool1 = tf.transpose(pool1, new[] { 0, 1, 3, 2 }); - }); - - tf_with(tf.name_scope("conv-maxpool-2"), delegate - { - var conv2 = tf.layers.conv2d(pool1, - filters: num_filters, - kernel_size: new[] {filter_sizes[1], num_filters }, - kernel_initializer: kernel_initializer, - activation: tf.nn.relu()); - - pool2 = tf.layers.max_pooling2d(conv2, - pool_size: new[] { 3, 1 }, - strides: new[] { 3, 1 }); - pool2 = tf.transpose(pool2, new[] { 0, 1, 3, 2 }); - }); - - tf_with(tf.name_scope("conv-3"), delegate - { - conv3 = tf.layers.conv2d(pool2, - filters: num_filters, - kernel_size: new[] { filter_sizes[2], num_filters }, - kernel_initializer: kernel_initializer, - activation: tf.nn.relu()); - conv3 = tf.transpose(conv3, new[] { 0, 1, 3, 2 }); - }); - - tf_with(tf.name_scope("conv-4"), delegate - { - conv4 = tf.layers.conv2d(conv3, - filters: num_filters, - kernel_size: new[] { filter_sizes[3], num_filters }, - kernel_initializer: kernel_initializer, - activation: tf.nn.relu()); - conv4 = tf.transpose(conv4, new[] { 0, 1, 3, 2 }); - }); - - tf_with(tf.name_scope("conv-5"), delegate - { - conv5 = tf.layers.conv2d(conv4, - filters: num_filters, - kernel_size: new[] { filter_sizes[4], num_filters }, - kernel_initializer: kernel_initializer, - activation: tf.nn.relu()); - conv5 = tf.transpose(conv5, new[] { 0, 1, 3, 2 }); - }); - - tf_with(tf.name_scope("conv-maxpool-6"), delegate - { - conv6 = tf.layers.conv2d(conv5, - filters: num_filters, - kernel_size: new[] { filter_sizes[5], num_filters }, - kernel_initializer: kernel_initializer, - activation: tf.nn.relu()); - - var pool6 = tf.layers.max_pooling2d(conv6, - pool_size: new[] { 3, 1 }, - strides: new[] { 3, 1 }); - pool6 = tf.transpose(pool6, new[] { 0, 2, 1, 3 }); - - h_pool = tf.reshape(pool6, new[] { -1, 34 * num_filters }); - }); - - // ============= Fully Connected Layers ============= - Tensor fc1_out = null, fc2_out = null; - Tensor logits = null; - Tensor predictions = null; - - tf_with(tf.name_scope("fc-1"), delegate - { - fc1_out = tf.layers.dense(h_pool, - 1024, - activation: tf.nn.relu(), - kernel_initializer: kernel_initializer); - }); - - tf_with(tf.name_scope("fc-2"), delegate - { - fc2_out = tf.layers.dense(fc1_out, - 1024, - activation: tf.nn.relu(), - kernel_initializer: kernel_initializer); - }); - - tf_with(tf.name_scope("fc-3"), delegate - { - logits = tf.layers.dense(fc2_out, - num_class, - kernel_initializer: kernel_initializer); - predictions = tf.argmax(logits, -1, output_type: tf.int32); - }); - - tf_with(tf.name_scope("loss"), delegate - { - var y_one_hot = tf.one_hot(y, num_class); - var loss = tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); - var optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step: global_step); - }); - - tf_with(tf.name_scope("accuracy"), delegate - { - var correct_predictions = tf.equal(predictions, y); - var accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name: "accuracy"); - }); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/ITextModel.cs b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/ITextModel.cs deleted file mode 100644 index dd5a8704..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/ITextModel.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace TensorFlowNET.Examples.Text -{ - interface ITextModel - { - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs deleted file mode 100644 index 6150fa90..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/VdCnn.cs +++ /dev/null @@ -1,171 +0,0 @@ -using System.Collections.Generic; -using System.Linq; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples.Text -{ - public class VdCnn : ITextModel - { - private int embedding_size; - private int[] filter_sizes; - private int[] num_filters; - private int[] num_blocks; - private float learning_rate; - private IInitializer cnn_initializer; - private IInitializer fc_initializer; - public Tensor x { get; private set; } - public Tensor y { get; private set; } - public Tensor is_training { get; private set; } - private RefVariable global_step; - private RefVariable embeddings; - private Tensor x_emb; - private Tensor x_expanded; - private Tensor logits; - private Tensor predictions; - private Tensor loss; - - public VdCnn(int alphabet_size, int document_max_len, int num_class) - { - embedding_size = 16; - filter_sizes = new int[] { 3, 3, 3, 3, 3 }; - num_filters = new int[] { 64, 64, 128, 256, 512 }; - num_blocks = new int[] { 2, 2, 2, 2 }; - learning_rate = 0.001f; - cnn_initializer = tensorflow.keras.initializers.he_normal(); - fc_initializer = tf.truncated_normal_initializer(stddev: 0.05f); - - x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); - y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); - is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training"); - global_step = tf.Variable(0, trainable: false); - - // Embedding Layer - tf_with(tf.name_scope("embedding"), delegate - { - var init_embeddings = tf.random_uniform(new int[] { alphabet_size, embedding_size }, -1.0f, 1.0f); - embeddings = tf.get_variable("embeddings", initializer: init_embeddings); - x_emb = tf.nn.embedding_lookup(embeddings, x); - x_expanded = tf.expand_dims(x_emb, -1); - }); - - Tensor conv0 = null; - Tensor conv1 = null; - Tensor conv2 = null; - Tensor conv3 = null; - Tensor conv4 = null; - Tensor h_flat = null; - Tensor fc1_out = null; - Tensor fc2_out = null; - - // First Convolution Layer - tf_with(tf.variable_scope("conv-0"), delegate - { - conv0 = tf.layers.conv2d(x_expanded, - filters: num_filters[0], - kernel_size: new int[] { filter_sizes[0], embedding_size }, - kernel_initializer: cnn_initializer, - activation: tf.nn.relu()); - - conv0 = tf.transpose(conv0, new int[] { 0, 1, 3, 2 }); - }); - - tf_with(tf.name_scope("conv-block-1"), delegate { - conv1 = conv_block(conv0, 1); - }); - - tf_with(tf.name_scope("conv-block-2"), delegate { - conv2 = conv_block(conv1, 2); - }); - - tf_with(tf.name_scope("conv-block-3"), delegate { - conv3 = conv_block(conv2, 3); - }); - - tf_with(tf.name_scope("conv-block-4"), delegate - { - conv4 = conv_block(conv3, 4, max_pool: false); - }); - - // ============= k-max Pooling ============= - tf_with(tf.name_scope("k-max-pooling"), delegate - { - var h = tf.transpose(tf.squeeze(conv4, new int[] { -1 }), new int[] { 0, 2, 1 }); - var top_k = tf.nn.top_k(h, k: 8, sorted: false)[0]; - h_flat = tf.reshape(top_k, new int[] { -1, 512 * 8 }); - }); - - // ============= Fully Connected Layers ============= - tf_with(tf.name_scope("fc-1"), scope => - { - fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer); - }); - - tf_with(tf.name_scope("fc-2"), scope => - { - fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer); - }); - - tf_with(tf.name_scope("fc-3"), scope => - { - logits = tf.layers.dense(fc2_out, num_class, activation: null, kernel_initializer: fc_initializer); - predictions = tf.argmax(logits, -1, output_type: tf.int32); - }); - - // ============= Loss and Accuracy ============= - tf_with(tf.name_scope("loss"), delegate - { - var y_one_hot = tf.one_hot(y, num_class); - loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); - - var update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) as List; - tf_with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate - { - var adam = tf.train.AdamOptimizer(learning_rate); - adam.minimize(loss, global_step: global_step); - }); - }); - } - - private Tensor conv_block(Tensor input, int i, bool max_pool = true) - { - return tf_with(tf.variable_scope($"conv-block-{i}"), delegate - { - Tensor conv = null; - // Two "conv-batch_norm-relu" layers. - foreach (var j in Enumerable.Range(0, 2)) - { - tf_with(tf.variable_scope($"conv-{j}"), delegate - { - // convolution - conv = tf.layers.conv2d( - input, - filters: num_filters[i], - kernel_size: new int[] { filter_sizes[i], num_filters[i - 1] }, - kernel_initializer: cnn_initializer, - activation: null); - // batch normalization - conv = tf.layers.batch_normalization(conv, training: is_training); - // relu - conv = tf.nn.relu(conv); - conv = tf.transpose(conv, new int[] { 0, 1, 3, 2 }); - }); - } - - if (max_pool) - { - // Max pooling - return tf.layers.max_pooling2d( - conv, - pool_size: new int[] { 3, 1 }, - strides: new int[] { 2, 1 }, - padding: "SAME"); - } - else - { - return conv; - } - }); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/WordCnn.cs b/test/TensorFlowNET.Examples/TextProcessing/cnn_models/WordCnn.cs deleted file mode 100644 index ef5bc9db..00000000 --- a/test/TensorFlowNET.Examples/TextProcessing/cnn_models/WordCnn.cs +++ /dev/null @@ -1,99 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System.Collections.Generic; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.Examples.Text -{ - public class WordCnn : ITextModel - { - public WordCnn(int vocabulary_size, int document_max_len, int num_class) - { - var embedding_size = 128; - var learning_rate = 0.001f; - var filter_sizes = new int[3, 4, 5]; - var num_filters = 100; - - var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); - var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); - var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training"); - var global_step = tf.Variable(0, trainable: false); - var keep_prob = tf.where(is_training, 0.5f, 1.0f); - Tensor x_emb = null; - - tf_with(tf.name_scope("embedding"), scope => - { - var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size }); - var embeddings = tf.get_variable("embeddings", initializer: init_embeddings); - x_emb = tf.nn.embedding_lookup(embeddings, x); - x_emb = tf.expand_dims(x_emb, -1); - }); - - var pooled_outputs = new List(); - for (int len = 0; len < filter_sizes.Rank; len++) - { - int filter_size = filter_sizes.GetLength(len); - var conv = tf.layers.conv2d( - x_emb, - filters: num_filters, - kernel_size: new int[] { filter_size, embedding_size }, - strides: new int[] { 1, 1 }, - padding: "VALID", - activation: tf.nn.relu()); - - var pool = tf.layers.max_pooling2d( - conv, - pool_size: new[] { document_max_len - filter_size + 1, 1 }, - strides: new[] { 1, 1 }, - padding: "VALID"); - - pooled_outputs.Add(pool); - } - - var h_pool = tf.concat(pooled_outputs, 3); - var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank)); - Tensor h_drop = null; - tf_with(tf.name_scope("dropout"), delegate - { - h_drop = tf.nn.dropout(h_pool_flat, keep_prob); - }); - - Tensor logits = null; - Tensor predictions = null; - tf_with(tf.name_scope("output"), delegate - { - logits = tf.layers.dense(h_drop, num_class); - predictions = tf.argmax(logits, -1, output_type: tf.int32); - }); - - tf_with(tf.name_scope("loss"), delegate - { - var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y); - var loss = tf.reduce_mean(sscel); - var adam = tf.train.AdamOptimizer(learning_rate); - var optimizer = adam.minimize(loss, global_step: global_step); - }); - - tf_with(tf.name_scope("accuracy"), delegate - { - var correct_predictions = tf.equal(predictions, y); - var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy"); - }); - } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/ArrayShuffling.cs b/test/TensorFlowNET.Examples/Utility/ArrayShuffling.cs deleted file mode 100644 index 2ae3e2ea..00000000 --- a/test/TensorFlowNET.Examples/Utility/ArrayShuffling.cs +++ /dev/null @@ -1,20 +0,0 @@ -using System; - -namespace TensorFlowNET.Examples.Utility -{ - public static class ArrayShuffling - { - public static T[] Shuffle(this Random rng, T[] array) - { - int n = array.Length; - while (n > 1) - { - int k = rng.Next(n--); - T temp = array[n]; - array[n] = array[k]; - array[k] = temp; - } - return array; - } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/CoNLLDataset.cs b/test/TensorFlowNET.Examples/Utility/CoNLLDataset.cs deleted file mode 100644 index 14b96656..00000000 --- a/test/TensorFlowNET.Examples/Utility/CoNLLDataset.cs +++ /dev/null @@ -1,105 +0,0 @@ -using System.Collections.Generic; -using System.IO; -using System.Linq; -using Tensorflow.Estimator; - -namespace TensorFlowNET.Examples.Utility -{ - public class CoNLLDataset - { - static Dictionary vocab_chars; - static Dictionary vocab_words; - static Dictionary vocab_tags; - - HyperParams _hp; - string _path; - - public CoNLLDataset(string path, HyperParams hp) - { - if (vocab_chars == null) - vocab_chars = load_vocab(hp.filepath_chars); - - if (vocab_words == null) - vocab_words = load_vocab(hp.filepath_words); - - if (vocab_tags == null) - vocab_tags = load_vocab(hp.filepath_tags); - - _path = path; - } - - private (int[], int) processing_word(string word) - { - var char_ids = word.ToCharArray().Select(x => vocab_chars[x.ToString()]).ToArray(); - - // 1. preprocess word - if (true) // lowercase - word = word.ToLower(); - if (false) // isdigit - word = "$NUM$"; - - // 2. get id of word - int id = vocab_words.GetValueOrDefault(word, vocab_words["$UNK$"]); - - return (char_ids, id); - } - - private int processing_tag(string word) - { - // 1. preprocess word - if (false) // lowercase - word = word.ToLower(); - if (false) // isdigit - word = "$NUM$"; - - // 2. get id of word - int id = vocab_tags.GetValueOrDefault(word, -1); - - return id; - } - - private Dictionary load_vocab(string filename) - { - var dict = new Dictionary(); - int i = 0; - File.ReadAllLines(filename) - .Select(x => dict[x] = i++) - .Count(); - return dict; - } - - public IEnumerable<((int[], int)[], int[])> GetItems() - { - var lines = File.ReadAllLines(_path); - - int niter = 0; - var words = new List<(int[], int)>(); - var tags = new List(); - - foreach (var l in lines) - { - string line = l.Trim(); - if (string.IsNullOrEmpty(line) || line.StartsWith("-DOCSTART-")) - { - if (words.Count > 0) - { - niter++; - yield return (words.ToArray(), tags.ToArray()); - words.Clear(); - tags.Clear(); - } - } - else - { - var ls = line.Split(' '); - // process word - var word = processing_word(ls[0]); - var tag = processing_tag(ls[1]); - - words.Add(word); - tags.Add(tag); - } - } - } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/Compress.cs b/test/TensorFlowNET.Examples/Utility/Compress.cs deleted file mode 100644 index 95eb0ddf..00000000 --- a/test/TensorFlowNET.Examples/Utility/Compress.cs +++ /dev/null @@ -1,102 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using ICSharpCode.SharpZipLib.Core; -using ICSharpCode.SharpZipLib.GZip; -using ICSharpCode.SharpZipLib.Tar; -using System; -using System.IO; -using System.IO.Compression; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; - -namespace TensorFlowNET.Examples.Utility -{ - public class Compress - { - public static void ExtractGZip(string gzipFileName, string targetDir) - { - // Use a 4K buffer. Any larger is a waste. - byte[] dataBuffer = new byte[4096]; - - using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read)) - { - using (GZipInputStream gzipStream = new GZipInputStream(fs)) - { - // Change this to your needs - string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName)); - - using (FileStream fsOut = File.Create(fnOut)) - { - StreamUtils.Copy(gzipStream, fsOut, dataBuffer); - } - } - } - } - - public static void UnZip(String gzArchiveName, String destFolder) - { - var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; - if (File.Exists(Path.Combine(destFolder, flag))) return; - - Console.WriteLine($"Extracting."); - var task = Task.Run(() => - { - ZipFile.ExtractToDirectory(gzArchiveName, destFolder); - }); - - while (!task.IsCompleted) - { - Thread.Sleep(200); - Console.Write("."); - } - - File.Create(Path.Combine(destFolder, flag)); - Console.WriteLine(""); - Console.WriteLine("Extracting is completed."); - } - - public static void ExtractTGZ(String gzArchiveName, String destFolder) - { - var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; - if (File.Exists(Path.Combine(destFolder, flag))) return; - - Console.WriteLine($"Extracting."); - var task = Task.Run(() => - { - using (var inStream = File.OpenRead(gzArchiveName)) - { - using (var gzipStream = new GZipInputStream(inStream)) - { - using (TarArchive tarArchive = TarArchive.CreateInputTarArchive(gzipStream)) - tarArchive.ExtractContents(destFolder); - } - } - }); - - while (!task.IsCompleted) - { - Thread.Sleep(200); - Console.Write("."); - } - - File.Create(Path.Combine(destFolder, flag)); - Console.WriteLine(""); - Console.WriteLine("Extracting is completed."); - } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/PbtxtParser.cs b/test/TensorFlowNET.Examples/Utility/PbtxtParser.cs deleted file mode 100644 index 4a918017..00000000 --- a/test/TensorFlowNET.Examples/Utility/PbtxtParser.cs +++ /dev/null @@ -1,81 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using Newtonsoft.Json; -using System.Collections.Generic; - -namespace TensorFlowNET.Examples.Utility -{ - public class PbtxtItem - { - public string name { get; set; } - public int id { get; set; } - public string display_name { get; set; } - } - public class PbtxtItems - { - public List items { get; set; } - } - - public class PbtxtParser - { - public static PbtxtItems ParsePbtxtFile(string filePath) - { - string line; - string newText = "{\"items\":["; - - using (System.IO.StreamReader reader = new System.IO.StreamReader(filePath)) - { - - while ((line = reader.ReadLine()) != null) - { - string newline = string.Empty; - - if (line.Contains("{")) - { - newline = line.Replace("item", "").Trim(); - //newText += line.Insert(line.IndexOf("=") + 1, "\"") + "\","; - newText += newline; - } - else if (line.Contains("}")) - { - newText = newText.Remove(newText.Length - 1); - newText += line; - newText += ","; - } - else - { - newline = line.Replace(":", "\":").Trim(); - newline = "\"" + newline;// newline.Insert(0, "\""); - newline += ","; - - newText += newline; - } - - } - - newText = newText.Remove(newText.Length - 1); - newText += "]}"; - - reader.Close(); - } - - PbtxtItems items = JsonConvert.DeserializeObject(newText); - - return items; - } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/Web.cs b/test/TensorFlowNET.Examples/Utility/Web.cs deleted file mode 100644 index 8f300167..00000000 --- a/test/TensorFlowNET.Examples/Utility/Web.cs +++ /dev/null @@ -1,57 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.IO; -using System.Linq; -using System.Net; -using System.Threading; -using System.Threading.Tasks; - -namespace TensorFlowNET.Examples.Utility -{ - public class Web - { - public static bool Download(string url, string destDir, string destFileName) - { - if (destFileName == null) - destFileName = url.Split(Path.DirectorySeparatorChar).Last(); - - Directory.CreateDirectory(destDir); - - string relativeFilePath = Path.Combine(destDir, destFileName); - - if (File.Exists(relativeFilePath)) - { - Console.WriteLine($"{relativeFilePath} already exists."); - return false; - } - - var wc = new WebClient(); - Console.WriteLine($"Downloading {relativeFilePath}"); - var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); - while (!download.IsCompleted) - { - Thread.Sleep(1000); - Console.Write("."); - } - Console.WriteLine(""); - Console.WriteLine($"Downloaded {relativeFilePath}"); - - return true; - } - } -} diff --git a/test/TensorFlowNET.Examples/python/binary_text_classification.py b/test/TensorFlowNET.Examples/python/binary_text_classification.py deleted file mode 100644 index f783327c..00000000 --- a/test/TensorFlowNET.Examples/python/binary_text_classification.py +++ /dev/null @@ -1,115 +0,0 @@ - -from __future__ import absolute_import, division, print_function - -import tensorflow as tf -from tensorflow import keras - -import numpy as np - -print(tf.__version__) - -imdb = keras.datasets.imdb - -(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000) - -print("Training entries: {}, labels: {}".format(len(train_data), len(train_labels))) -print(train_data[0]) -len(train_data[0]), len(train_data[1]) - -# A dictionary mapping words to an integer index -word_index = imdb.get_word_index() - -# The first indices are reserved -word_index = {k:(v+3) for k,v in word_index.items()} -word_index[""] = 0 -word_index[""] = 1 -word_index[""] = 2 # unknown -word_index[""] = 3 - -reverse_word_index = dict([(value, key) for (key, value) in word_index.items()]) - -def decode_review(text): - return ' '.join([reverse_word_index.get(i, '?') for i in text]) - -decode_review(train_data[0]) - - -train_data = keras.preprocessing.sequence.pad_sequences(train_data, - value=word_index[""], - padding='post', - maxlen=256) - -test_data = keras.preprocessing.sequence.pad_sequences(test_data, - value=word_index[""], - padding='post', - maxlen=256) - - -print(train_data[0]) - -# input shape is the vocabulary count used for the movie reviews (10,000 words) -vocab_size = 10000 - -model = keras.Sequential() -model.add(keras.layers.Embedding(vocab_size, 16)) -model.add(keras.layers.GlobalAveragePooling1D()) -model.add(keras.layers.Dense(16, activation=tf.nn.relu)) -model.add(keras.layers.Dense(1, activation=tf.nn.sigmoid)) - -model.summary() - -model.compile(optimizer='adam', - loss='binary_crossentropy', - metrics=['accuracy']) - - -x_val = train_data[:10000] -partial_x_train = train_data[10000:] - -y_val = train_labels[:10000] -partial_y_train = train_labels[10000:] - -history = model.fit(partial_x_train, - partial_y_train, - epochs=20, - batch_size=512, - validation_data=(x_val, y_val), - verbose=1) - -results = model.evaluate(test_data, test_labels) -print(results) - -history_dict = history.history -history_dict.keys() - -import matplotlib.pyplot as plt - -acc = history_dict['acc'] -val_acc = history_dict['val_acc'] -loss = history_dict['loss'] -val_loss = history_dict['val_loss'] - -epochs = range(1, len(acc) + 1) - -# "bo" is for "blue dot" -plt.plot(epochs, loss, 'bo', label='Training loss') -# b is for "solid blue line" -plt.plot(epochs, val_loss, 'b', label='Validation loss') -plt.title('Training and validation loss') -plt.xlabel('Epochs') -plt.ylabel('Loss') -plt.legend() - -plt.show() - - -plt.clf() # clear figure - -plt.plot(epochs, acc, 'bo', label='Training acc') -plt.plot(epochs, val_acc, 'b', label='Validation acc') -plt.title('Training and validation accuracy') -plt.xlabel('Epochs') -plt.ylabel('Accuracy') -plt.legend() - -plt.show() \ No newline at end of file diff --git a/test/TensorFlowNET.Examples/python/linear_regression.py b/test/TensorFlowNET.Examples/python/linear_regression.py deleted file mode 100644 index c4bb00f0..00000000 --- a/test/TensorFlowNET.Examples/python/linear_regression.py +++ /dev/null @@ -1,107 +0,0 @@ -''' -A linear regression learning algorithm example using TensorFlow library. -Author: Aymeric Damien -Project: https://github.com/aymericdamien/TensorFlow-Examples/ -''' - -from __future__ import print_function - -import tensorflow as tf -import numpy -import matplotlib.pyplot as plt -rng = numpy.random - -# Parameters -learning_rate = 0.01 -training_epochs = 1000 -display_step = 10 - -# Training Data -train_X = numpy.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167, - 7.042,10.791,5.313,7.997,5.654,9.27,3.1]) -train_Y = numpy.asarray([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221, - 2.827,3.465,1.65,2.904,2.42,2.94,1.3]) -n_samples = train_X.shape[0] - -if False: - # tf Graph Input - X = tf.placeholder("float") - Y = tf.placeholder("float") - - # Set model weights - W = tf.Variable(-0.06, name="weight") - b = tf.Variable(-0.73, name="bias") - - # Construct a linear model - mul = tf.multiply(X, W) - pred = tf.add(mul, b) - - # Mean squared error - sub = pred-Y - pow = tf.pow(sub, 2) - - reduce = tf.reduce_sum(pow) - cost = reduce/(2*n_samples) - # Gradient descent - # Note, minimize() knows to modify W and b because Variable objects are trainable=True by default - grad = tf.train.GradientDescentOptimizer(learning_rate) - optimizer = grad.minimize(cost) - # tf.train.export_meta_graph(filename='save_model.meta'); -else: - # tf Graph Input - new_saver = tf.train.import_meta_graph("linear_regression.meta") - nodes = tf.get_default_graph()._nodes_by_name; - optimizer = nodes["GradientDescent"] - cost = nodes["truediv"].outputs[0] - X = nodes["Placeholder"].outputs[0] - Y = nodes["Placeholder_1"].outputs[0] - W = nodes["weight"].outputs[0] - b = nodes["bias"].outputs[0] - pred = nodes["Add"].outputs[0] - -# Initialize the variables (i.e. assign their default value) -init = tf.global_variables_initializer() - -# Start training -with tf.Session() as sess: - - # Run the initializer - sess.run(init) - - # Fit all training data - for epoch in range(training_epochs): - for (x, y) in zip(train_X, train_Y): - sess.run(optimizer, feed_dict={X: x, Y: y}) - - # Display logs per epoch step - if (epoch+1) % display_step == 0: - c = sess.run(cost, feed_dict={X: train_X, Y:train_Y}) - print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c), \ - "W=", sess.run(W), "b=", sess.run(b)) - - print("Optimization Finished!") - training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y}) - print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n') - - # Graphic display - plt.plot(train_X, train_Y, 'ro', label='Original data') - plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line') - plt.legend() - plt.show() - - # Testing example, as requested (Issue #2) - test_X = numpy.asarray([6.83, 4.668, 8.9, 7.91, 5.7, 8.7, 3.1, 2.1]) - test_Y = numpy.asarray([1.84, 2.273, 3.2, 2.831, 2.92, 3.24, 1.35, 1.03]) - - print("Testing... (Mean square loss Comparison)") - testing_cost = sess.run( - tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * test_X.shape[0]), - feed_dict={X: test_X, Y: test_Y}) # same function as cost above - print("Testing cost=", testing_cost) - print("Absolute mean square loss difference:", abs( - training_cost - testing_cost)) - - plt.plot(test_X, test_Y, 'bo', label='Testing data') - plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line') - plt.legend() - plt.show() \ No newline at end of file diff --git a/test/TensorFlowNET.Examples/python/logistic_regression.py b/test/TensorFlowNET.Examples/python/logistic_regression.py deleted file mode 100644 index 236d83d1..00000000 --- a/test/TensorFlowNET.Examples/python/logistic_regression.py +++ /dev/null @@ -1,100 +0,0 @@ -''' -A logistic regression learning algorithm example using TensorFlow library. -This example is using the MNIST database of handwritten digits -(http://yann.lecun.com/exdb/mnist/) -Author: Aymeric Damien -Project: https://github.com/aymericdamien/TensorFlow-Examples/ -''' - -from __future__ import print_function - -import tensorflow as tf - -# Import MNIST data -from tensorflow.examples.tutorials.mnist import input_data -mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) - -# Parameters -learning_rate = 0.01 -training_epochs = 10 -batch_size = 100 -display_step = 1 - -# tf Graph Input -x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784 -y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes - -# Set model weights -W = tf.Variable(tf.zeros([784, 10])) -b = tf.Variable(tf.zeros([10])) - -# Construct model -pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax - -# Minimize error using cross entropy -cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) -# Gradient Descent -optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) - -# Initialize the variables (i.e. assign their default value) -init = tf.global_variables_initializer() - -# Start training -with tf.Session() as sess: - - # Run the initializer - sess.run(init) - - # Training cycle - for epoch in range(training_epochs): - avg_cost = 0. - total_batch = int(mnist.train.num_examples/batch_size) - # Loop over all batches - for i in range(total_batch): - batch_xs, batch_ys = mnist.train.next_batch(batch_size) - # Run optimization op (backprop) and cost op (to get loss value) - _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, - y: batch_ys}) - # Compute average loss - avg_cost += c / total_batch - # Display logs per epoch step - if (epoch+1) % display_step == 0: - print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)) - - print("Optimization Finished!") - - # Test model - correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) - # Calculate accuracy - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) - - # predict - # results = sess.run(pred, feed_dict={x: batch_xs[:1]}) - - # save model - saver = tf.train.Saver() - save_path = saver.save(sess, "logistic_regression/model.ckpt") - tf.train.write_graph(sess.graph.as_graph_def(),'logistic_regression','model.pbtxt', as_text=True) - - freeze_graph.freeze_graph(input_graph = 'logistic_regression/model.pbtxt', - input_saver = "", - input_binary = False, - input_checkpoint = 'logistic_regression/model.ckpt', - output_node_names = "Softmax", - restore_op_name = "save/restore_all", - filename_tensor_name = "save/Const:0", - output_graph = 'logistic_regression/model.pb', - clear_devices = True, - initializer_nodes = "") - - # restoring the model - saver = tf.train.import_meta_graph('logistic_regression/tensorflowModel.ckpt.meta') - saver.restore(sess,tf.train.latest_checkpoint('logistic_regression')) - - # predict - # pred = graph._nodes_by_name["Softmax"] - # output = pred.outputs[0] - # x = graph._nodes_by_name["Placeholder"] - # input = x.outputs[0] - # results = sess.run(output, feed_dict={input: batch_xs[:1]}) \ No newline at end of file diff --git a/test/TensorFlowNET.Examples/python/meta_graph.py b/test/TensorFlowNET.Examples/python/meta_graph.py deleted file mode 100644 index cb426091..00000000 --- a/test/TensorFlowNET.Examples/python/meta_graph.py +++ /dev/null @@ -1,67 +0,0 @@ - -import tensorflow as tf -import math - -# Creates an inference graph. -# Hidden 1 -images = tf.constant(1.2, tf.float32, shape=[100, 28]) -with tf.name_scope("hidden1"): - weights = tf.Variable( - tf.truncated_normal([28, 128], - stddev=1.0 / math.sqrt(float(28))), - name="weights") - biases = tf.Variable(tf.zeros([128]), - name="biases") - hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases) -# Hidden 2 -with tf.name_scope("hidden2"): - weights = tf.Variable( - tf.truncated_normal([128, 32], - stddev=1.0 / math.sqrt(float(128))), - name="weights") - biases = tf.Variable(tf.zeros([32]), - name="biases") - hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases) -# Linear -with tf.name_scope("softmax_linear"): - weights = tf.Variable( - tf.truncated_normal([32, 10], - stddev=1.0 / math.sqrt(float(32))), - name="weights") - biases = tf.Variable(tf.zeros([10]), - name="biases") - logits = tf.matmul(hidden2, weights) + biases - tf.add_to_collection("logits", logits) - -init_all_op = tf.global_variables_initializer() - -with tf.Session() as sess: - # Initializes all the variables. - sess.run(init_all_op) - # Runs to logit. - sess.run(logits) - # Creates a saver. - saver0 = tf.train.Saver() - saver0.save(sess, 'my-save-dir/my-model-10000') - # Generates MetaGraphDef. - saver0.export_meta_graph('my-save-dir/my-model-10000.meta') - - -# Then later import it and extend it to a training graph. -with tf.Session() as sess: - new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') - new_saver.restore(sess, 'my-save-dir/my-model-10000') - # Addes loss and train. - labels = tf.constant(0, tf.int32, shape=[100], name="labels") - batch_size = tf.size(labels) - logits = tf.get_collection("logits")[0] - loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, - logits=logits) - - tf.summary.scalar('loss', loss) - # Creates the gradient descent optimizer with the given learning rate. - optimizer = tf.train.GradientDescentOptimizer(0.01) - - # Runs train_op. - train_op = optimizer.minimize(loss) - sess.run(train_op) diff --git a/test/TensorFlowNET.Examples/python/minst_lstm.py b/test/TensorFlowNET.Examples/python/minst_lstm.py deleted file mode 100644 index a58d7fdf..00000000 --- a/test/TensorFlowNET.Examples/python/minst_lstm.py +++ /dev/null @@ -1,78 +0,0 @@ -import tensorflow as tf -from tensorflow.contrib import rnn - -#import mnist dataset -from tensorflow.examples.tutorials.mnist import input_data -mnist=input_data.read_data_sets("/tmp/data/",one_hot=True) - -#define constants -#unrolled through 28 time steps -time_steps=28 -#hidden LSTM units -num_units=128 -#rows of 28 pixels -n_input=28 -#learning rate for adam -learning_rate=0.001 -#mnist is meant to be classified in 10 classes(0-9). -n_classes=10 -#size of batch -batch_size=128 - - -#weights and biases of appropriate shape to accomplish above task -out_weights=tf.Variable(tf.random_normal([num_units,n_classes])) -out_bias=tf.Variable(tf.random_normal([n_classes])) - -#defining placeholders -#input image placeholder -x=tf.placeholder("float",[None,time_steps,n_input]) -#input label placeholder -y=tf.placeholder("float",[None,n_classes]) - -#processing the input tensor from [batch_size,n_steps,n_input] to "time_steps" number of [batch_size,n_input] tensors -input=tf.unstack(x ,time_steps,1) - -#defining the network -lstm_layer=rnn.BasicLSTMCell(num_units,forget_bias=1) -outputs,_=rnn.static_rnn(lstm_layer,input,dtype="float32") - -#converting last output of dimension [batch_size,num_units] to [batch_size,n_classes] by out_weight multiplication -prediction=tf.matmul(outputs[-1],out_weights)+out_bias - -#loss_function -loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y)) -#optimization -opt=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) - -#model evaluation -correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1)) -accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) - -#initialize variables -init=tf.global_variables_initializer() -with tf.Session() as sess: - sess.run(init) - iter=1 - while iter<800: - batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size) - - batch_x=batch_x.reshape((batch_size,time_steps,n_input)) - - sess.run(opt, feed_dict={x: batch_x, y: batch_y}) - - if iter %10==0: - acc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y}) - los=sess.run(loss,feed_dict={x:batch_x,y:batch_y}) - print("For iter ",iter) - print("Accuracy ",acc) - print("Loss ",los) - print("__________________") - - iter=iter+1 - - #calculating test accuracy - test_data = mnist.test.images[:128].reshape((-1, time_steps, n_input)) - test_label = mnist.test.labels[:128] - print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label})) - diff --git a/test/TensorFlowNET.Examples/python/minst_rnn.py b/test/TensorFlowNET.Examples/python/minst_rnn.py deleted file mode 100644 index c6218776..00000000 --- a/test/TensorFlowNET.Examples/python/minst_rnn.py +++ /dev/null @@ -1,48 +0,0 @@ -import tensorflow as tf - -# hyperparameters -n_neurons = 128 -learning_rate = 0.001 -batch_size = 128 -n_epochs = 10 -# parameters -n_steps = 28 # 28 rows -n_inputs = 28 # 28 cols -n_outputs = 10 # 10 classes -# build a rnn model -X = tf.placeholder(tf.float32, [None, n_steps, n_inputs]) -y = tf.placeholder(tf.int32, [None]) -cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons) -output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32) -logits = tf.layers.dense(state, n_outputs) -cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits) -loss = tf.reduce_mean(cross_entropy) -optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) -prediction = tf.nn.in_top_k(logits, y, 1) -accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32)) - -# input data -from tensorflow.examples.tutorials.mnist import input_data -mnist = input_data.read_data_sets("MNIST_data/") -X_test = mnist.test.images # X_test shape: [num_test, 28*28] -X_test = X_test.reshape([-1, n_steps, n_inputs]) -y_test = mnist.test.labels - -# initialize the variables -init = tf.global_variables_initializer() -# train the model -with tf.Session() as sess: - sess.run(init) - n_batches = mnist.train.num_examples // batch_size - for epoch in range(n_epochs): - for batch in range(n_batches): - X_train, y_train = mnist.train.next_batch(batch_size) - X_train = X_train.reshape([-1, n_steps, n_inputs]) - sess.run(optimizer, feed_dict={X: X_train, y: y_train}) - loss_train, acc_train = sess.run( - [loss, accuracy], feed_dict={X: X_train, y: y_train}) - print('Epoch: {}, Train Loss: {:.3f}, Train Acc: {:.3f}'.format( - epoch + 1, loss_train, acc_train)) - loss_test, acc_test = sess.run( - [loss, accuracy], feed_dict={X: X_test, y: y_test}) - print('Test Loss: {:.3f}, Test Acc: {:.3f}'.format(loss_test, acc_test)) \ No newline at end of file diff --git a/test/TensorFlowNET.Examples/python/neural_network.py b/test/TensorFlowNET.Examples/python/neural_network.py deleted file mode 100644 index ac9597ca..00000000 --- a/test/TensorFlowNET.Examples/python/neural_network.py +++ /dev/null @@ -1,164 +0,0 @@ - -# imports -import tensorflow as tf -import numpy as np -import matplotlib.pyplot as plt - -img_h = img_w = 28 # MNIST images are 28x28 -img_size_flat = img_h * img_w # 28x28=784, the total number of pixels -n_classes = 10 # Number of classes, one class per digit - -def load_data(mode='train'): - """ - Function to (download and) load the MNIST data - :param mode: train or test - :return: images and the corresponding labels - """ - from tensorflow.examples.tutorials.mnist import input_data - mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) - if mode == 'train': - x_train, y_train, x_valid, y_valid = mnist.train.images, mnist.train.labels, \ - mnist.validation.images, mnist.validation.labels - return x_train, y_train, x_valid, y_valid - elif mode == 'test': - x_test, y_test = mnist.test.images, mnist.test.labels - return x_test, y_test - -def randomize(x, y): - """ Randomizes the order of data samples and their corresponding labels""" - permutation = np.random.permutation(y.shape[0]) - shuffled_x = x[permutation, :] - shuffled_y = y[permutation] - return shuffled_x, shuffled_y - -def get_next_batch(x, y, start, end): - x_batch = x[start:end] - y_batch = y[start:end] - return x_batch, y_batch - -# Load MNIST data -x_train, y_train, x_valid, y_valid = load_data(mode='train') -print("Size of:") -print("- Training-set:\t\t{}".format(len(y_train))) -print("- Validation-set:\t{}".format(len(y_valid))) - -print('x_train:\t{}'.format(x_train.shape)) -print('y_train:\t{}'.format(y_train.shape)) -print('x_train:\t{}'.format(x_valid.shape)) -print('y_valid:\t{}'.format(y_valid.shape)) - -print(y_valid[:5, :]) - -# Hyper-parameters -epochs = 10 # Total number of training epochs -batch_size = 100 # Training batch size -display_freq = 100 # Frequency of displaying the training results -learning_rate = 0.001 # The optimization initial learning rate - -h1 = 200 # number of nodes in the 1st hidden layer - -# weight and bais wrappers -def weight_variable(name, shape): - """ - Create a weight variable with appropriate initialization - :param name: weight name - :param shape: weight shape - :return: initialized weight variable - """ - initer = tf.truncated_normal_initializer(stddev=0.01) - return tf.get_variable('W_' + name, - dtype=tf.float32, - shape=shape, - initializer=initer) - - -def bias_variable(name, shape): - """ - Create a bias variable with appropriate initialization - :param name: bias variable name - :param shape: bias variable shape - :return: initialized bias variable - """ - initial = tf.constant(0., shape=shape, dtype=tf.float32) - return tf.get_variable('b_' + name, - dtype=tf.float32, - initializer=initial) - -def fc_layer(x, num_units, name, use_relu=True): - """ - Create a fully-connected layer - :param x: input from previous layer - :param num_units: number of hidden units in the fully-connected layer - :param name: layer name - :param use_relu: boolean to add ReLU non-linearity (or not) - :return: The output array - """ - in_dim = x.get_shape()[1] - W = weight_variable(name, shape=[in_dim, num_units]) - b = bias_variable(name, [num_units]) - layer = tf.matmul(x, W) - layer += b - if use_relu: - layer = tf.nn.relu(layer) - return layer - -# Create the graph for the linear model -# Placeholders for inputs (x) and outputs(y) -x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='X') -y = tf.placeholder(tf.float32, shape=[None, n_classes], name='Y') - -# Create a fully-connected layer with h1 nodes as hidden layer -fc1 = fc_layer(x, h1, 'FC1', use_relu=True) -# Create a fully-connected layer with n_classes nodes as output layer -output_logits = fc_layer(fc1, n_classes, 'OUT', use_relu=False) - -# Define the loss function, optimizer, and accuracy -logits = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output_logits) -loss = tf.reduce_mean(logits, name='loss') -optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, name='Adam-op').minimize(loss) -correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name='correct_pred') -accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy') - -# Network predictions -cls_prediction = tf.argmax(output_logits, axis=1, name='predictions') - -# export graph -#tf.train.export_meta_graph(filename='neural_network.meta', graph=tf.get_default_graph(), clear_extraneous_savers= True, as_text = True) - -# Create the op for initializing all variables -init = tf.global_variables_initializer() - -# Create an interactive session (to keep the session in the other cells) -sess = tf.InteractiveSession() -# Initialize all variables -sess.run(init) -# Number of training iterations in each epoch -num_tr_iter = int(len(y_train) / batch_size) -for epoch in range(epochs): - print('Training epoch: {}'.format(epoch + 1)) - # Randomly shuffle the training data at the beginning of each epoch - x_train, y_train = randomize(x_train, y_train) - for iteration in range(num_tr_iter): - start = iteration * batch_size - end = (iteration + 1) * batch_size - x_batch, y_batch = get_next_batch(x_train, y_train, start, end) - - # Run optimization op (backprop) - feed_dict_batch = {x: x_batch, y: y_batch} - sess.run(optimizer, feed_dict=feed_dict_batch) - - if iteration % display_freq == 0: - # Calculate and display the batch loss and accuracy - loss_batch, acc_batch = sess.run([loss, accuracy], - feed_dict=feed_dict_batch) - - print("iter {0:3d}:\t Loss={1:.2f},\tTraining Accuracy={2:.01%}". - format(iteration, loss_batch, acc_batch)) - - # Run validation after every epoch - feed_dict_valid = {x: x_valid[:1000], y: y_valid[:1000]} - loss_valid, acc_valid = sess.run([loss, accuracy], feed_dict=feed_dict_valid) - print('---------------------------------------------------------') - print("Epoch: {0}, validation loss: {1:.2f}, validation accuracy: {2:.01%}". - format(epoch + 1, loss_valid, acc_valid)) - print('---------------------------------------------------------') \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs index 6c593929..0b09f783 100644 --- a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs +++ b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs @@ -33,5 +33,22 @@ namespace TensorFlowNET.UnitTest.Basics } } } + + [TestMethod] + public void Bug397() + { + // fix bug https://github.com/SciSharp/TensorFlow.NET/issues/397 + var W = tf.Variable(-1, name: "weight_" + 1, dtype: tf.float32); + var init = tf.global_variables_initializer(); + var reluEval = tf.nn.relu(W); + var nonZero = tf.assign(W, reluEval); + + using (var sess = tf.Session()) + { + sess.run(init); + float result = nonZero.eval(); + Assert.IsTrue(result == 0f); + } + } } } \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs index 58609c17..007b5624 100644 --- a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs @@ -11,7 +11,7 @@ namespace TensorFlowNET.UnitTest /// tensorflow\c\c_api_test.cc /// `class CApiGradientsTest` /// - [TestClass] + [TestClass, Ignore] public class CApiGradientsTest : CApiTest, IDisposable { private Graph graph_ = new Graph(); diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index ae57b075..fa293288 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow; +using Tensorflow.Util; namespace TensorFlowNET.UnitTest { @@ -22,9 +23,12 @@ namespace TensorFlowNET.UnitTest public CSession(Graph graph, Status s, bool user_XLA = false) { - var opts = new SessionOptions(); - opts.SetConfig(new ConfigProto { InterOpParallelismThreads = 4 }); - session_ = new Session(graph, opts, s); + lock (Locks.ProcessWide) + { + var opts = new SessionOptions(); + opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); + session_ = new Session(graph, opts, s); + } } public void SetInputs(Dictionary inputs) @@ -64,13 +68,13 @@ namespace TensorFlowNET.UnitTest public unsafe void Run(Status s) { var inputs_ptr = inputs_.ToArray(); - var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); + var input_values_ptr = input_values_.Select(x => (IntPtr) x).ToArray(); var outputs_ptr = outputs_.ToArray(); var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); IntPtr[] targets_ptr = new IntPtr[0]; c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, - outputs_ptr, output_values_ptr, outputs_.Count, + outputs_ptr, output_values_ptr, outputs_.Count, targets_ptr, targets_.Count, IntPtr.Zero, s); @@ -90,4 +94,4 @@ namespace TensorFlowNET.UnitTest ResetOutputValues(); } } -} +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs b/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs new file mode 100644 index 00000000..66cb48e3 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs @@ -0,0 +1,64 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow; +using Tensorflow.Eager; +using Tensorflow.Estimators; + +namespace TensorFlowNET.UnitTest.Estimators +{ + /// + /// estimator/tensorflow_estimator/python/estimator/run_config_test.py + /// + [TestClass] + public class RunConfigTest + { + private static readonly string _TEST_DIR = "test_dir"; + private static readonly string _MASTER = "master_"; + private static readonly string _NOT_SUPPORTED_REPLACE_PROPERTY_MSG = "Replacing .*is not supported"; + private static readonly string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set."; + private static readonly string _MODEL_DIR_ERR = "model_dir should be non-empty"; + private static readonly string _MODEL_DIR_TF_CONFIG_ERR = "model_dir in TF_CONFIG should be non-empty"; + private static readonly string _MODEL_DIR_MISMATCH_ERR = "`model_dir` provided in RunConfig construct, if set, must have the same value as the model_dir in TF_CONFIG. "; + private static readonly string _SAVE_SUMMARY_STEPS_ERR = "save_summary_steps should be >= 0"; + private static readonly string _SAVE_CKPT_STEPS_ERR = "save_checkpoints_steps should be >= 0"; + private static readonly string _SAVE_CKPT_SECS_ERR = "save_checkpoints_secs should be >= 0"; + private static readonly string _SESSION_CONFIG_ERR = "session_config must be instance of ConfigProto"; + private static readonly string _KEEP_CKPT_MAX_ERR = "keep_checkpoint_max should be >= 0"; + private static readonly string _KEEP_CKPT_HOURS_ERR = "keep_checkpoint_every_n_hours should be > 0"; + private static readonly string _TF_RANDOM_SEED_ERR = "tf_random_seed must be integer"; + private static readonly string _DEVICE_FN_ERR = "device_fn must be callable with exactly one argument \"op\"."; + private static readonly string _ONE_CHIEF_ERR = "The \"cluster\" in TF_CONFIG must have only one \"chief\" node."; + private static readonly string _ONE_MASTER_ERR = "The \"cluster\" in TF_CONFIG must have only one \"master\" node."; + private static readonly string _MISSING_CHIEF_ERR = "If \"cluster\" is set .* it must have one \"chief\" node"; + private static readonly string _MISSING_TASK_TYPE_ERR = "If \"cluster\" is set .* task type must be set"; + private static readonly string _MISSING_TASK_ID_ERR = "If \"cluster\" is set .* task index must be set"; + private static readonly string _INVALID_TASK_INDEX_ERR = "is not a valid task_id"; + private static readonly string _NEGATIVE_TASK_INDEX_ERR = "Task index must be non-negative number."; + private static readonly string _INVALID_TASK_TYPE_ERR = "is not a valid task_type"; + private static readonly string _INVALID_TASK_TYPE_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task type must be WORKER."; + private static readonly string _INVALID_TASK_INDEX_FOR_LOCAL_ERR = "If \"cluster\" is not set in TF_CONFIG, task index must be 0."; + private static readonly string _INVALID_EVALUATOR_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, task_type `evaluator` is not supported."; + private static readonly string _INVALID_CHIEF_IN_CLUSTER_WITH_MASTER_ERR = "If `master` node exists in `cluster`, job `chief` is not supported."; + private static readonly string _INVALID_SERVICE_TYPE_ERR = "If \"service\" is set in TF_CONFIG, it must be a dict. Given"; + private static readonly string _EXPERIMENTAL_MAX_WORKER_DELAY_SECS_ERR = "experimental_max_worker_delay_secs must be an integer if set."; + private static readonly string _SESSION_CREATION_TIMEOUT_SECS_ERR = "session_creation_timeout_secs should be > 0"; + + [TestMethod] + public void test_default_property_values() + { + var config = new RunConfig(); + + Assert.IsNull(config.model_dir); + Assert.IsNull(config.session_config); + Assert.IsNull(config.tf_random_seed); + Assert.AreEqual(100, config.save_summary_steps); + Assert.AreEqual(600, config.save_checkpoints_secs); + Assert.AreEqual(5, config.keep_checkpoint_max); + Assert.AreEqual(10000, config.keep_checkpoint_every_n_hours); + Assert.IsNull(config.service); + Assert.IsNull(config.device_fn); + Assert.IsNull(config.experimental_max_worker_delay_secs); + Assert.AreEqual(7200, config.session_creation_timeout_secs); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs deleted file mode 100644 index c980692c..00000000 --- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs +++ /dev/null @@ -1,126 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow; -using TensorFlowNET.Examples; -using static Tensorflow.Binding; - -namespace TensorFlowNET.ExamplesTests -{ - [TestClass] - public class ExamplesTest - { - [TestMethod] - public void BasicOperations() - { - tf.Graph().as_default(); - new BasicOperations() { Enabled = true }.Run(); - } - - [TestMethod] - public void HelloWorld() - { - tf.Graph().as_default(); - new HelloWorld() { Enabled = true }.Run(); - } - - [TestMethod] - public void ImageRecognition() - { - tf.Graph().as_default(); - new HelloWorld() { Enabled = true }.Run(); - } - - [Ignore] - [TestMethod] - public void InceptionArchGoogLeNet() - { - tf.Graph().as_default(); - new InceptionArchGoogLeNet() { Enabled = true }.Run(); - } - - [Ignore] - [TestMethod] - public void KMeansClustering() - { - tf.Graph().as_default(); - new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run(); - } - - [TestMethod] - public void LinearRegression() - { - tf.Graph().as_default(); - new LinearRegression() { Enabled = true }.Run(); - } - - [TestMethod] - public void LogisticRegression() - { - tf.Graph().as_default(); - new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run(); - } - - [Ignore] - [TestMethod] - public void NaiveBayesClassifier() - { - tf.Graph().as_default(); - new NaiveBayesClassifier() { Enabled = false }.Run(); - } - - [Ignore] - [TestMethod] - public void NamedEntityRecognition() - { - tf.Graph().as_default(); - new NamedEntityRecognition() { Enabled = true }.Run(); - } - - [TestMethod] - public void NearestNeighbor() - { - tf.Graph().as_default(); - new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); - } - - [Ignore] - [TestMethod] - public void WordCnnTextClassification() - => new CnnTextClassification { Enabled = true, ModelName = "word_cnn", DataLimit =100 }.Run(); - - [Ignore] - [TestMethod] - public void CharCnnTextClassification() - => new CnnTextClassification { Enabled = true, ModelName = "char_cnn", DataLimit = 100 }.Run(); - - [Ignore] - [TestMethod] - public void TextClassificationWithMovieReviews() - { - tf.Graph().as_default(); - new BinaryTextClassification() { Enabled = true }.Run(); - } - - [TestMethod] - public void NeuralNetXor() - { - tf.Graph().as_default(); - Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Run()); - } - - [Ignore("Not working")] - [TestMethod] - public void NeuralNetXor_ImportedGraph() - { - tf.Graph().as_default(); - Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Run()); - } - - [Ignore("Not working")] - [TestMethod] - public void ObjectDetection() - { - tf.Graph().as_default(); - Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Run()); - } - } -} diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 443191dd..6a117ac1 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -207,7 +207,7 @@ namespace TensorFlowNET.UnitTest public void ImportGraphDef() { var s = new Status(); - var graph = new Graph(); + var graph = new Graph().as_default(); // Create a simple graph. c_test_util.Placeholder(graph, s); @@ -221,7 +221,7 @@ namespace TensorFlowNET.UnitTest // Import it, with a prefix, in a fresh graph. graph.Dispose(); - graph = new Graph(); + graph = new Graph().as_default(); var opts = c_api.TF_NewImportGraphDefOptions(); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); @@ -359,7 +359,7 @@ namespace TensorFlowNET.UnitTest public void ImportGraphDef_WithReturnOutputs() { var s = new Status(); - var graph = new Graph(); + var graph = new Graph().as_default(); // Create a graph with two nodes: x and 3 c_test_util.Placeholder(graph, s); @@ -375,7 +375,7 @@ namespace TensorFlowNET.UnitTest // Import it in a fresh graph with return outputs. graph.Dispose(); - graph = new Graph(); + graph = new Graph().as_default(); var opts = new ImportGraphDefOptions(); opts.AddReturnOutput("feed", 0); opts.AddReturnOutput("scalar", 0); @@ -411,17 +411,18 @@ namespace TensorFlowNET.UnitTest } + [Ignore] [TestMethod] public void ImportGraphMeta() { var dir = "my-save-dir/"; using (var sess = tf.Session()) { - var new_saver = tf.train.import_meta_graph(@"D:\tmp\resnet_v2_101_2017_04_14\eval.graph"); + var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); new_saver.restore(sess, dir + "my-model-10000"); var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); var batch_size = tf.size(labels); - var logits = (tf.get_collection("logits") as List)[0] as Tensor; + var logits = tf.get_collection("logits")[0] as Tensor; var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, logits: logits); } diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs index e4f8a835..dd0b8b38 100644 --- a/test/TensorFlowNET.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -14,15 +14,17 @@ namespace TensorFlowNET.UnitTest [TestClass] public class ImageTest { - string imgPath = "../../../../../data/shasta-daisy.jpg"; + string imgPath = "shasta-daisy.jpg"; Tensor contents; - public ImageTest() + [TestInitialize] + public void Initialize() { imgPath = Path.GetFullPath(imgPath); contents = tf.read_file(imgPath); } + [Ignore("")] [TestMethod] public void decode_image() { diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs new file mode 100644 index 00000000..f0a79ed6 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs @@ -0,0 +1,331 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class MultithreadingTests + { + [TestMethod] + public void SessionCreation() + { + ops.uid(); //increment id by one + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + + using (var sess = tf.Session()) + { + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.GetPrivate("_graph"); + sess_graph.Should().NotBeNull(); + default_graph.Should().NotBeNull() + .And.BeEquivalentTo(sess_graph); + } + } + } + + [TestMethod] + public void SessionCreation_x2() + { + ops.uid(); //increment id by one + + MultiThreadedUnitTestExecuter.Run(16, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + //tf.Session created an other graph + using (var sess = tf.Session()) + { + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.GetPrivate("_graph"); + sess_graph.Should().NotBeNull(); + default_graph.Should().NotBeNull() + .And.BeEquivalentTo(sess_graph); + } + } + } + + [TestMethod] + public void GraphCreation() + { + ops.uid(); //increment id by one + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + var beforehand = tf.get_default_graph(); //this should create default automatically. + beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); + tf.peak_default_graph().Should().NotBeNull(); + + using (var sess = tf.Session()) + { + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.GetPrivate("_graph"); + sess_graph.Should().NotBeNull(); + default_graph.Should().NotBeNull() + .And.BeEquivalentTo(sess_graph) + .And.BeEquivalentTo(beforehand); + + Console.WriteLine($"{tid}-{default_graph.graph_key}"); + + //var result = sess.run(new object[] {g, a}); + //var actualDeriv = result[0].GetData()[0]; + //var actual = result[1].GetData()[0]; + } + } + } + + + [TestMethod] + public void Marshal_AllocHGlobal() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + for (int i = 0; i < 100; i++) + { + Marshal.FreeHGlobal(Marshal.AllocHGlobal(sizeof(int))); + } + } + } + + [TestMethod] + public void TensorCreation() + { + //lock (Locks.ProcessWide) + // tf.Session(); //create one to increase next id to 1. + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + using (var sess = tf.Session()) + { + Tensor t = null; + for (int i = 0; i < 100; i++) + { + t = new Tensor(1); + } + } + } + } + + [TestMethod] + public void TensorCreation_Array() + { + //lock (Locks.ProcessWide) + // tf.Session(); //create one to increase next id to 1. + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + //tf.Session created an other graph + using (var sess = tf.Session()) + { + Tensor t = null; + for (int i = 0; i < 100; i++) + { + t = new Tensor(new int[] {1, 2, 3}); + } + } + } + } + + [TestMethod] + public void TensorCreation_Undressed() + { + //lock (Locks.ProcessWide) + // tf.Session(); //create one to increase next id to 1. + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + unsafe void Core(int tid) + { + using (var sess = tf.Session()) + { + Tensor t = null; + for (int i = 0; i < 100; i++) + { + var v = (int*) Marshal.AllocHGlobal(sizeof(int)); + c_api.DeallocatorArgs _deallocatorArgs = new c_api.DeallocatorArgs(); + var handle = c_api.TF_NewTensor(typeof(int).as_dtype(), dims: new long[0], num_dims: 0, + data: (IntPtr) v, len: (UIntPtr) sizeof(int), + deallocator: (IntPtr data, IntPtr size, ref c_api.DeallocatorArgs args) => Marshal.FreeHGlobal(data), + ref _deallocatorArgs); + c_api.TF_DeleteTensor(handle); + } + } + } + } + + [TestMethod] + public void SessionRun() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] {2f}, shape: new[] {1}); + var a2 = tf.constant(new[] {3f}, shape: new[] {1}); + var math = a1 + a2; + for (int i = 0; i < 100; i++) + { + using (var sess = tf.Session()) + { + sess.run(math).GetAtIndex(0).Should().Be(5); + } + } + } + } + + [TestMethod] + public void SessionRun_InsideSession() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + using (var sess = tf.Session()) + { + tf.peak_default_graph().Should().NotBeNull(); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] {2f}, shape: new[] {1}); + var a2 = tf.constant(new[] {3f}, shape: new[] {1}); + var math = a1 + a2; + + var result = sess.run(math); + result[0].GetAtIndex(0).Should().Be(5); + } + } + } + + [TestMethod] + public void SessionRun_Initialization() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + using (var sess = tf.Session()) + { + tf.peak_default_graph().Should().NotBeNull(); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] {2f}, shape: new[] {1}); + var a2 = tf.constant(new[] {3f}, shape: new[] {1}); + var math = a1 + a2; + } + } + } + + [TestMethod] + public void SessionRun_Initialization_OutsideSession() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] {2f}, shape: new[] {1}); + var a2 = tf.constant(new[] {3f}, shape: new[] {1}); + var math = a1 + a2; + } + } + + + [TestMethod] + public void TF_GraphOperationByName() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] {2f}, shape: new[] {1}); + var a2 = tf.constant(new[] {3f}, shape: new[] {1}, name: "ConstantK"); + var math = a1 + a2; + for (int i = 0; i < 100; i++) + { + var op = tf.get_default_graph().OperationByName("ConstantK"); + } + } + } + + private static string modelPath = "./model/"; + + [TestMethod] + public void TF_GraphOperationByName_FromModel() + { + if (!Directory.Exists(modelPath)) + return; + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + Console.WriteLine(); + for (int j = 0; j < 100; j++) + { + var sess = Session.LoadFromSavedModel(modelPath).as_default(); + var inputs = new[] {"sp", "fuel"}; + + var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray(); + var outp = sess.graph.OperationByName("softmax_tensor").output; + + for (var i = 0; i < 100; i++) + { + { + var data = new float[96]; + FeedItem[] feeds = new FeedItem[2]; + + for (int f = 0; f < 2; f++) + feeds[f] = new FeedItem(inp[f], new NDArray(data)); + + try + { + sess.run(outp, feeds); + } catch (Exception ex) + { + Console.WriteLine(ex); + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 226a4839..b5d37d35 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -1494,5 +1494,23 @@ namespace TensorFlowNET.UnitTest } #endregion } + + [Ignore("Not finished yet")] + [TestMethod] + public void map_fn() + { + var a = tf.constant(new[] { 1, 2, 3, 4 }); + var b = tf.constant(new[] { 17, 12, 11, 10 }); + var ab = tf.stack(new[] { a, b }, 1); + + Func map_operation = (value_ab) => + { + var value_a = value_ab[0]; + var value_b = value_ab[1]; + return value_a + value_b; + }; + + var map_result = tf.map_fn(map_operation, ab); + } } } diff --git a/test/TensorFlowNET.UnitTest/PythonBaseTests.cs b/test/TensorFlowNET.UnitTest/PythonBaseTests.cs index 6fd0d7c3..8dfbdeef 100644 --- a/test/TensorFlowNET.UnitTest/PythonBaseTests.cs +++ b/test/TensorFlowNET.UnitTest/PythonBaseTests.cs @@ -49,27 +49,6 @@ namespace TensorFlowNET.UnitTest Assert.IsFalse(false2); Assert.IsFalse(false3); } - - [TestMethod] - public void hasattr_getattr() - { - var s1 = "Tensorflow v0.1"; - var f = "Tensorflow"; - var r = "Tensorflow.NET"; - var res = s1.Replace(f, r); - - // Test 1 - Assert.IsTrue(hasattr(s1, "Replace")); - - // Test 2 - var o = getattr( s1, "Replace", typeof(string), typeof(string)); - Assert.AreEqual(res, o(f, r)); - - // Test 3 - var l = getattr(s1, "Length"); - Assert.AreEqual(s1.Length, l()); - - } } } diff --git a/test/TensorFlowNET.UnitTest/QueueTest.cs b/test/TensorFlowNET.UnitTest/QueueTest.cs new file mode 100644 index 00000000..731635b7 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/QueueTest.cs @@ -0,0 +1,116 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class QueueTest + { + [TestMethod] + public void PaddingFIFOQueue() + { + var numbers = tf.placeholder(tf.int32); + var queue = tf.PaddingFIFOQueue(10, tf.int32, new TensorShape(-1)); + var enqueue = queue.enqueue(numbers); + var dequeue_many = queue.dequeue_many(n: 3); + + using(var sess = tf.Session()) + { + sess.run(enqueue, (numbers, new[] { 1 })); + sess.run(enqueue, (numbers, new[] { 2, 3 })); + sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); + + var result = sess.run(dequeue_many[0]); + + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray())); + } + } + + [TestMethod] + public void FIFOQueue() + { + // create a first in first out queue with capacity up to 2 + // and data type set as int32 + var queue = tf.FIFOQueue(2, tf.int32); + // init queue, push 3 elements into queue. + var init = queue.enqueue_many(new[] { 10, 20 }); + // pop out the first element + var x = queue.dequeue(); + // add 1 + var y = x + 1; + // push back into queue + var inc = queue.enqueue(y); + + using (var sess = tf.Session()) + { + // init queue + init.run(); + + // pop out first element and push back calculated y + (int dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(10, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(20, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(11, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(21, dequeued); + + // thread will hang or block if you run sess.run(x) again + // until queue has more element. + } + } + + [TestMethod] + public void PriorityQueue() + { + var queue = tf.PriorityQueue(3, tf.@string); + var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); + var x = queue.dequeue(); + + using (var sess = tf.Session()) + { + init.run(); + + var result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 2L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 3L); + + result = sess.run(x); + Assert.AreEqual(result[0].GetInt64(), 4L); + } + } + + [TestMethod] + public void RandomShuffleQueue() + { + var queue = tf.RandomShuffleQueue(10, min_after_dequeue: 1, dtype: tf.int32); + var init = queue.enqueue_many(new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + var x = queue.dequeue(); + + string results = ""; + using (var sess = tf.Session()) + { + init.run(); + + foreach(var i in range(9)) + results += (int)sess.run(x) + "."; + + // output in random order + Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9."); + } + } + } +} diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 62d7c63d..f1453c0e 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -2,7 +2,14 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using FluentAssertions; +using Google.Protobuf; +using NumSharp.Backends; using Tensorflow; +using Tensorflow.Util; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest @@ -14,77 +21,172 @@ namespace TensorFlowNET.UnitTest /// tensorflow\c\c_api_test.cc /// `TEST(CAPI, Session)` /// - [TestMethod] + [TestMethod, Ignore] public void Session() { - var s = new Status(); - var graph = new Graph(); - - // Make a placeholder operation. - var feed = c_test_util.Placeholder(graph, s); - - // Make a constant operation with the scalar "2". - var two = c_test_util.ScalarConst(2, graph, s); - - // Add operation. - var add = c_test_util.Add(feed, two, graph, s); - - var csession = new CSession(graph, s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); - - // Run the graph. - var inputs = new Dictionary(); - inputs.Add(feed, new Tensor(3)); - csession.SetInputs(inputs); - - var outputs = new TF_Output[] { new TF_Output(add, 0) }; - csession.SetOutputs(outputs); - - csession.Run(s); - Tensor outTensor = csession.output_tensor(0); - EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); - EXPECT_EQ(0, outTensor.NDims); - ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - var output_contents = outTensor.ToArray(); - EXPECT_EQ(3 + 2, output_contents[0]); - - // Add another operation to the graph. - var neg = c_test_util.Neg(add, graph, s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); - - // Run up to the new operation. - inputs = new Dictionary(); - inputs.Add(feed, new Tensor(7)); - csession.SetInputs(inputs); - outputs = new TF_Output[] { new TF_Output(neg, 0) }; - csession.SetOutputs(outputs); - csession.Run(s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); - - outTensor = csession.output_tensor(0); - ASSERT_TRUE(outTensor != IntPtr.Zero); - EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); - EXPECT_EQ(0, outTensor.NDims); // scalar - ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - output_contents = outTensor.ToArray(); - EXPECT_EQ(-(7 + 2), output_contents[0]); - - // Clean up - csession.CloseAndDelete(s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); + lock (Locks.ProcessWide) + { + var s = new Status(); + var graph = new Graph().as_default(); + + // Make a placeholder operation. + var feed = c_test_util.Placeholder(graph, s); + + // Make a constant operation with the scalar "2". + var two = c_test_util.ScalarConst(2, graph, s); + + // Add operation. + var add = c_test_util.Add(feed, two, graph, s); + + var csession = new CSession(graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Run the graph. + var inputs = new Dictionary(); + inputs.Add(feed, new Tensor(3)); + csession.SetInputs(inputs); + + var outputs = new TF_Output[] {new TF_Output(add, 0)}; + csession.SetOutputs(outputs); + + csession.Run(s); + Tensor outTensor = csession.output_tensor(0); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.NDims); + ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize); + var output_contents = outTensor.ToArray(); + EXPECT_EQ(3 + 2, output_contents[0]); + + // Add another operation to the graph. + var neg = c_test_util.Neg(add, graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Run up to the new operation. + inputs = new Dictionary(); + inputs.Add(feed, new Tensor(7)); + csession.SetInputs(inputs); + outputs = new TF_Output[] {new TF_Output(neg, 0)}; + csession.SetOutputs(outputs); + csession.Run(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + outTensor = csession.output_tensor(0); + ASSERT_TRUE(outTensor != IntPtr.Zero); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.NDims); // scalar + ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize); + output_contents = outTensor.ToArray(); + EXPECT_EQ(-(7 + 2), output_contents[0]); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + } } [TestMethod] public void EvalTensor() { - var a = constant_op.constant(np.array(3.0).reshape(1, 1)); - var b = constant_op.constant(np.array(2.0).reshape(1, 1)); - var c = math_ops.matmul(a, b, name: "matmul"); - using (var sess = tf.Session()) + lock (this) + { + var a = constant_op.constant(np.array(3.0).reshape(1, 1)); + var b = constant_op.constant(np.array(2.0).reshape(1, 1)); + var c = math_ops.matmul(a, b, name: "matmul"); + using (var sess = tf.Session()) + { + var result = c.eval(sess); + Assert.AreEqual(6, result.GetAtIndex(0)); + } + } + } + + [TestMethod] + public void Eval_SmallString_Scalar() + { + lock (this) + { + var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING); + var c = tf.strings.substr(a, 4, 8); + using (var sess = tf.Session()) + { + var result = (string) c.eval(sess); + Console.WriteLine(result); + result.Should().Be("heythere"); + } + } + } + + [TestMethod] + public void Eval_LargeString_Scalar() + { + lock (this) { - var result = c.eval(sess); - Assert.AreEqual(6, result.Data()[0]); + const int size = 30_000; + var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING); + var c = tf.strings.substr(a, 0, size - 5000); + using (var sess = tf.Session()) + { + var result = (string) c.eval(sess); + Console.WriteLine((string) result); + result.Should().HaveLength(size - 5000).And.ContainAll("a"); + } } } + + [TestMethod] + public void Autocast_Case1() + { + var sess = tf.Session().as_default(); + var input = tf.placeholder(tf.float64, shape: new TensorShape(6)); + var op = tf.reshape(input, new int[] {2, 3}); + sess.run(tf.global_variables_initializer()); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6))); + + ret.Should().BeOfType().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); + print(ret.dtype); + print(ret); + } + + [TestMethod] + public void Autocast_Case2() + { + var sess = tf.Session().as_default(); + var input = tf.placeholder(tf.float64, shape: new TensorShape(6)); + var op = tf.reshape(input, new int[] {2, 3}); + sess.run(tf.global_variables_initializer()); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f)); + + ret.Should().BeOfType().And.BeShaped(2, 3).And.BeOfValuesApproximately(0.001d, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1); + print(ret.dtype); + print(ret); + } + + [TestMethod] + public void Autocast_Case3() + { + var sess = tf.Session().as_default(); + var input = tf.placeholder(tf.int16, shape: new TensorShape(6)); + var op = tf.reshape(input, new int[] {2, 3}); + sess.run(tf.global_variables_initializer()); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f)); + + ret.Should().BeOfType().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); + print(ret.dtype); + print(ret); + } + + [TestMethod] + public void Autocast_Case4() + { + var sess = tf.Session().as_default(); + var input = tf.placeholder(tf.@byte, shape: new TensorShape(6)); + var op = tf.reshape(input, new int[] {2, 3}); + sess.run(tf.global_variables_initializer()); + var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f)); + + ret.Should().BeOfType().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6); + print(ret.dtype); + print(ret); + } } -} +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 661d85ea..6cc1a87d 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -1,7 +1,7 @@  - netcoreapp2.2 + netcoreapp3.0 false @@ -10,6 +10,8 @@ false Open.snk + + latest @@ -22,6 +24,11 @@ + + + + + @@ -29,11 +36,17 @@ - - - - + + + + + + + + + PreserveNewest + diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings new file mode 100644 index 00000000..6cbf8796 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings @@ -0,0 +1,2 @@ + + True \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/TensorShapeTest.cs b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs new file mode 100644 index 00000000..b7846ce3 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TensorShapeTest.cs @@ -0,0 +1,67 @@ +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class TensorShapeTest + { + [TestMethod] + public void Case1() + { + int a = 2; + int b = 3; + var dims = new [] { Unknown, a, b}; + new TensorShape(dims).GetPrivate("shape").Should().BeShaped(-1, 2, 3); + } + + [TestMethod] + public void Case2() + { + int a = 2; + int b = 3; + var dims = new[] { Unknown, a, b}; + new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, 3); + } + + [TestMethod] + public void Case3() + { + int a = 2; + int b = Unknown; + var dims = new [] { Unknown, a, b}; + new TensorShape(new [] {dims}).GetPrivate("shape").Should().BeShaped(-1, 2, -1); + } + + [TestMethod] + public void Case4() + { + TensorShape shape = (Unknown, Unknown); + shape.GetPrivate("shape").Should().BeShaped(-1, -1); + } + + [TestMethod] + public void Case5() + { + TensorShape shape = (1, Unknown, 3); + shape.GetPrivate("shape").Should().BeShaped(1, -1, 3); + } + + [TestMethod] + public void Case6() + { + TensorShape shape = (Unknown, 1, 2, 3, Unknown); + shape.GetPrivate("shape").Should().BeShaped(-1, 1, 2, 3, -1); + } + + [TestMethod] + public void Case7() + { + TensorShape shape = new TensorShape(); + Assert.AreEqual(shape.rank, -1); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 11557f14..fe68d718 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -4,85 +4,73 @@ using System; using System.Linq; using System.Runtime.InteropServices; using System.Threading; +using FluentAssertions; using Tensorflow; using static Tensorflow.Binding; +using Tensorflow.Framework; namespace TensorFlowNET.UnitTest { [TestClass] public class TensorTest : CApiTest { - [Ignore("Not for mult-thread")] - public void TensorDeallocationThreadSafety() - { - var tensors = new Tensor[1000]; - foreach (var i in range(1000)) - { - tensors[i] = new Tensor(new int[1000]); - } - SemaphoreSlim s = new SemaphoreSlim(0, 2); - SemaphoreSlim s_done = new SemaphoreSlim(0, 2); - - var t1 = new Thread(() => - { - s.Wait(); - foreach (var t in tensors) - t.Dispose(); - s_done.Release(); - }); - - var t2 = new Thread(() => - { - s.Wait(); - foreach (var t in tensors) - t.Dispose(); - s_done.Release(); - }); - - t1.Start(); - t2.Start(); - s.Release(2); - s_done.Wait(); - s_done.Wait(); - - foreach (var t in tensors) - Assert.IsTrue(t.IsDisposed); - } - [TestMethod] public unsafe void TensorFromFixed() { var array = new float[1000]; var span = new Span(array, 100, 500); - fixed (float* ptr=&MemoryMarshal.GetReference(span)) + fixed (float* ptr = &MemoryMarshal.GetReference(span)) { - using (var t = new Tensor((IntPtr)ptr, new long[] {span.Length}, tf.float32, 4*span.Length)) + using (var t = new Tensor((IntPtr) ptr, new long[] {span.Length}, tf.float32, 4 * span.Length)) { Assert.IsFalse(t.IsDisposed); - Assert.IsFalse(t.IsMemoryOwner); Assert.AreEqual(2000, (int) t.bytesize); } } + fixed (float* ptr = &array[0]) { - using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length)) + using (var t = new Tensor((IntPtr) ptr, new long[] {array.Length}, tf.float32, 4 * array.Length)) { Assert.IsFalse(t.IsDisposed); - Assert.IsFalse(t.IsMemoryOwner); - Assert.AreEqual(4000, (int)t.bytesize); + Assert.AreEqual(4000, (int) t.bytesize); } } } + [TestMethod] + public unsafe void TensorFromArray() + { + var array = new float[1000]; + using (var t = new Tensor(array, new long[] {array.Length}, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1000 * sizeof(float), (int) t.bytesize); + } + + using (var t = new Tensor(new float[] {1}, new long[] {1}, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1 * sizeof(float), (int) t.bytesize); + } + + using (var t = new Tensor(new float[] {1}, null, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1 * sizeof(float), (int) t.bytesize); + t.shape.Should().BeEmpty(); + } + } + [TestMethod] public void AllocateTensor() { ulong num_bytes = 6 * sizeof(float); - long[] dims = { 2, 3 }; + long[] dims = {2, 3}; Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); EXPECT_EQ(2, t.NDims); - EXPECT_EQ((int)dims[0], t.shape[0]); + EXPECT_EQ((int) dims[0], t.shape[0]); EXPECT_EQ(num_bytes, t.bytesize); t.Dispose(); } @@ -98,7 +86,7 @@ namespace TensorFlowNET.UnitTest NDArray nd = np.array(2, 3); Tensor t = new Tensor(nd); Tensor o = t.MaybeMove(); - ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. + ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. t.Dispose(); } @@ -116,10 +104,10 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); EXPECT_EQ(tensor.rank, nd.ndim); - EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); - EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); - EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); - Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] { 1, 2, 3, 4, 5, 6 })); + EXPECT_EQ((int) tensor.shape[0], nd.shape[0]); + EXPECT_EQ((int) tensor.shape[1], nd.shape[1]); + EXPECT_EQ(tensor.bytesize, (ulong) nd.size * sizeof(float)); + Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] {1, 2, 3, 4, 5, 6})); } /// @@ -130,7 +118,7 @@ namespace TensorFlowNET.UnitTest public void SetShape() { var s = new Status(); - var graph = new Graph(); + var graph = new Graph().as_default(); var feed = c_test_util.Placeholder(graph, s); var feed_out_0 = new TF_Output(feed, 0); @@ -148,7 +136,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(-1, num_dims); // Set the shape to be 2 x Unknown - long[] dims = { 2, -1 }; + long[] dims = {2, -1}; c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); @@ -177,8 +165,8 @@ namespace TensorFlowNET.UnitTest c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); EXPECT_EQ(2, num_dims); - EXPECT_EQ(2, (int)returned_dims[0]); - EXPECT_EQ(3, (int)returned_dims[1]); + EXPECT_EQ(2, (int) returned_dims[0]); + EXPECT_EQ(3, (int) returned_dims[1]); // Try to set 'unknown' with same rank on the shape and see that // it doesn't change. @@ -189,8 +177,8 @@ namespace TensorFlowNET.UnitTest c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); EXPECT_EQ(2, num_dims); - EXPECT_EQ(2, (int)returned_dims[0]); - EXPECT_EQ(3, (int)returned_dims[1]); + EXPECT_EQ(2, (int) returned_dims[0]); + EXPECT_EQ(3, (int) returned_dims[1]); // Try to fetch a shape with the wrong num_dims c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); @@ -215,5 +203,75 @@ namespace TensorFlowNET.UnitTest // graph.Dispose(); s.Dispose(); } + + [TestMethod] + public void sparse_to_dense() + { + var indices = tf.reshape(tf.range(0, 5), new int[] { 5, 1 }); + var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }),1); + var st = tf.concat(values: new[] { indices, labels }, axis: 1); + var onehot = tf.sparse_to_dense(st, (5, 5), 1); + using (var sess = tf.Session()) + { + var result = sess.run(onehot); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray())); + }; + } + + [TestMethod] + public void sparse_tensor_to_dense() + { + var decoded_list = tf.SparseTensor(new[,] + { + { 0L, 0L }, + { 1L, 2L } + }, + new int[] { 1, 2 }, + new[] { 3L, 4L }); + + var onehot = tf.sparse_tensor_to_dense(decoded_list); + using (var sess = tf.Session()) + { + var result = sess.run(onehot); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray())); + } + } + + [TestMethod] + public void batch_to_space_nd() + { + var inputs = np.arange(24).reshape(4, 2, 3); + var block_shape = new[] { 2, 2 }; + int[,] crops = { { 0, 0 }, { 0, 0 } }; + var tensor = tf.batch_to_space_nd(inputs, block_shape, crops); + + using (var sess = tf.Session()) + { + var result = sess.run(tensor); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray())); + } + } + + [TestMethod] + public void boolean_mask() + { + var tensor = new[] { 0, 1, 2, 3 }; + var mask = np.array(new[] { true, false, true, false }); + var masked = tf.boolean_mask(tensor, mask); + using (var sess = tf.Session()) + { + var result = sess.run(masked); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, result.ToArray())); + } + } } -} +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs new file mode 100644 index 00000000..7bd16888 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/FluentExtension.cs @@ -0,0 +1,1206 @@ +using System; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using FluentAssertions; +using FluentAssertions.Execution; +using FluentAssertions.Primitives; +using NumSharp; +using NumSharp.Backends; +using NumSharp.Utilities; + +namespace TensorFlowNET.UnitTest +{ + [DebuggerStepThrough] + public static class FluentExtension + { + public static ShapeAssertions Should(this Shape shape) + { + return new ShapeAssertions(shape); + } + + public static NDArrayAssertions Should(this NDArray arr) + { + return new NDArrayAssertions(arr); + } + + public static NDArrayAssertions Should(this UnmanagedStorage arr) + { + return new NDArrayAssertions(arr); + } + + public static string ToString(this Array arr, bool flat) + { + return new NDArray(arr).ToString(flat); + } + } + + [DebuggerStepThrough] + public class ShapeAssertions : ReferenceTypeAssertions + { + public ShapeAssertions(Shape instance) + { + Subject = instance; + } + + protected override string Identifier => "shape"; + + public AndConstraint BeOfSize(int size, string because = null, params object[] becauseArgs) + { + Subject.Size.Should().Be(size, because, becauseArgs); + return new AndConstraint(this); + } + + public AndConstraint NotBeOfSize(int size, string because = null, params object[] becauseArgs) + { + Subject.Size.Should().NotBe(size, because, becauseArgs); + return new AndConstraint(this); + } + + public AndConstraint BeShaped(params int[] dimensions) + { + if (dimensions == null) + throw new ArgumentNullException(nameof(dimensions)); + + if (dimensions.Length == 0) + throw new ArgumentException("Value cannot be an empty collection.", nameof(dimensions)); + + Subject.Dimensions.Should().BeEquivalentTo(dimensions); + return new AndConstraint(this); + } + + public AndConstraint Be(Shape shape, string because = null, params object[] becauseArgs) + { + Execute.Assertion + .BecauseOf(because, becauseArgs) + .ForCondition(Subject.Equals(shape)) + .FailWith($"Expected shape to be {shape.ToString()} but got {Subject.ToString()}"); + + return new AndConstraint(this); + } + + public AndConstraint BeEquivalentTo(int? size = null, int? ndim = null, ITuple shape = null) + { + if (size.HasValue) + { + BeOfSize(size.Value, null); + } + + if (ndim.HasValue) + HaveNDim(ndim.Value); + + if (shape != null) + for (int i = 0; i < shape.Length; i++) + { + Subject.Dimensions[i].Should().Be((int) shape[i]); + } + + return new AndConstraint(this); + } + + public AndConstraint NotBe(Shape shape, string because = null, params object[] becauseArgs) + { + Execute.Assertion + .BecauseOf(because, becauseArgs) + .ForCondition(!Subject.Equals(shape)) + .FailWith($"Expected shape to be {shape.ToString()} but got {Subject.ToString()}"); + + return new AndConstraint(this); + } + + public AndConstraint HaveNDim(int ndim) + { + Subject.Dimensions.Length.Should().Be(ndim); + return new AndConstraint(this); + } + + public AndConstraint BeSliced() + { + Subject.IsSliced.Should().BeTrue(); + return new AndConstraint(this); + } + + public AndConstraint BeScalar() + { + Subject.IsScalar.Should().BeTrue(); + return new AndConstraint(this); + } + + public AndConstraint BeBroadcasted() + { + Subject.IsBroadcasted.Should().BeTrue(); + return new AndConstraint(this); + } + + + public AndConstraint NotBeSliced() + { + Subject.IsSliced.Should().BeFalse(); + return new AndConstraint(this); + } + + public AndConstraint NotBeScalar() + { + Subject.IsScalar.Should().BeFalse(); + return new AndConstraint(this); + } + + public AndConstraint NotBeBroadcasted() + { + Subject.IsBroadcasted.Should().BeFalse(); + return new AndConstraint(this); + } + + public AndConstraint BeNDim(int ndim) + { + Subject.Dimensions.Length.Should().Be(ndim); + return new AndConstraint(this); + } + } + + //[DebuggerStepThrough] + public class NDArrayAssertions : ReferenceTypeAssertions + { + public NDArrayAssertions(NDArray instance) + { + Subject = instance; + } + + public NDArrayAssertions(UnmanagedStorage instance) + { + Subject = new NDArray(instance); + } + + protected override string Identifier => "shape"; + + public AndConstraint BeOfSize(int size, string because = null, params object[] becauseArgs) + { + Subject.size.Should().Be(size, because, becauseArgs); + return new AndConstraint(this); + } + + public AndConstraint BeShaped(params int[] dimensions) + { + if (dimensions == null) + throw new ArgumentNullException(nameof(dimensions)); + + if (dimensions.Length == 0) + throw new ArgumentException("Value cannot be an empty collection.", nameof(dimensions)); + + Subject.Unsafe.Storage.Shape.Dimensions.Should().BeEquivalentTo(dimensions); + return new AndConstraint(this); + } + + public AndConstraint BeShaped(int? size = null, int? ndim = null, ITuple shape = null) + { + if (size.HasValue) + { + BeOfSize(size.Value, null); + } + + if (ndim.HasValue) + HaveNDim(ndim.Value); + + if (shape != null) + for (int i = 0; i < shape.Length; i++) + { + Subject.Unsafe.Storage.Shape.Dimensions[i].Should().Be((int) shape[i]); + } + + return new AndConstraint(this); + } + + public AndConstraint NotBeShaped(Shape shape, string because = null, params object[] becauseArgs) + { + Execute.Assertion + .BecauseOf(because, becauseArgs) + .ForCondition(!Subject.Unsafe.Storage.Shape.Equals(shape)) + .FailWith($"Expected shape to be {shape.ToString()} but got {Subject.ToString()}"); + + return new AndConstraint(this); + } + + public AndConstraint HaveNDim(int ndim) + { + Subject.Unsafe.Storage.Shape.Dimensions.Length.Should().Be(ndim); + return new AndConstraint(this); + } + + public AndConstraint BeBroadcasted() + { + Subject.Unsafe.Storage.Shape.IsBroadcasted.Should().BeTrue(); + return new AndConstraint(this); + } + + public AndConstraint NotBeBroadcasted() + { + Subject.Unsafe.Storage.Shape.IsBroadcasted.Should().BeFalse(); + return new AndConstraint(this); + } + + public AndConstraint BeSliced() + { + Subject.Unsafe.Storage.Shape.IsSliced.Should().BeTrue(); + return new AndConstraint(this); + } + + public AndConstraint BeScalar() + { + Subject.Unsafe.Storage.Shape.IsScalar.Should().BeTrue(); + return new AndConstraint(this); + } + + public AndConstraint BeScalar(object value) + { + Subject.Unsafe.Storage.Shape.IsScalar.Should().BeTrue(); + Subject.GetValue().Should().Be(value); + return new AndConstraint(this); + } + + public AndConstraint BeOfType(NPTypeCode typeCode) + { + Subject.typecode.Should().Be(typeCode); + return new AndConstraint(this); + } + + public AndConstraint BeOfType(Type typeCode) + { + Subject.dtype.Should().Be(typeCode); + return new AndConstraint(this); + } + + public AndConstraint BeOfType() + { + Subject.typecode.Should().Be(InfoOf.NPTypeCode); + return new AndConstraint(this); + } + + public AndConstraint NotBeSliced() + { + Subject.Unsafe.Storage.Shape.IsSliced.Should().BeFalse(); + return new AndConstraint(this); + } + + public AndConstraint NotBeScalar() + { + Subject.Unsafe.Storage.Shape.IsScalar.Should().BeFalse(); + return new AndConstraint(this); + } + + + public AndConstraint BeNDim(int ndim) + { + Subject.Unsafe.Storage.Shape.Dimensions.Length.Should().Be(ndim); + return new AndConstraint(this); + } + + public AndConstraint Be(NDArray expected) + { + Execute.Assertion + .ForCondition(np.array_equal(Subject, expected)) + .FailWith($"Expected the subject and other ndarray to be equals.\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{expected.ToString(false)}"); + + return new AndConstraint(this); + } + + public AndConstraint BeOfValues(params object[] values) + { + if (values == null) + throw new ArgumentNullException(nameof(values)); + + Subject.size.Should().Be(values.Length, "the method BeOfValues also confirms the sizes are matching with given values."); + +#if _REGEN + #region Compute + switch (Subject.typecode) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: + { + var iter = Subject.AsIterator<#2>(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.To#1(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: #1).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + break; + } + % + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + + switch (Subject.typecode) + { + case NPTypeCode.Boolean: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToBoolean(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Byte: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToByte(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Byte).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToInt16(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Int16).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToUInt16(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: UInt16).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToInt32(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Int32).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToUInt32(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: UInt32).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToInt64(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Int64).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToUInt64(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: UInt64).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Char: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToChar(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Char).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Double: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToDouble(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Double).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Single: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToSingle(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Single).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Decimal: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToDecimal(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Decimal).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + default: + throw new NotSupportedException(); + } + + #endregion + +#endif + + + return new AndConstraint(this); + } + + public AndConstraint AllValuesBe(object val) + { +#if _REGEN + #region Compute + switch (Subject.typecode) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: + { + var iter = Subject.AsIterator<#2>(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.To#1(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: #1).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + break; + } + % + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + + switch (Subject.typecode) + { + case NPTypeCode.Boolean: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToBoolean(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Byte: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToByte(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Byte).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToInt16(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Int16).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToUInt16(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: UInt16).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToInt32(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Int32).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToUInt32(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: UInt32).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToInt64(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Int64).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToUInt64(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: UInt64).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Char: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToChar(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Char).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Double: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToDouble(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Double).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Single: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToSingle(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Single).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Decimal: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + var expected = Convert.ToDecimal(val); + for (int i = 0; hasnext(); i++) + { + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {2}th value to be {0}, but found {1} (dtype: Decimal).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n{val}", expected, nextval, i); + } + + break; + } + + default: + throw new NotSupportedException(); + } + + #endregion + +#endif + + + return new AndConstraint(this); + } + + public AndConstraint BeOfValuesApproximately(double sensitivity, params object[] values) + { + if (values == null) + throw new ArgumentNullException(nameof(values)); + + Subject.size.Should().Be(values.Length, "the method BeOfValuesApproximately also confirms the sizes are matching with given values."); + +#if _REGEN + #region Compute + switch (Subject.typecode) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: + { + var iter = Subject.AsIterator<#2>(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.To#1(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + break; + } + % + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + + switch (Subject.typecode) + { + case NPTypeCode.Boolean: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToBoolean(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(expected == nextval) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Byte: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToByte(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToInt16(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt16: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToUInt16(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToInt32(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt32: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToUInt32(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Int64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToInt64(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.UInt64: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToUInt64(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs((double) (expected - nextval)) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Char: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToChar(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Double: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToDouble(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Single: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToSingle(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + case NPTypeCode.Decimal: + { + var iter = Subject.AsIterator(); + var next = iter.MoveNext; + var hasnext = iter.HasNext; + for (int i = 0; i < values.Length; i++) + { + Execute.Assertion + .ForCondition(hasnext()) + .FailWith($"Expected the NDArray to have atleast {values.Length} but in fact it has size of {i}."); + + var expected = Convert.ToDecimal(values[i]); + var nextval = next(); + + Execute.Assertion + .ForCondition(Math.Abs(expected - nextval) <= (decimal) sensitivity) + .FailWith($"Expected NDArray's {{2}}th value to be {{0}}, but found {{1}} (dtype: Boolean).\n------- Subject -------\n{Subject.ToString(false)}\n------- Expected -------\n[{string.Join(", ", values.Select(v => v.ToString()))}]", expected, nextval, i); + } + + break; + } + + default: + throw new NotSupportedException(); + } + + #endregion + +#endif + + + return new AndConstraint(this); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs b/test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs new file mode 100644 index 00000000..ac4dee69 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs @@ -0,0 +1,173 @@ +using System; +using System.Diagnostics; +using System.Threading; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace TensorFlowNET.UnitTest +{ + public delegate void MultiThreadedTestDelegate(int threadid); + + /// + /// Creates a synchronized eco-system of running code. + /// + public class MultiThreadedUnitTestExecuter : IDisposable + { + public int ThreadCount { get; } + public Thread[] Threads { get; } + public Exception[] Exceptions { get; } + private readonly SemaphoreSlim barrier_threadstarted; + private readonly ManualResetEventSlim barrier_corestart; + private readonly SemaphoreSlim done_barrier2; + + public Action PostRun { get; set; } + + #region Static + + [DebuggerHidden] + public static void Run(int threadCount, MultiThreadedTestDelegate workload) + { + if (workload == null) throw new ArgumentNullException(nameof(workload)); + if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount)); + new MultiThreadedUnitTestExecuter(threadCount).Run(workload); + } + + [DebuggerHidden] + public static void Run(int threadCount, params MultiThreadedTestDelegate[] workloads) + { + if (workloads == null) throw new ArgumentNullException(nameof(workloads)); + if (workloads.Length == 0) throw new ArgumentException("Value cannot be an empty collection.", nameof(workloads)); + if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount)); + new MultiThreadedUnitTestExecuter(threadCount).Run(workloads); + } + + [DebuggerHidden] + public static void Run(int threadCount, MultiThreadedTestDelegate workload, Action postRun) + { + if (workload == null) throw new ArgumentNullException(nameof(workload)); + if (postRun == null) throw new ArgumentNullException(nameof(postRun)); + if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount)); + new MultiThreadedUnitTestExecuter(threadCount) {PostRun = postRun}.Run(workload); + } + + #endregion + + + /// Initializes a new instance of the class. + public MultiThreadedUnitTestExecuter(int threadCount) + { + if (threadCount <= 0) + throw new ArgumentOutOfRangeException(nameof(threadCount)); + ThreadCount = threadCount; + Threads = new Thread[ThreadCount]; + Exceptions = new Exception[ThreadCount]; + done_barrier2 = new SemaphoreSlim(0, threadCount); + barrier_corestart = new ManualResetEventSlim(); + barrier_threadstarted = new SemaphoreSlim(0, threadCount); + } + + [DebuggerHidden] + public void Run(params MultiThreadedTestDelegate[] workloads) + { + if (workloads == null) + throw new ArgumentNullException(nameof(workloads)); + if (workloads.Length != 1 && workloads.Length % ThreadCount != 0) + throw new InvalidOperationException($"Run method must accept either 1 workload or n-threads workloads. Got {workloads.Length} workloads."); + + if (ThreadCount == 1) + { + Exception ex = null; + new Thread(() => + { + try + { + workloads[0](0); + } catch (Exception e) + { + if (Debugger.IsAttached) + throw; + ex = e; + } finally + { + done_barrier2.Release(1); + } + }).Start(); + + done_barrier2.Wait(); + + if (ex != null) + throw new Exception($"Thread 0 has failed: ", ex); + + PostRun?.Invoke(this); + + return; + } + + //thread core + Exception ThreadCore(MultiThreadedTestDelegate core, int threadid) + { + barrier_threadstarted.Release(1); + barrier_corestart.Wait(); + //workload + try + { + core(threadid); + } catch (Exception e) + { + if (Debugger.IsAttached) + throw; + return e; + } finally + { + done_barrier2.Release(1); + } + + return null; + } + + //initialize all threads + if (workloads.Length == 1) + { + var workload = workloads[0]; + for (int i = 0; i < ThreadCount; i++) + { + var i_local = i; + Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); + } + } else + { + for (int i = 0; i < ThreadCount; i++) + { + var i_local = i; + var workload = workloads[i_local % workloads.Length]; + Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); + } + } + + //run all threads + for (int i = 0; i < ThreadCount; i++) Threads[i].Start(); + //wait for threads to be started and ready + for (int i = 0; i < ThreadCount; i++) barrier_threadstarted.Wait(); + + //signal threads to start + barrier_corestart.Set(); + + //wait for threads to finish + for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait(); + + //handle fails + for (int i = 0; i < ThreadCount; i++) + if (Exceptions[i] != null) + throw new Exception($"Thread {i} has failed: ", Exceptions[i]); + + //checks after ended + PostRun?.Invoke(this); + } + + public void Dispose() + { + barrier_threadstarted.Dispose(); + barrier_corestart.Dispose(); + done_barrier2.Dispose(); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs b/test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs new file mode 100644 index 00000000..acb8c41e --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs @@ -0,0 +1,914 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.VisualStudio.TestTools.UnitTesting +{ + using System; + using System.Collections.Generic; + //using System.Diagnostics; + //using System.Diagnostics.CodeAnalysis; + using System.Globalization; + using System.Reflection; + + /// + /// This class represents the live NON public INTERNAL object in the system + /// + internal class PrivateObject + { + #region Data + + // bind everything + private const BindingFlags BindToEveryThing = BindingFlags.Default | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public; + + private static BindingFlags constructorFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.CreateInstance | BindingFlags.NonPublic; + + private object target; // automatically initialized to null + private Type originalType; // automatically initialized to null + + //private Dictionary> methodCache; // automatically initialized to null + + #endregion + + #region Constructors + + ///// + ///// Initializes a new instance of the class that contains + ///// the already existing object of the private class + ///// + ///// object that serves as starting point to reach the private members + ///// the derefrencing string using . that points to the object to be retrived as in m_X.m_Y.m_Z + //[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")] + //public PrivateObject(object obj, string memberToAccess) + //{ + // Helper.CheckParameterNotNull(obj, "obj", string.Empty); + // ValidateAccessString(memberToAccess); + + // PrivateObject temp = obj as PrivateObject; + // if (temp == null) + // { + // temp = new PrivateObject(obj); + // } + + // // Split The access string + // string[] arr = memberToAccess.Split(new char[] { '.' }); + + // for (int i = 0; i < arr.Length; i++) + // { + // object next = temp.InvokeHelper(arr[i], BindToEveryThing | BindingFlags.Instance | BindingFlags.GetField | BindingFlags.GetProperty, null, CultureInfo.InvariantCulture); + // temp = new PrivateObject(next); + // } + + // this.target = temp.target; + // this.originalType = temp.originalType; + //} + + ///// + ///// Initializes a new instance of the class that wraps the + ///// specified type. + ///// + ///// Name of the assembly + ///// fully qualified name + ///// Argmenets to pass to the constructor + //public PrivateObject(string assemblyName, string typeName, params object[] args) + // : this(assemblyName, typeName, null, args) + //{ + //} + + ///// + ///// Initializes a new instance of the class that wraps the + ///// specified type. + ///// + ///// Name of the assembly + ///// fully qualified name + ///// An array of objects representing the number, order, and type of the parameters for the constructor to get + ///// Argmenets to pass to the constructor + //public PrivateObject(string assemblyName, string typeName, Type[] parameterTypes, object[] args) + // : this(Type.GetType(string.Format(CultureInfo.InvariantCulture, "{0}, {1}", typeName, assemblyName), false), parameterTypes, args) + //{ + // Helper.CheckParameterNotNull(assemblyName, "assemblyName", string.Empty); + // Helper.CheckParameterNotNull(typeName, "typeName", string.Empty); + //} + + ///// + ///// Initializes a new instance of the class that wraps the + ///// specified type. + ///// + ///// type of the object to create + ///// Argmenets to pass to the constructor + //public PrivateObject(Type type, params object[] args) + // : this(type, null, args) + //{ + // Helper.CheckParameterNotNull(type, "type", string.Empty); + //} + + ///// + ///// Initializes a new instance of the class that wraps the + ///// specified type. + ///// + ///// type of the object to create + ///// An array of objects representing the number, order, and type of the parameters for the constructor to get + ///// Argmenets to pass to the constructor + //public PrivateObject(Type type, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(type, "type", string.Empty); + // object o; + // if (parameterTypes != null) + // { + // ConstructorInfo ci = type.GetConstructor(BindToEveryThing, null, parameterTypes, null); + // if (ci == null) + // { + // throw new ArgumentException(FrameworkMessages.PrivateAccessorConstructorNotFound); + // } + + // try + // { + // o = ci.Invoke(args); + // } + // catch (TargetInvocationException e) + // { + // Debug.Assert(e.InnerException != null, "Inner exception should not be null."); + // if (e.InnerException != null) + // { + // throw e.InnerException; + // } + + // throw; + // } + // } + // else + // { + // o = Activator.CreateInstance(type, constructorFlags, null, args, null); + // } + + // this.ConstructFrom(o); + //} + + /// + /// Initializes a new instance of the class that wraps + /// the given object. + /// + /// object to wrap + //[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")] + public PrivateObject(object obj) + { + Helper.CheckParameterNotNull(obj, "obj", string.Empty); + this.ConstructFrom(obj); + } + + /// + /// Initializes a new instance of the class that wraps + /// the given object. + /// + /// object to wrap + /// PrivateType object + //[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an an object, so 'obj' seems reasonable")] + public PrivateObject(object obj, PrivateType type) + { + Helper.CheckParameterNotNull(type, "type", string.Empty); + this.target = obj; + this.originalType = type.ReferencedType; + } + + #endregion + + ///// + ///// Gets or sets the target + ///// + //public object Target + //{ + // get + // { + // return this.target; + // } + + // set + // { + // Helper.CheckParameterNotNull(value, "Target", string.Empty); + // this.target = value; + // this.originalType = value.GetType(); + // } + //} + + ///// + ///// Gets the type of underlying object + ///// + //public Type RealType + //{ + // get + // { + // return this.originalType; + // } + //} + + //private Dictionary> GenericMethodCache + //{ + // get + // { + // if (this.methodCache == null) + // { + // this.BuildGenericMethodCacheForType(this.originalType); + // } + + // Debug.Assert(this.methodCache != null, "Invalid method cache for type."); + + // return this.methodCache; + // } + //} + + /// + /// returns the hash code of the target object + /// + /// int representing hashcode of the target object + public override int GetHashCode() + { + //Debug.Assert(this.target != null, "target should not be null."); + return this.target.GetHashCode(); + } + + /// + /// Equals + /// + /// Object with whom to compare + /// returns true if the objects are equal. + public override bool Equals(object obj) + { + if (this != obj) + { + //Debug.Assert(this.target != null, "target should not be null."); + if (typeof(PrivateObject) == obj?.GetType()) + { + return this.target.Equals(((PrivateObject) obj).target); + } else + { + return false; + } + } + + return true; + } + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// Arguments to pass to the member to invoke. + ///// Result of method call + //public object Invoke(string name, params object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.Invoke(name, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Result of method call + //public object Invoke(string name, Type[] parameterTypes, object[] args) + //{ + // return this.Invoke(name, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of method call + //public object Invoke(string name, Type[] parameterTypes, object[] args, Type[] typeArguments) + //{ + // return this.Invoke(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// Result of method call + //public object Invoke(string name, object[] args, CultureInfo culture) + //{ + // return this.Invoke(name, null, args, culture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// Result of method call + //public object Invoke(string name, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.Invoke(name, BindToEveryThing, parameterTypes, args, culture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// Arguments to pass to the member to invoke. + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.Invoke(name, bindingFlags, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // return this.Invoke(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + //{ + // return this.Invoke(name, bindingFlags, null, args, culture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.Invoke(name, bindingFlags, parameterTypes, args, culture, null); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // bindingFlags |= BindToEveryThing | BindingFlags.Instance; + + // // Fix up the parameter types + // MethodInfo member = this.originalType.GetMethod(name, bindingFlags, null, parameterTypes, null); + + // // If the method was not found and type arguments were provided for generic paramaters, + // // attempt to look up a generic method. + // if ((member == null) && (typeArguments != null)) + // { + // // This method may contain generic parameters...if so, the previous call to + // // GetMethod() will fail because it doesn't fully support generic parameters. + + // // Look in the method cache to see if there is a generic method + // // on the incoming type that contains the correct signature. + // member = this.GetGenericMethodFromCache(name, parameterTypes, typeArguments, bindingFlags, null); + // } + + // if (member == null) + // { + // throw new ArgumentException( + // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // try + // { + // if (member.IsGenericMethodDefinition) + // { + // MethodInfo constructed = member.MakeGenericMethod(typeArguments); + // return constructed.Invoke(this.target, bindingFlags, null, args, culture); + // } + // else + // { + // return member.Invoke(this.target, bindingFlags, null, args, culture); + // } + // } + // catch (TargetInvocationException e) + // { + // Debug.Assert(e.InnerException != null, "Inner exception should not be null."); + // if (e.InnerException != null) + // { + // throw e.InnerException; + // } + + // throw; + // } + // } + // else + // { + // return this.InvokeHelper(name, bindingFlags | BindingFlags.InvokeMethod, args, culture); + // } + //} + + ///// + ///// Gets the array element using array of subsrcipts for each dimension + ///// + ///// Name of the member + ///// the indices of array + ///// An arrya of elements. + //public object GetArrayElement(string name, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetArrayElement(name, BindToEveryThing, indices); + //} + + ///// + ///// Sets the array element using array of subsrcipts for each dimension + ///// + ///// Name of the member + ///// Value to set + ///// the indices of array + //public void SetArrayElement(string name, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetArrayElement(name, BindToEveryThing, value, indices); + //} + + ///// + ///// Gets the array element using array of subsrcipts for each dimension + ///// + ///// Name of the member + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// the indices of array + ///// An arrya of elements. + //public object GetArrayElement(string name, BindingFlags bindingFlags, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); + // return arr.GetValue(indices); + //} + + ///// + ///// Sets the array element using array of subsrcipts for each dimension + ///// + ///// Name of the member + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// Value to set + ///// the indices of array + //public void SetArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); + // arr.SetValue(value, indices); + //} + + ///// + ///// Get the field + ///// + ///// Name of the field + ///// The field. + //public object GetField(string name) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetField(name, BindToEveryThing); + //} + + ///// + ///// Sets the field + ///// + ///// Name of the field + ///// value to set + //public void SetField(string name, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetField(name, BindToEveryThing, value); + //} + + ///// + ///// Gets the field + ///// + ///// Name of the field + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// The field. + //public object GetField(string name, BindingFlags bindingFlags) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); + //} + + ///// + ///// Sets the field + ///// + ///// Name of the field + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// value to set + //public void SetField(string name, BindingFlags bindingFlags, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.InvokeHelper(name, BindingFlags.SetField | bindingFlags, new object[] { value }, CultureInfo.InvariantCulture); + //} + + /// + /// Get the field or property + /// + /// Name of the field or property + /// The field or property. + public object GetFieldOrProperty(string name) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.GetFieldOrProperty(name, BindToEveryThing); + } + + /// + /// Sets the field or property + /// + /// Name of the field or property + /// value to set + public void SetFieldOrProperty(string name, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.SetFieldOrProperty(name, BindToEveryThing, value); + } + + /// + /// Gets the field or property + /// + /// Name of the field or property + /// A bitmask comprised of one or more that specify how the search is conducted. + /// The field or property. + public object GetFieldOrProperty(string name, BindingFlags bindingFlags) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.InvokeHelper(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture); + } + + /// + /// Sets the field or property + /// + /// Name of the field or property + /// A bitmask comprised of one or more that specify how the search is conducted. + /// value to set + public void SetFieldOrProperty(string name, BindingFlags bindingFlags, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.InvokeHelper(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags, new object[] {value}, CultureInfo.InvariantCulture); + } + + ///// + ///// Gets the property + ///// + ///// Name of the property + ///// Arguments to pass to the member to invoke. + ///// The property. + //public object GetProperty(string name, params object[] args) + //{ + // return this.GetProperty(name, null, args); + //} + + ///// + ///// Gets the property + ///// + ///// Name of the property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + ///// The property. + //public object GetProperty(string name, Type[] parameterTypes, object[] args) + //{ + // return this.GetProperty(name, BindToEveryThing, parameterTypes, args); + //} + + ///// + ///// Set the property + ///// + ///// Name of the property + ///// value to set + ///// Arguments to pass to the member to invoke. + //public void SetProperty(string name, object value, params object[] args) + //{ + // this.SetProperty(name, null, value, args); + //} + + ///// + ///// Set the property + ///// + ///// Name of the property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// value to set + ///// Arguments to pass to the member to invoke. + //public void SetProperty(string name, Type[] parameterTypes, object value, object[] args) + //{ + // this.SetProperty(name, BindToEveryThing, value, parameterTypes, args); + //} + + ///// + ///// Gets the property + ///// + ///// Name of the property + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// Arguments to pass to the member to invoke. + ///// The property. + //public object GetProperty(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.GetProperty(name, bindingFlags, null, args); + //} + + ///// + ///// Gets the property + ///// + ///// Name of the property + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + ///// The property. + //public object GetProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException( + // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // return pi.GetValue(this.target, args); + // } + // else + // { + // return this.InvokeHelper(name, bindingFlags | BindingFlags.GetProperty, args, null); + // } + //} + + ///// + ///// Sets the property + ///// + ///// Name of the property + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// value to set + ///// Arguments to pass to the member to invoke. + //public void SetProperty(string name, BindingFlags bindingFlags, object value, params object[] args) + //{ + // this.SetProperty(name, bindingFlags, value, null, args); + //} + + ///// + ///// Sets the property + ///// + ///// Name of the property + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// value to set + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + //public void SetProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException( + // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // pi.SetValue(this.target, value, args); + // } + // else + // { + // object[] pass = new object[(args?.Length ?? 0) + 1]; + // pass[0] = value; + // args?.CopyTo(pass, 1); + // this.InvokeHelper(name, bindingFlags | BindingFlags.SetProperty, pass, null); + // } + //} + + #region Private Helpers + + ///// + ///// Validate access string + ///// + ///// access string + //private static void ValidateAccessString(string access) + //{ + // Helper.CheckParameterNotNull(access, "access", string.Empty); + // if (access.Length == 0) + // { + // throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax); + // } + + // string[] arr = access.Split('.'); + // foreach (string str in arr) + // { + // if ((str.Length == 0) || (str.IndexOfAny(new char[] { ' ', '\t', '\n' }) != -1)) + // { + // throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax); + // } + // } + //} + + /// + /// Invokes the memeber + /// + /// Name of the member + /// Additional attributes + /// Arguments for the invocation + /// Culture + /// Result of the invocation + private object InvokeHelper(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + //Debug.Assert(this.target != null, "Internal Error: Null reference is returned for internal object"); + + // Invoke the actual Method + try + { + return this.originalType.InvokeMember(name, bindingFlags, null, this.target, args, culture); + } catch (TargetInvocationException e) + { + //Debug.Assert(e.InnerException != null, "Inner exception should not be null."); + if (e.InnerException != null) + { + throw e.InnerException; + } + + throw; + } + } + + private void ConstructFrom(object obj) + { + Helper.CheckParameterNotNull(obj, "obj", string.Empty); + this.target = obj; + this.originalType = obj.GetType(); + } + + //private void BuildGenericMethodCacheForType(Type t) + //{ + // Debug.Assert(t != null, "type should not be null."); + // this.methodCache = new Dictionary>(); + + // MethodInfo[] members = t.GetMethods(BindToEveryThing); + // LinkedList listByName; // automatically initialized to null + + // foreach (MethodInfo member in members) + // { + // if (member.IsGenericMethod || member.IsGenericMethodDefinition) + // { + // if (!this.GenericMethodCache.TryGetValue(member.Name, out listByName)) + // { + // listByName = new LinkedList(); + // this.GenericMethodCache.Add(member.Name, listByName); + // } + + // Debug.Assert(listByName != null, "list should not be null."); + // listByName.AddLast(member); + // } + // } + //} + + ///// + ///// Extracts the most appropriate generic method signature from the current private type. + ///// + ///// The name of the method in which to search the signature cache. + ///// An array of types corresponding to the types of the parameters in which to search. + ///// An array of types corresponding to the types of the generic arguments. + ///// to further filter the method signatures. + ///// Modifiers for parameters. + ///// A methodinfo instance. + //private MethodInfo GetGenericMethodFromCache(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers) + //{ + // Debug.Assert(!string.IsNullOrEmpty(methodName), "Invalid method name."); + // Debug.Assert(parameterTypes != null, "Invalid parameter type array."); + // Debug.Assert(typeArguments != null, "Invalid type arguments array."); + + // // Build a preliminary list of method candidates that contain roughly the same signature. + // var methodCandidates = this.GetMethodCandidates(methodName, parameterTypes, typeArguments, bindingFlags, modifiers); + + // // Search of ambiguous methods (methods with the same signature). + // MethodInfo[] finalCandidates = new MethodInfo[methodCandidates.Count]; + // methodCandidates.CopyTo(finalCandidates, 0); + + // if ((parameterTypes != null) && (parameterTypes.Length == 0)) + // { + // for (int i = 0; i < finalCandidates.Length; i++) + // { + // MethodInfo methodInfo = finalCandidates[i]; + + // if (!RuntimeTypeHelper.CompareMethodSigAndName(methodInfo, finalCandidates[0])) + // { + // throw new AmbiguousMatchException(); + // } + // } + + // // All the methods have the exact same name and sig so return the most derived one. + // return RuntimeTypeHelper.FindMostDerivedNewSlotMeth(finalCandidates, finalCandidates.Length) as MethodInfo; + // } + + // // Now that we have a preliminary list of candidates, select the most appropriate one. + // return RuntimeTypeHelper.SelectMethod(bindingFlags, finalCandidates, parameterTypes, modifiers) as MethodInfo; + //} + + //private LinkedList GetMethodCandidates(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers) + //{ + // Debug.Assert(!string.IsNullOrEmpty(methodName), "methodName should not be null."); + // Debug.Assert(parameterTypes != null, "parameterTypes should not be null."); + // Debug.Assert(typeArguments != null, "typeArguments should not be null."); + + // LinkedList methodCandidates = new LinkedList(); + // LinkedList methods = null; + + // if (!this.GenericMethodCache.TryGetValue(methodName, out methods)) + // { + // return methodCandidates; + // } + + // Debug.Assert(methods != null, "methods should not be null."); + + // foreach (MethodInfo candidate in methods) + // { + // bool paramMatch = true; + // ParameterInfo[] candidateParams = null; + // Type[] genericArgs = candidate.GetGenericArguments(); + // Type sourceParameterType = null; + + // if (genericArgs.Length != typeArguments.Length) + // { + // continue; + // } + + // // Since we can't just get the correct MethodInfo from Reflection, + // // we will just match the number of parameters, their order, and their type + // var methodCandidate = candidate; + // candidateParams = methodCandidate.GetParameters(); + + // if (candidateParams.Length != parameterTypes.Length) + // { + // continue; + // } + + // // Exact binding + // if ((bindingFlags & BindingFlags.ExactBinding) != 0) + // { + // int i = 0; + + // foreach (ParameterInfo candidateParam in candidateParams) + // { + // sourceParameterType = parameterTypes[i++]; + + // if (candidateParam.ParameterType.ContainsGenericParameters) + // { + // // Since we have a generic parameter here, just make sure the IsArray matches. + // if (candidateParam.ParameterType.IsArray != sourceParameterType.IsArray) + // { + // paramMatch = false; + // break; + // } + // } + // else + // { + // if (candidateParam.ParameterType != sourceParameterType) + // { + // paramMatch = false; + // break; + // } + // } + // } + + // if (paramMatch) + // { + // methodCandidates.AddLast(methodCandidate); + // continue; + // } + // } + // else + // { + // methodCandidates.AddLast(methodCandidate); + // } + // } + + // return methodCandidates; + //} + + #endregion + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs b/test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs new file mode 100644 index 00000000..f40cc727 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs @@ -0,0 +1,314 @@ +// +// Copyright (c) 2019 cactuaroid All Rights Reserved +// +// +// Released under the MIT license +// https://github.com/cactuaroid/PrivateObjectExtensions +// + +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using System.Reflection; + +namespace System +{ + /// + /// Extension methods for PrivateObject + /// + public static class PrivateObjectExtensions + { + private static readonly BindingFlags Static = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Static; + private static readonly BindingFlags Instance = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Instance; + + /// + /// Get from private (and any other) field/property. + /// If the real type of specified object doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The object to get from + /// The name of the field/property + /// The object got from the field/property + /// 'name' is not found. + /// Arguments contain null. + public static object GetPrivate(this object obj, string name) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + + return GetPrivate(obj, name, obj.GetType(), null); + } + + /// + /// Get from private (and any other) field/property. + /// If the real type of specified object doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The type of the field/property + /// The object to get from + /// The name of the field/property + /// The object got from the field/property + /// 'name' is not found. + /// Arguments contain null. + public static T GetPrivate(this object obj, string name) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + + return (T)GetPrivate(obj, name, obj.GetType(), typeof(T)); + } + + /// + /// Get from private (and any other) field/property with assuming the specified object as specified type. + /// If the specified type doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The object to get from + /// The name of the field/property + /// The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored. + /// The object got from the field/property + /// 'name' is not found. + /// 'objType' is not assignable from 'obj'. + /// Arguments contain null. + public static object GetPrivate(this object obj, string name, Type objType) + { + return GetPrivate(obj, name, objType, null); + } + + /// + /// Get from private (and any other) field/property with assuming the specified object as specified type. + /// If the specified type doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The type of the field/property + /// The object to get from + /// The name of the field/property + /// The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored. + /// The object got from the field/property + /// 'name' is not found. + /// 'objType' is not assignable from 'obj'. + /// Arguments contain null. + public static T GetPrivate(this object obj, string name, Type objType) + { + return (T)GetPrivate(obj, name, objType, typeof(T)); + } + + private static object GetPrivate(object obj, string name, Type objType, Type memberType) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + if (name == null) { throw new ArgumentNullException("name"); } + if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } + if (objType == null) { throw new ArgumentNullException("objType"); } + if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); } + + bool memberTypeMatching(Type actualType) => actualType == memberType; + + if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType)) + { + return new PrivateObject(obj, new PrivateType(ownerType)).GetFieldOrProperty(name); + } + else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType)) + { + return new PrivateType(ownerType).GetStaticFieldOrProperty(name); + } + + throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found."); + } + + /// + /// Get from private (and any other) static field/property. + /// + /// The type to get from + /// The name of the static field/property + /// The object got from the static field/property + /// 'name' is not found. + /// Arguments contain null. + public static object GetPrivate(this Type type, string name) + { + return GetPrivate(type, name, null); + } + + /// + /// Get from private (and any other) static field/property. + /// + /// The type of the field/property + /// The type to get from + /// The name of the static field/property + /// The object got from the static field/property + /// 'name' is not found. + /// Arguments contain null. + public static T GetPrivate(this Type type, string name) + { + return (T)GetPrivate(type, name, typeof(T)); + } + + private static object GetPrivate(this Type type, string name, Type memberType) + { + if (type == null) { throw new ArgumentNullException("type"); } + if (name == null) { throw new ArgumentNullException("name"); } + if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } + + bool memberTypeMatching(Type actualType) => actualType == memberType; + + if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static)) + { + return new PrivateType(type).GetStaticFieldOrProperty(name); + } + + throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found."); + } + + /// + /// Set to private (and any other) field/property. + /// If the real type of specified object doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The object to set to + /// The name of the field/property + /// The value to set for 'name' + /// 'name' is not found. + /// Arguments contain null. + public static void SetPrivate(this object obj, string name, T value) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + + SetPrivate(obj, name, value, obj.GetType()); + } + + /// + /// Set to private (and any other) field/property with assuming the specified object as specified type. + /// If the specified type doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The object to set to + /// The name of the field/property + /// The value to set for 'name' + /// The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored. + /// 'name' is not found. + /// 'objType' is not assignable from 'obj'. + /// Arguments contain null. + public static void SetPrivate(this object obj, string name, T value, Type objType) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + if (name == null) { throw new ArgumentNullException("name"); } + if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } + if (value == null) { throw new ArgumentNullException("value"); } + if (objType == null) { throw new ArgumentNullException("objType"); } + if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); } + + if (TrySetPrivate(obj, name, value, objType)) { return; } + + // retry for the case of getter only property + if (TrySetPrivate(obj, GetBackingFieldName(name), value, objType)) { return; } + + throw new ArgumentException($"{typeof(T)} {name} is not found."); + } + + private static bool TrySetPrivate(object obj, string name, T value, Type objType) + { + var memberType = typeof(T); + bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType); + + try + { + if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType)) + { + new PrivateObject(obj, new PrivateType(ownerType)).SetFieldOrProperty(name, value); + return true; + } + else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType)) + { + new PrivateType(ownerType).SetStaticFieldOrProperty(name, value); + return true; + } + } + catch(MissingMethodException) + { + // When getter only property name is given, the property is found but fails to set. + return false; + } + + return false; + } + + /// + /// Set to private (and any other) static field/property. + /// + /// The type to set to + /// The name of the field/property + /// The value to set for 'name' + /// 'name' is not found. + /// Arguments contain null. + public static void SetPrivate(this Type type, string name, T value) + { + if (type == null) { throw new ArgumentNullException("type"); } + if (name == null) { throw new ArgumentNullException("name"); } + if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } + + if (TrySetPrivate(type, name, value)) { return; } + + // retry for the case of getter only property + if (TrySetPrivate(type, GetBackingFieldName(name), value)) { return; } + + throw new ArgumentException($"{typeof(T)} {name} is not found."); + } + + private static bool TrySetPrivate(this Type type, string name, T value) + { + var memberType = typeof(T); + bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType); + + try + { + if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static)) + { + new PrivateType(type).SetStaticFieldOrProperty(name, value); + return true; + } + } + catch (MissingMethodException) + { + // When getter only property name is given, the property is found but fails to set. + return false; + } + + return false; + } + + private static string GetBackingFieldName(string propertyName) + => $"<{propertyName}>k__BackingField"; // generated backing field name + + private static bool TryFindFieldOrPropertyOwnerType(Type objType, string name, Type memberType, Func memberTypeMatching, BindingFlags bindingFlag, out Type ownerType) + { + ownerType = FindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, bindingFlag); + + return (ownerType != null); + } + + private static Type FindFieldOrPropertyOwnerType(Type objectType, string name, Type memberType, Func memberTypeMatching, BindingFlags bindingFlags) + { + if (objectType == null) { return null; } + + if (objectType.ContainsFieldOrProperty(name, memberType, memberTypeMatching, bindingFlags)) + { + return objectType; + } + + return FindFieldOrPropertyOwnerType(objectType.BaseType, name, memberType, memberTypeMatching, bindingFlags); + } + + private static bool ContainsFieldOrProperty(this Type objectType, string name, Type memberType, Func memberTypeMatching, BindingFlags bindingFlags) + { + var fields = objectType + .GetFields(bindingFlags) + .Select((x) => new { Type = x.FieldType, Member = x as MemberInfo }); + + var properties = objectType + .GetProperties(bindingFlags) + .Select((x) => new { Type = x.PropertyType, Member = x as MemberInfo }); + + var members = fields.Concat(properties); + + return members.Any((actual) => + (memberType == null || memberTypeMatching.Invoke(actual.Type)) + && actual.Member.Name == name); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs b/test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs new file mode 100644 index 00000000..a2d0b3c3 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs @@ -0,0 +1,572 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.VisualStudio.TestTools.UnitTesting +{ + using System; + //using System.Diagnostics; + using System.Globalization; + using System.Reflection; + + /// + /// This class represents a private class for the Private Accessor functionality. + /// + internal class PrivateType + { + /// + /// Binds to everything + /// + private const BindingFlags BindToEveryThing = BindingFlags.Default + | BindingFlags.NonPublic | BindingFlags.Instance + | BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy; + + /// + /// The wrapped type. + /// + private Type type; + + ///// + ///// Initializes a new instance of the class that contains the private type. + ///// + ///// Assembly name + ///// fully qualified name of the + //public PrivateType(string assemblyName, string typeName) + //{ + // Helper.CheckParameterNotNullOrEmpty(assemblyName, "assemblyName", string.Empty); + // Helper.CheckParameterNotNullOrEmpty(typeName, "typeName", string.Empty); + // Assembly asm = Assembly.Load(assemblyName); + + // this.type = asm.GetType(typeName, true); + //} + + /// + /// Initializes a new instance of the class that contains + /// the private type from the type object + /// + /// The wrapped Type to create. + public PrivateType(Type type) + { + if (type == null) + { + throw new ArgumentNullException("type"); + } + + this.type = type; + } + + /// + /// Gets the referenced type + /// + public Type ReferencedType => this.type; + + ///// + ///// Invokes static member + ///// + ///// Name of the member to InvokeHelper + ///// Arguements to the invoction + ///// Result of invocation + //public object InvokeStatic(string name, params object[] args) + //{ + // return this.InvokeStatic(name, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes static member + ///// + ///// Name of the member to InvokeHelper + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invoction + ///// Result of invocation + //public object InvokeStatic(string name, Type[] parameterTypes, object[] args) + //{ + // return this.InvokeStatic(name, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes static member + ///// + ///// Name of the member to InvokeHelper + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invoction + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of invocation + //public object InvokeStatic(string name, Type[] parameterTypes, object[] args, Type[] typeArguments) + //{ + // return this.InvokeStatic(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Arguements to the invocation + ///// Culture + ///// Result of invocation + //public object InvokeStatic(string name, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, null, args, culture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Culture info + ///// Result of invocation + //public object InvokeStatic(string name, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, BindingFlags.InvokeMethod, parameterTypes, args, culture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// Arguements to the invocation + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.InvokeStatic(name, bindingFlags, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // return this.InvokeStatic(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// Arguements to the invocation + ///// Culture + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, bindingFlags, null, args, culture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// /// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Culture + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, bindingFlags, parameterTypes, args, culture, null); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// /// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Culture + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // MethodInfo member = this.type.GetMethod(name, bindingFlags | BindToEveryThing | BindingFlags.Static, null, parameterTypes, null); + // if (member == null) + // { + // throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // try + // { + // if (member.IsGenericMethodDefinition) + // { + // MethodInfo constructed = member.MakeGenericMethod(typeArguments); + // return constructed.Invoke(null, bindingFlags, null, args, culture); + // } + // else + // { + // return member.Invoke(null, bindingFlags, null, args, culture); + // } + // } + // catch (TargetInvocationException e) + // { + // Debug.Assert(e.InnerException != null, "Inner Exception should not be null."); + // if (e.InnerException != null) + // { + // throw e.InnerException; + // } + + // throw; + // } + // } + // else + // { + // return this.InvokeHelperStatic(name, bindingFlags | BindingFlags.InvokeMethod, args, culture); + // } + //} + + ///// + ///// Gets the element in static array + ///// + ///// Name of the array + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to get. For instance, to access a[10][11] the indices would be {10,11} + ///// + ///// element at the specified location + //public object GetStaticArrayElement(string name, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetStaticArrayElement(name, BindToEveryThing, indices); + //} + + ///// + ///// Sets the memeber of the static array + ///// + ///// Name of the array + ///// value to set + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to set. For instance, to access a[10][11] the array would be {10,11} + ///// + //public void SetStaticArrayElement(string name, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetStaticArrayElement(name, BindToEveryThing, value, indices); + //} + + ///// + ///// Gets the element in satatic array + ///// + ///// Name of the array + ///// Additional InvokeHelper attributes + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to get. For instance, to access a[10][11] the array would be {10,11} + ///// + ///// element at the spcified location + //public object GetStaticArrayElement(string name, BindingFlags bindingFlags, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture); + // return arr.GetValue(indices); + //} + + ///// + ///// Sets the memeber of the static array + ///// + ///// Name of the array + ///// Additional InvokeHelper attributes + ///// value to set + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to set. For instance, to access a[10][11] the array would be {10,11} + ///// + //public void SetStaticArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture); + // arr.SetValue(value, indices); + //} + + ///// + ///// Gets the static field + ///// + ///// Name of the field + ///// The static field. + //public object GetStaticField(string name) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetStaticField(name, BindToEveryThing); + //} + + ///// + ///// Sets the static field + ///// + ///// Name of the field + ///// Arguement to the invocation + //public void SetStaticField(string name, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetStaticField(name, BindToEveryThing, value); + //} + + ///// + ///// Gets the static field using specified InvokeHelper attributes + ///// + ///// Name of the field + ///// Additional invocation attributes + ///// The static field. + //public object GetStaticField(string name, BindingFlags bindingFlags) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture); + //} + + ///// + ///// Sets the static field using binding attributes + ///// + ///// Name of the field + ///// Additional InvokeHelper attributes + ///// Arguement to the invocation + //public void SetStaticField(string name, BindingFlags bindingFlags, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.InvokeHelperStatic(name, BindingFlags.SetField | bindingFlags | BindingFlags.Static, new[] { value }, CultureInfo.InvariantCulture); + //} + + /// + /// Gets the static field or property + /// + /// Name of the field or property + /// The static field or property. + public object GetStaticFieldOrProperty(string name) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.GetStaticFieldOrProperty(name, BindToEveryThing); + } + + /// + /// Sets the static field or property + /// + /// Name of the field or property + /// Value to be set to field or property + public void SetStaticFieldOrProperty(string name, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.SetStaticFieldOrProperty(name, BindToEveryThing, value); + } + + /// + /// Gets the static field or property using specified InvokeHelper attributes + /// + /// Name of the field or property + /// Additional invocation attributes + /// The static field or property. + public object GetStaticFieldOrProperty(string name, BindingFlags bindingFlags) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture); + } + + /// + /// Sets the static field or property using binding attributes + /// + /// Name of the field or property + /// Additional invocation attributes + /// Value to be set to field or property + public void SetStaticFieldOrProperty(string name, BindingFlags bindingFlags, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.InvokeHelperStatic(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags | BindingFlags.Static, new[] {value}, CultureInfo.InvariantCulture); + } + + ///// + ///// Gets the static property + ///// + ///// Name of the field or property + ///// Arguements to the invocation + ///// The static property. + //public object GetStaticProperty(string name, params object[] args) + //{ + // return this.GetStaticProperty(name, BindToEveryThing, args); + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Value to be set to field or property + ///// Arguments to pass to the member to invoke. + //public void SetStaticProperty(string name, object value, params object[] args) + //{ + // this.SetStaticProperty(name, BindToEveryThing, value, null, args); + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Value to be set to field or property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + //public void SetStaticProperty(string name, object value, Type[] parameterTypes, object[] args) + //{ + // this.SetStaticProperty(name, BindingFlags.SetProperty, value, parameterTypes, args); + //} + + ///// + ///// Gets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// Arguments to pass to the member to invoke. + ///// The static property. + //public object GetStaticProperty(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.GetStaticProperty(name, BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, args); + //} + + ///// + ///// Gets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + ///// The static property. + //public object GetStaticProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.type.GetProperty(name, bindingFlags | BindingFlags.Static, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // return pi.GetValue(null, args); + // } + // else + // { + // return this.InvokeHelperStatic(name, bindingFlags | BindingFlags.GetProperty, args, null); + // } + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// Value to be set to field or property + ///// Optional index values for indexed properties. The indexes of indexed properties are zero-based. This value should be null for non-indexed properties. + //public void SetStaticProperty(string name, BindingFlags bindingFlags, object value, params object[] args) + //{ + // this.SetStaticProperty(name, bindingFlags, value, null, args); + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// Value to be set to field or property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + //public void SetStaticProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.type.GetProperty(name, bindingFlags | BindingFlags.Static, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException( + // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // pi.SetValue(null, value, args); + // } + // else + // { + // object[] pass = new object[(args?.Length ?? 0) + 1]; + // pass[0] = value; + // args?.CopyTo(pass, 1); + // this.InvokeHelperStatic(name, bindingFlags | BindingFlags.SetProperty, pass, null); + // } + //} + + /// + /// Invokes the static method + /// + /// Name of the member + /// Additional invocation attributes + /// Arguements to the invocation + /// Culture + /// Result of invocation + private object InvokeHelperStatic(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + try + { + return this.type.InvokeMember(name, bindingFlags | BindToEveryThing | BindingFlags.Static, null, null, args, culture); + } catch (TargetInvocationException e) + { + //Debug.Assert(e.InnerException != null, "Inner Exception should not be null."); + if (e.InnerException != null) + { + throw e.InnerException; + } + + throw; + } + } + } + + /// + /// The helper. + /// + internal static class Helper + { + /// + /// The check parameter not null. + /// + /// + /// The parameter. + /// + /// + /// The parameter name. + /// + /// + /// The message. + /// + /// Throws argument null exception when parameter is null. + internal static void CheckParameterNotNull(object param, string parameterName, string message) + { + if (param == null) + { + throw new ArgumentNullException(parameterName, message); + } + } + + /// + /// The check parameter not null or empty. + /// + /// + /// The parameter. + /// + /// + /// The parameter name. + /// + /// + /// The message. + /// + /// Throws ArgumentException when parameter is null. + //internal static void CheckParameterNotNullOrEmpty(string param, string parameterName, string message) + //{ + // if (string.IsNullOrEmpty(param)) + // { + // throw new ArgumentException(message, parameterName); + // } + //} + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 4d9d1059..e1a91560 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -29,7 +29,7 @@ namespace TensorFlowNET.UnitTest } /// - /// https://www.tf.org/api_docs/python/tf/variable_scope + /// https://www.tensorflow.org/api_docs/python/tf/variable_scope /// how to create a new variable /// [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 627d7c2f..988afa17 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -12,42 +12,51 @@ namespace TensorFlowNET.UnitTest { public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") { - var desc = c_api.TF_NewOperation(graph, "AddN", name); - - var inputs = new TF_Output[] + lock (Locks.ProcessWide) { - new TF_Output(l, 0), - new TF_Output(r, 0), - }; + var desc = c_api.TF_NewOperation(graph, "AddN", name); - c_api.TF_AddInputList(desc, inputs, inputs.Length); + var inputs = new TF_Output[] + { + new TF_Output(l, 0), + new TF_Output(r, 0), + }; - var op = c_api.TF_FinishOperation(desc, s); - s.Check(); + c_api.TF_AddInputList(desc, inputs, inputs.Length); - return op; + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); + + return op; + } } [SuppressMessage("ReSharper", "RedundantAssignment")] public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) { - using (var buffer = new Buffer()) + lock (Locks.ProcessWide) { - c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); - attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); - } + using (var buffer = new Buffer()) + { + c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); + attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } - return s.Code == TF_Code.TF_OK; + return s.Code == TF_Code.TF_OK; + } } public static GraphDef GetGraphDef(Graph graph) { - using (var s = new Status()) - using (var buffer = new Buffer()) + lock (Locks.ProcessWide) { - c_api.TF_GraphToGraphDef(graph, buffer, s); - s.Check(); - return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + using (var s = new Status()) + using (var buffer = new Buffer()) + { + c_api.TF_GraphToGraphDef(graph, buffer, s); + s.Check(); + return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } } } @@ -58,6 +67,7 @@ namespace TensorFlowNET.UnitTest { return false; } + bool found_t = false; bool found_n = false; foreach (var attr in node_def.Attr) @@ -67,19 +77,16 @@ namespace TensorFlowNET.UnitTest if (attr.Value.Type == DataType.DtInt32) { found_t = true; - } - else + } else { return false; } - } - else if (attr.Key == "N") + } else if (attr.Key == "N") { if (attr.Value.I == n) { found_n = true; - } - else + } else { return false; } @@ -92,7 +99,7 @@ namespace TensorFlowNET.UnitTest public static bool IsNeg(NodeDef node_def, string input) { return node_def.Op == "Neg" && node_def.Name == "neg" && - node_def.Input.Count == 1 && node_def.Input[0] == input; + node_def.Input.Count == 1 && node_def.Input[0] == input; } public static bool IsPlaceholder(NodeDef node_def) @@ -111,13 +118,11 @@ namespace TensorFlowNET.UnitTest if (attr.Value.Type == DataType.DtInt32) { found_dtype = true; - } - else + } else { return false; } - } - else if (attr.Key == "shape") + } else if (attr.Key == "shape") { found_shape = true; } @@ -132,72 +137,82 @@ namespace TensorFlowNET.UnitTest { return false; } + bool found_dtype = false; bool found_value = false; - foreach (var attr in node_def.Attr) { + foreach (var attr in node_def.Attr) + { if (attr.Key == "dtype") { if (attr.Value.Type == DataType.DtInt32) { found_dtype = true; - } - else + } else { return false; } - } - else if (attr.Key == "value") + } else if (attr.Key == "value") { if (attr.Value.Tensor != null && attr.Value.Tensor.IntVal.Count == 1 && attr.Value.Tensor.IntVal[0] == v) { found_value = true; - } - else + } else { return false; } } } + return found_dtype && found_value; } public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") { - OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); - var neg_input = new TF_Output(n, 0); - c_api.TF_AddInput(desc, neg_input); - var op = c_api.TF_FinishOperation(desc, s); - s.Check(); + lock (Locks.ProcessWide) + { + OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); + var neg_input = new TF_Output(n, 0); + c_api.TF_AddInput(desc, neg_input); + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); - return op; + return op; + } } public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) { - var desc = c_api.TF_NewOperation(graph, "Placeholder", name); - c_api.TF_SetAttrType(desc, "dtype", dtype); - if (dims != null) + lock (Locks.ProcessWide) { - c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); - } - var op = c_api.TF_FinishOperation(desc, s); - s.Check(); + var desc = c_api.TF_NewOperation(graph, "Placeholder", name); + c_api.TF_SetAttrType(desc, "dtype", dtype); + if (dims != null) + { + c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); + } + + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); - return op; + return op; + } } public static Operation Const(Tensor t, Graph graph, Status s, string name) { - var desc = c_api.TF_NewOperation(graph, "Const", name); - c_api.TF_SetAttrTensor(desc, "value", t, s); - s.Check(); - c_api.TF_SetAttrType(desc, "dtype", t.dtype); - var op = c_api.TF_FinishOperation(desc, s); - s.Check(); - - return op; + lock (Locks.ProcessWide) + { + var desc = c_api.TF_NewOperation(graph, "Const", name); + c_api.TF_SetAttrTensor(desc, "value", t, s); + s.Check(); + c_api.TF_SetAttrType(desc, "dtype", t.dtype); + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); + + return op; + } } public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") @@ -205,4 +220,4 @@ namespace TensorFlowNET.UnitTest return Const(new Tensor(v), graph, s, name); } } -} +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/layers_test/flatten.cs b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs new file mode 100644 index 00000000..fa8ec792 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/layers_test/flatten.cs @@ -0,0 +1,58 @@ +using System; +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.layers_test +{ + [TestClass] + public class flatten + { + [TestMethod] + public void Case1() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } + + [TestMethod] + public void Case2() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); + sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1); + } + + [TestMethod] + public void Case3() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape()); + new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw(); + } + + [TestMethod] + public void Case4() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } + + [TestMethod] + public void Case5() + { + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2)); + sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs index 53334349..4e2e5871 100644 --- a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs +++ b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs @@ -1,6 +1,6 @@ -using System.Collections; +using System; +using System.Collections; using System.Collections.Generic; -using Colorful; using Microsoft.VisualStudio.TestTools.UnitTesting; using Newtonsoft.Json.Linq; using NumSharp; diff --git a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs index 310ac634..08c8da2a 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs @@ -27,7 +27,7 @@ namespace TensorFlowNET.UnitTest.ops_test using (var g = tf.Graph().as_default()) { var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); - var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); + var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); var op = g._create_op_from_tf_operation(c_op); Assert.AreEqual("myop", op.name); @@ -68,7 +68,7 @@ namespace TensorFlowNET.UnitTest.ops_test var true_fn = new Func(() => { - var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); + var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); var new_ops = g._add_new_tf_operations(); self.assertEqual(len(new_ops), 1); return x;