| @@ -8,6 +8,7 @@ | |||||
| [](https://www.nuget.org/packages/TensorFlow.NET) | [](https://www.nuget.org/packages/TensorFlow.NET) | ||||
| [](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | [](https://tensorflownet.readthedocs.io/en/latest/?badge=latest) | ||||
| [](https://996.icu/#/en_US) | [](https://996.icu/#/en_US) | ||||
| [](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab) | |||||
| TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). | ||||
| @@ -3,17 +3,11 @@ Microsoft Visual Studio Solution File, Format Version 12.00 | |||||
| # Visual Studio Version 16 | # Visual Studio Version 16 | ||||
| VisualStudioVersion = 16.0.29102.190 | VisualStudioVersion = 16.0.29102.190 | ||||
| MinimumVisualStudioVersion = 10.0.40219.1 | MinimumVisualStudioVersion = 10.0.40219.1 | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.UnitTest", "test\TensorFlowNET.UnitTest\TensorFlowNET.UnitTest.csproj", "{029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}" | |||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlow.Binding", "src\TensorFlowNET.Core\TensorFlow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | |||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | |||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Benchmark", "src\TensorFlowNet.Benchmarks\Benchmark.csproj", "{3A6EB896-604F-4E25-B677-B8103BCF3D2E}" | |||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Models", "src\TensorFlowNET.Models\TensorFlowNET.Models.csproj", "{D03F94CF-B283-4730-B177-21A57641061F}" | |||||
| EndProject | |||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Text", "src\TensorFlowNET.Text\TensorFlowNET.Text.csproj", "{904472F8-40E1-4650-AA6F-C7F209B3691B}" | |||||
| EndProject | |||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Hub", "src\TensorFlowNET.Hub\TensorFlowNET.Hub.csproj", "{4EAFAE19-C832-47C6-B01E-0F4268C9072C}" | |||||
| EndProject | |||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Datasets", "src\TensorFlowNET.Datasets\TensorFlowNET.Datasets.csproj", "{494D6CAD-2C0D-4C0B-90E2-B097DB039383}" | |||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "UnitTest", "test\TensorFlowNET.UnitTest\UnitTest.csproj", "{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}" | |||||
| EndProject | EndProject | ||||
| Global | Global | ||||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
| @@ -25,18 +19,6 @@ Global | |||||
| Release|x64 = Release|x64 | Release|x64 = Release|x64 | ||||
| EndGlobalSection | EndGlobalSection | ||||
| GlobalSection(ProjectConfigurationPlatforms) = postSolution | GlobalSection(ProjectConfigurationPlatforms) = postSolution | ||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Publish|x64.Build.0 = Release|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {029A8CF1-CF95-4DCB-98AA-9D3D96A83B3E}.Release|x64.Build.0 = Release|Any CPU | |||||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | ||||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|Any CPU | ||||
| @@ -49,54 +31,30 @@ Global | |||||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU | ||||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.Build.0 = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.Build.0 = Release|Any CPU | ||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Publish|x64.Build.0 = Release|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {D03F94CF-B283-4730-B177-21A57641061F}.Release|x64.Build.0 = Release|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Publish|x64.Build.0 = Release|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {904472F8-40E1-4650-AA6F-C7F209B3691B}.Release|x64.Build.0 = Release|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Publish|x64.Build.0 = Release|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {4EAFAE19-C832-47C6-B01E-0F4268C9072C}.Release|x64.Build.0 = Release|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Publish|x64.Build.0 = Release|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {494D6CAD-2C0D-4C0B-90E2-B097DB039383}.Release|x64.Build.0 = Release|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Debug|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Debug|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Debug|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.Build.0 = Release|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Debug|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Debug|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Debug|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|Any CPU | |||||
| EndGlobalSection | EndGlobalSection | ||||
| GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
| HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
| @@ -17,7 +17,9 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -76,7 +78,14 @@ namespace Tensorflow | |||||
| public Tensor concat(IList<Tensor> values, int axis, string name = "concat") | public Tensor concat(IList<Tensor> values, int axis, string name = "concat") | ||||
| { | { | ||||
| if (values.Count == 1) | if (values.Count == 1) | ||||
| throw new NotImplementedException("tf.concat length is 1"); | |||||
| { | |||||
| return tf_with(ops.name_scope(name), scope => | |||||
| { | |||||
| var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32); | |||||
| Debug.Assert(tensor.TensorShape.ndim == 0); | |||||
| return identity(values[0], name: scope); | |||||
| }); | |||||
| } | |||||
| return gen_array_ops.concat_v2(values.ToArray(), axis, name: name); | return gen_array_ops.concat_v2(values.ToArray(), axis, name: name); | ||||
| } | } | ||||
| @@ -111,7 +120,7 @@ namespace Tensorflow | |||||
| /// <param name="input"></param> | /// <param name="input"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor identity(Tensor input, string name = null) | |||||
| public Tensor identity(Tensor input, string name = null) | |||||
| => array_ops.identity(input, name: name); | => array_ops.identity(input, name: name); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -150,10 +159,10 @@ namespace Tensorflow | |||||
| /// <param name="axis"></param> | /// <param name="axis"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||||
| public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||||
| => gen_array_ops.reverse(tensor, axis, name: name); | => gen_array_ops.reverse(tensor, axis, name: name); | ||||
| public static Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||||
| public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||||
| => gen_array_ops.reverse(tensor, axis, name: name); | => gen_array_ops.reverse(tensor, axis, name: name); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -277,5 +286,14 @@ namespace Tensorflow | |||||
| /// <returns>A `Tensor` with all elements set to zero.</returns> | /// <returns>A `Tensor` with all elements set to zero.</returns> | ||||
| public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | ||||
| => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); | => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); | ||||
| /// <summary> | |||||
| /// Stops gradient computation. | |||||
| /// </summary> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor stop_gradient(Tensor x, string name = null) | |||||
| => gen_array_ops.stop_gradient(x, name: name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,57 +0,0 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using static Tensorflow.Binding; | |||||
| using Tensorflow.Estimators; | |||||
| using Tensorflow.Data; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class tensorflow | |||||
| { | |||||
| public Estimator_Internal estimator { get; } = new Estimator_Internal(); | |||||
| public class Estimator_Internal | |||||
| { | |||||
| public Estimator Estimator(Action model_fn, RunConfig config) | |||||
| => new Estimator(model_fn: model_fn, config: config); | |||||
| public RunConfig RunConfig(string model_dir) | |||||
| => new RunConfig(model_dir: model_dir); | |||||
| public void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||||
| => Training.train_and_evaluate(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | |||||
| public TrainSpec TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||||
| => new TrainSpec(input_fn: input_fn, max_steps: max_steps); | |||||
| /// <summary> | |||||
| /// Create an `Exporter` to use with `tf.estimator.EvalSpec`. | |||||
| /// </summary> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="serving_input_receiver_fn"></param> | |||||
| /// <param name="as_text"></param> | |||||
| /// <returns></returns> | |||||
| public FinalExporter FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false) | |||||
| => new FinalExporter(name: name, serving_input_receiver_fn: serving_input_receiver_fn, | |||||
| as_text: as_text); | |||||
| public EvalSpec EvalSpec(string name, Action input_fn, FinalExporter exporters) | |||||
| => new EvalSpec(name: name, input_fn: input_fn, exporters: exporters); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -24,6 +24,9 @@ namespace Tensorflow | |||||
| public GraphKeys GraphKeys { get; } = new GraphKeys(); | public GraphKeys GraphKeys { get; } = new GraphKeys(); | ||||
| public void reset_default_graph() | |||||
| => ops.reset_default_graph(); | |||||
| public Graph get_default_graph() | public Graph get_default_graph() | ||||
| { | { | ||||
| return ops.get_default_graph(); | return ops.get_default_graph(); | ||||
| @@ -37,6 +37,22 @@ namespace Tensorflow | |||||
| fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated, | ||||
| acceptable_fraction: acceptable_fraction, dct_method: dct_method); | acceptable_fraction: acceptable_fraction, dct_method: dct_method); | ||||
| /// <summary> | |||||
| /// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change. | |||||
| /// Returns a tensor with crops from the input image at positions defined at the bounding box locations in boxes.The cropped boxes are all resized(with bilinear or nearest neighbor interpolation) to a fixed size = [crop_height, crop_width].The result is a 4 - D tensor[num_boxes, crop_height, crop_width, depth].The resizing is corner aligned. In particular, if boxes = [[0, 0, 1, 1]], the method will give identical results to using tf.image.resize_bilinear() or tf.image.resize_nearest_neighbor() (depends on the method argument) with align_corners = True. | |||||
| /// </summary> | |||||
| /// <param name="image">A Tensor. Must be one of the following types: uint8, uint16, int8, int16, int32, int64, half, float32, float64. A 4-D tensor of shape [batch, image_height, image_width, depth]. Both image_height and image_width need to be positive.</param> | |||||
| /// <param name="boxes">A Tensor of type float32. A 2-D tensor of shape [num_boxes, 4]. The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled crop is an up-down flipped version of the original image. The width dimension is treated similarly. Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to extrapolate the input image values.</param> | |||||
| /// <param name="box_ind">A Tensor of type int32. A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). The value of box_ind[i] specifies the image that the i-th box refers to.</param> | |||||
| /// <param name="crop_size">A Tensor of type int32. A 1-D tensor of 2 elements, size = [crop_height, crop_width]. All cropped image patches are resized to this size. The aspect ratio of the image content is not preserved. Both crop_height and crop_width need to be positive.</param> | |||||
| /// <param name="method">An optional string from: "bilinear", "nearest". Defaults to "bilinear". A string specifying the sampling method for resizing. It can be either "bilinear" or "nearest" and default to "bilinear". Currently two sampling methods are supported: Bilinear and Nearest Neighbor.</param> | |||||
| /// <param name="extrapolation_value">An optional float. Defaults to 0. Value used for extrapolation, when applicable.</param> | |||||
| /// <param name="name">A name for the operation (optional).</param> | |||||
| /// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns> | |||||
| public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) => | |||||
| image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name); | |||||
| public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) | public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, string name = null) | ||||
| => gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, name: name); | => gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, name: name); | ||||
| @@ -66,20 +66,20 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Initializer capable of adapting its scale to the shape of weights tensors. | /// Initializer capable of adapting its scale to the shape of weights tensors. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="scale"></param> | |||||
| /// <param name="factor"></param> | |||||
| /// <param name="mode"></param> | /// <param name="mode"></param> | ||||
| /// <param name="distribution"></param> | /// <param name="distribution"></param> | ||||
| /// <param name="seed"></param> | /// <param name="seed"></param> | ||||
| /// <param name="dtype"></param> | /// <param name="dtype"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public IInitializer variance_scaling_initializer(float scale = 1.0f, | |||||
| string mode = "fan_in", | |||||
| string distribution = "truncated_normal", | |||||
| public IInitializer variance_scaling_initializer(float factor = 1.0f, | |||||
| string mode = "FAN_IN", | |||||
| bool uniform = false, | |||||
| int? seed = null, | int? seed = null, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling( | TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling( | ||||
| scale: scale, | |||||
| factor: factor, | |||||
| mode: mode, | mode: mode, | ||||
| distribution: distribution, | |||||
| uniform: uniform, | |||||
| seed: seed, | seed: seed, | ||||
| dtype: dtype); | dtype: dtype); | ||||
| } | } | ||||
| @@ -434,14 +434,17 @@ namespace Tensorflow | |||||
| public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, | public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, | ||||
| bool keepdims = false, string name = null) | bool keepdims = false, string name = null) | ||||
| { | { | ||||
| if(!axis.HasValue && reduction_indices.HasValue) | |||||
| if (!axis.HasValue && reduction_indices.HasValue && !keepdims) | |||||
| return math_ops.reduce_sum(input, reduction_indices.Value); | return math_ops.reduce_sum(input, reduction_indices.Value); | ||||
| else if (axis.HasValue && !reduction_indices.HasValue) | |||||
| else if (axis.HasValue && !reduction_indices.HasValue && !keepdims) | |||||
| return math_ops.reduce_sum(input, axis.Value); | return math_ops.reduce_sum(input, axis.Value); | ||||
| return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||||
| else if (axis.HasValue && !reduction_indices.HasValue && keepdims) | |||||
| return math_ops.reduce_sum(input, keepdims: keepdims, axis: axis.Value, name: name); | |||||
| else | |||||
| return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||||
| } | } | ||||
| public Tensor reduce_sum(Tensor input, int[] axis, int? reduction_indices = null, | |||||
| public Tensor reduce_sum(Tensor input, TensorShape axis, int? reduction_indices = null, | |||||
| bool keepdims = false, string name = null) | bool keepdims = false, string name = null) | ||||
| => math_ops.reduce_sum(input, axis, keepdims: keepdims, name: name); | => math_ops.reduce_sum(input, axis, keepdims: keepdims, name: name); | ||||
| @@ -456,6 +459,9 @@ namespace Tensorflow | |||||
| public Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | public Tensor reduce_max(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | ||||
| => math_ops.reduce_max(input_tensor, axis, keepdims, name); | => math_ops.reduce_max(input_tensor, axis, keepdims, name); | ||||
| public Tensor reduce_max(Tensor input_tensor, int axis, bool keepdims = false, string name = null) | |||||
| => math_ops.reduce_max(input_tensor, axis, keepdims, name); | |||||
| public Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | public Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | ||||
| => math_ops.reduce_min(input_tensor, axis, keepdims, name); | => math_ops.reduce_min(input_tensor, axis, keepdims, name); | ||||
| @@ -468,6 +474,9 @@ namespace Tensorflow | |||||
| public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | ||||
| => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); | => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); | ||||
| public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) | |||||
| => math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name); | |||||
| public Tensor round(Tensor x, string name = null) | public Tensor round(Tensor x, string name = null) | ||||
| => gen_math_ops.round(x, name: name); | => gen_math_ops.round(x, name: name); | ||||
| @@ -111,6 +111,7 @@ namespace Tensorflow | |||||
| name: name); | name: name); | ||||
| public IActivation relu() => new relu(); | public IActivation relu() => new relu(); | ||||
| public IActivation swish() => new swish(); | |||||
| public Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name); | public Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name); | ||||
| @@ -206,6 +207,9 @@ namespace Tensorflow | |||||
| public Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) | public Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) | ||||
| => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); | => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); | ||||
| public Tensor sigmoid<T>(T x, string name = null) | |||||
| => math_ops.sigmoid(x, name: name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -28,21 +28,21 @@ namespace Tensorflow | |||||
| /// <param name="seed"></param> | /// <param name="seed"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor random_normal(int[] shape, | |||||
| public Tensor random_normal(TensorShape shape, | |||||
| float mean = 0.0f, | float mean = 0.0f, | ||||
| float stddev = 1.0f, | float stddev = 1.0f, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| int? seed = null, | int? seed = null, | ||||
| string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | ||||
| public Tensor random_uniform(int[] shape, | |||||
| public Tensor random_uniform(TensorShape shape, | |||||
| float minval = 0, | float minval = 0, | ||||
| float maxval = 1, | float maxval = 1, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| int? seed = null, | int? seed = null, | ||||
| string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); | string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); | ||||
| public Tensor truncated_normal(int[] shape, | |||||
| public Tensor truncated_normal(TensorShape shape, | |||||
| float mean = 0.0f, | float mean = 0.0f, | ||||
| float stddev = 1.0f, | float stddev = 1.0f, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| @@ -62,5 +62,13 @@ namespace Tensorflow | |||||
| /// </returns> | /// </returns> | ||||
| public Tensor random_shuffle(Tensor value, int? seed = null, string name = null) | public Tensor random_shuffle(Tensor value, int? seed = null, string name = null) | ||||
| => random_ops.random_shuffle(value, seed: seed, name: name); | => random_ops.random_shuffle(value, seed: seed, name: name); | ||||
| public void set_random_seed(int seed) | |||||
| => ops.get_default_graph().seed = seed; | |||||
| public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | |||||
| string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | |||||
| => random_ops.multinomial(logits, num_samples, seed: seed, | |||||
| name: name, output_dtype: output_dtype); | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.Optimizers; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -73,6 +74,26 @@ namespace Tensorflow | |||||
| public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | ||||
| => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); | => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); | ||||
| public Tensor polynomial_decay(float learning_rate, | |||||
| RefVariable global_step, | |||||
| float decay_steps, | |||||
| float end_learning_rate = 0.0001f, | |||||
| float power = 1.0f, | |||||
| bool cycle = false, | |||||
| string name = null) | |||||
| { | |||||
| var decayed = new PolynomialDecay(learning_rate, | |||||
| decay_steps, | |||||
| end_learning_rate: end_learning_rate, | |||||
| power: power, | |||||
| cycle: cycle, | |||||
| name: name); | |||||
| var decayed_lr = decayed.__call__(global_step); | |||||
| return decayed_lr; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -27,6 +27,15 @@ namespace Tensorflow | |||||
| .ToArray(); | .ToArray(); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Returns an Op that initializes a list of variables. | |||||
| /// </summary> | |||||
| /// <param name="var_list">List of `Variable` objects to initialize.</param> | |||||
| /// <param name="name">Optional name for the returned operation.</param> | |||||
| /// <returns>An Op that run the initializers of all the specified variables.</returns> | |||||
| public Operation variables_initializer(VariableV1[] var_list, string name = "init") | |||||
| => variables.variables_initializer(var_list, name: name); | |||||
| public Operation global_variables_initializer() | public Operation global_variables_initializer() | ||||
| { | { | ||||
| var g = variables.global_variables(); | var g = variables.global_variables(); | ||||
| @@ -54,9 +63,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| var scope = Tensorflow.variable_scope.get_variable_scope(); | var scope = Tensorflow.variable_scope.get_variable_scope(); | ||||
| var store = Tensorflow.variable_scope._get_default_variable_store(); | var store = Tensorflow.variable_scope._get_default_variable_store(); | ||||
| return scope.get_variable(store, | |||||
| name, | |||||
| shape: shape, | |||||
| return scope.get_variable(store, | |||||
| name, | |||||
| shape: shape, | |||||
| dtype: dtype, | dtype: dtype, | ||||
| use_resource: use_resource, | use_resource: use_resource, | ||||
| validate_shape: validate_shape, | validate_shape: validate_shape, | ||||
| @@ -0,0 +1,32 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static partial class Binding | |||||
| { | |||||
| public static class functools | |||||
| { | |||||
| public static PartialFunc<Tin, Tout> partial<Tin, Tout>(Func<Tin, Tout> func, Tin arg) | |||||
| => new PartialFunc<Tin, Tout> | |||||
| { | |||||
| args = arg, | |||||
| invoke = func | |||||
| }; | |||||
| public static Func<Tin1, Tin2, Tout> partial<Tin1, Tin2, Tout>(Func<Tin1, Tin2, Tout> func, (Tin1, Tin2) args) | |||||
| => (arg1, arg2) => func(args.Item1, args.Item2); | |||||
| } | |||||
| public class PartialFunc<Tin, Tout> | |||||
| { | |||||
| public Tin args { get; set; } | |||||
| public object[] keywords { get; set; } | |||||
| public Func<Tin, Tout> invoke { get; set; } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -115,6 +115,7 @@ namespace Tensorflow | |||||
| return instance; | return instance; | ||||
| } | } | ||||
| [DebuggerStepThrough] | |||||
| [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | ||||
| public static void tf_with(IObjectLife py, Action<IObjectLife> action) | public static void tf_with(IObjectLife py, Action<IObjectLife> action) | ||||
| { | { | ||||
| @@ -273,6 +274,12 @@ namespace Tensorflow | |||||
| return sum; | return sum; | ||||
| } | } | ||||
| public static float sum(IEnumerable<float> enumerable) | |||||
| => enumerable.Sum(); | |||||
| public static int sum(IEnumerable<int> enumerable) | |||||
| => enumerable.Sum(); | |||||
| public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values) | public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values) | ||||
| { | { | ||||
| return sum(values.Keys); | return sum(values.Keys); | ||||
| @@ -339,15 +346,5 @@ namespace Tensorflow | |||||
| return true; | return true; | ||||
| return false; | return false; | ||||
| } | } | ||||
| public static Func<Tin1, Tout> partial<Tin1, Tout>(Func<Tin1, Tout> func, Tin1 args) | |||||
| { | |||||
| Func<Tin1, Tout> newfunc = (args1) => | |||||
| { | |||||
| return func(args1); | |||||
| }; | |||||
| return newfunc; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,6 @@ | |||||
| namespace Tensorflow.Data | |||||
| using System; | |||||
| namespace Tensorflow.Data | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Represents a potentially large set of elements. | /// Represents a potentially large set of elements. | ||||
| @@ -11,5 +13,9 @@ | |||||
| /// </summary> | /// </summary> | ||||
| public class DatasetV2 | public class DatasetV2 | ||||
| { | { | ||||
| public static DatasetV2 from_generator() | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,138 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.IO; | |||||
| using System.Text; | |||||
| using Tensorflow.Data; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| /// <summary> | |||||
| /// Estimator class to train and evaluate TensorFlow models. | |||||
| /// </summary> | |||||
| public class Estimator : IObjectLife | |||||
| { | |||||
| RunConfig _config; | |||||
| public RunConfig config => _config; | |||||
| ConfigProto _session_config; | |||||
| public ConfigProto session_config => _session_config; | |||||
| string _model_dir; | |||||
| Action _model_fn; | |||||
| public Estimator(Action model_fn, RunConfig config) | |||||
| { | |||||
| _config = config; | |||||
| _model_dir = _config.model_dir; | |||||
| _session_config = _config.session_config; | |||||
| _model_fn = model_fn; | |||||
| } | |||||
| public Estimator train(Func<DatasetV1Adapter> input_fn, int max_steps = 1, Action[] hooks = null, | |||||
| _NewCheckpointListenerForEvaluate[] saving_listeners = null) | |||||
| { | |||||
| if(max_steps > 0) | |||||
| { | |||||
| var start_step = _load_global_step_from_checkpoint_dir(_model_dir); | |||||
| if (max_steps <= start_step) | |||||
| { | |||||
| Console.WriteLine("Skipping training since max_steps has already saved."); | |||||
| return this; | |||||
| } | |||||
| } | |||||
| var loss = _train_model(input_fn); | |||||
| print($"Loss for final step: {loss}."); | |||||
| return this; | |||||
| } | |||||
| private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) | |||||
| { | |||||
| // var cp = tf.train.latest_checkpoint(checkpoint_dir); | |||||
| // should use NewCheckpointReader (not implemented) | |||||
| var cp = tf.train.get_checkpoint_state(checkpoint_dir); | |||||
| return cp.AllModelCheckpointPaths.Count - 1; | |||||
| } | |||||
| private Tensor _train_model(Func<DatasetV1Adapter> input_fn) | |||||
| { | |||||
| return _train_model_default(input_fn); | |||||
| } | |||||
| private Tensor _train_model_default(Func<DatasetV1Adapter> input_fn) | |||||
| { | |||||
| using (var g = tf.Graph().as_default()) | |||||
| { | |||||
| var global_step_tensor = _create_and_assert_global_step(g); | |||||
| // Skip creating a read variable if _create_and_assert_global_step | |||||
| // returns None (e.g. tf.contrib.estimator.SavedModelEstimator). | |||||
| if (global_step_tensor != null) | |||||
| TrainingUtil._get_or_create_global_step_read(g); | |||||
| var (features, labels) = _get_features_and_labels_from_input_fn(input_fn, "train"); | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| private (Dictionary<string, Tensor>, Dictionary<string, Tensor>) _get_features_and_labels_from_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
| { | |||||
| var result = _call_input_fn(input_fn, mode); | |||||
| return EstimatorUtil.parse_input_fn_result(result); | |||||
| } | |||||
| /// <summary> | |||||
| /// Calls the input function. | |||||
| /// </summary> | |||||
| /// <param name="input_fn"></param> | |||||
| /// <param name="mode"></param> | |||||
| private DatasetV1Adapter _call_input_fn(Func<DatasetV1Adapter> input_fn, string mode) | |||||
| { | |||||
| return input_fn(); | |||||
| } | |||||
| private Tensor _create_and_assert_global_step(Graph graph) | |||||
| { | |||||
| var step = _create_global_step(graph); | |||||
| Debug.Assert(step == tf.train.get_global_step(graph)); | |||||
| Debug.Assert(step.dtype.is_integer()); | |||||
| return step; | |||||
| } | |||||
| private RefVariable _create_global_step(Graph graph) | |||||
| { | |||||
| return tf.train.create_global_step(graph); | |||||
| } | |||||
| public void __init__() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public void __enter__() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public void __del__() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public void __exit__() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,15 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Data; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class EstimatorUtil | |||||
| { | |||||
| public static (Dictionary<string, Tensor>, Dictionary<string, Tensor>) parse_input_fn_result(DatasetV1Adapter result) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,16 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class EvalSpec | |||||
| { | |||||
| string _name; | |||||
| public EvalSpec(string name, Action input_fn, FinalExporter exporters) | |||||
| { | |||||
| _name = name; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,11 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public abstract class Exporter | |||||
| { | |||||
| public abstract void export(Estimator estimator, string export_path, string checkpoint_path); | |||||
| } | |||||
| } | |||||
| @@ -1,14 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class FinalExporter | |||||
| { | |||||
| public FinalExporter(string name, Action serving_input_receiver_fn, bool as_text = false) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,89 +0,0 @@ | |||||
| using System.IO; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class HyperParams | |||||
| { | |||||
| /// <summary> | |||||
| /// root dir | |||||
| /// </summary> | |||||
| public string data_root_dir { get; set; } | |||||
| /// <summary> | |||||
| /// results dir | |||||
| /// </summary> | |||||
| public string result_dir { get; set; } = "results"; | |||||
| /// <summary> | |||||
| /// model dir | |||||
| /// </summary> | |||||
| public string model_dir { get; set; } = "model"; | |||||
| public string eval_dir { get; set; } = "eval"; | |||||
| public string test_dir { get; set; } = "test"; | |||||
| public int dim { get; set; } = 300; | |||||
| public float dropout { get; set; } = 0.5f; | |||||
| public int num_oov_buckets { get; set; } = 1; | |||||
| public int epochs { get; set; } = 25; | |||||
| public int epoch_no_imprv { get; set; } = 3; | |||||
| public int batch_size { get; set; } = 20; | |||||
| public int buffer { get; set; } = 15000; | |||||
| public int lstm_size { get; set; } = 100; | |||||
| public string lr_method { get; set; } = "adam"; | |||||
| public float lr { get; set; } = 0.001f; | |||||
| public float lr_decay { get; set; } = 0.9f; | |||||
| /// <summary> | |||||
| /// lstm on chars | |||||
| /// </summary> | |||||
| public int hidden_size_char { get; set; } = 100; | |||||
| /// <summary> | |||||
| /// lstm on word embeddings | |||||
| /// </summary> | |||||
| public int hidden_size_lstm { get; set; } = 300; | |||||
| /// <summary> | |||||
| /// is clipping | |||||
| /// </summary> | |||||
| public bool clip { get; set; } = false; | |||||
| public string filepath_dev { get; set; } | |||||
| public string filepath_test { get; set; } | |||||
| public string filepath_train { get; set; } | |||||
| public string filepath_words { get; set; } | |||||
| public string filepath_chars { get; set; } | |||||
| public string filepath_tags { get; set; } | |||||
| public string filepath_glove { get; set; } | |||||
| public HyperParams(string dataDir) | |||||
| { | |||||
| data_root_dir = dataDir; | |||||
| if (string.IsNullOrEmpty(data_root_dir)) | |||||
| throw new ValueError("Please specifiy the root data directory"); | |||||
| if (!Directory.Exists(data_root_dir)) | |||||
| Directory.CreateDirectory(data_root_dir); | |||||
| result_dir = Path.Combine(data_root_dir, result_dir); | |||||
| if (!Directory.Exists(result_dir)) | |||||
| Directory.CreateDirectory(result_dir); | |||||
| model_dir = Path.Combine(result_dir, model_dir); | |||||
| if (!Directory.Exists(model_dir)) | |||||
| Directory.CreateDirectory(model_dir); | |||||
| test_dir = Path.Combine(result_dir, test_dir); | |||||
| if (!Directory.Exists(test_dir)) | |||||
| Directory.CreateDirectory(test_dir); | |||||
| eval_dir = Path.Combine(result_dir, eval_dir); | |||||
| if (!Directory.Exists(eval_dir)) | |||||
| Directory.CreateDirectory(eval_dir); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,7 +0,0 @@ | |||||
| # TensorFlow Estimator | |||||
| TensorFlow Estimator is a high-level TensorFlow API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model. | |||||
| https://github.com/tensorflow/estimator | |||||
| @@ -1,101 +0,0 @@ | |||||
| using System; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class RunConfig | |||||
| { | |||||
| // A list of the property names in RunConfig that the user is allowed to change. | |||||
| private static readonly string[] _DEFAULT_REPLACEABLE_LIST = new [] | |||||
| { | |||||
| "model_dir", | |||||
| "tf_random_seed", | |||||
| "save_summary_steps", | |||||
| "save_checkpoints_steps", | |||||
| "save_checkpoints_secs", | |||||
| "session_config", | |||||
| "keep_checkpoint_max", | |||||
| "keep_checkpoint_every_n_hours", | |||||
| "log_step_count_steps", | |||||
| "train_distribute", | |||||
| "device_fn", | |||||
| "protocol", | |||||
| "eval_distribute", | |||||
| "experimental_distribute", | |||||
| "experimental_max_worker_delay_secs", | |||||
| "session_creation_timeout_secs" | |||||
| }; | |||||
| #region const values | |||||
| private const string _SAVE_CKPT_ERR = "`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set."; | |||||
| private const string _TF_CONFIG_ENV = "TF_CONFIG"; | |||||
| private const string _TASK_ENV_KEY = "task"; | |||||
| private const string _TASK_TYPE_KEY = "type"; | |||||
| private const string _TASK_ID_KEY = "index"; | |||||
| private const string _CLUSTER_KEY = "cluster"; | |||||
| private const string _SERVICE_KEY = "service"; | |||||
| private const string _SESSION_MASTER_KEY = "session_master"; | |||||
| private const string _EVAL_SESSION_MASTER_KEY = "eval_session_master"; | |||||
| private const string _MODEL_DIR_KEY = "model_dir"; | |||||
| private const string _LOCAL_MASTER = ""; | |||||
| private const string _GRPC_SCHEME = "grpc://"; | |||||
| #endregion | |||||
| public string model_dir { get; set; } | |||||
| public ConfigProto session_config { get; set; } | |||||
| public int? tf_random_seed { get; set; } | |||||
| public int save_summary_steps { get; set; } = 100; | |||||
| public int save_checkpoints_steps { get; set; } | |||||
| public int save_checkpoints_secs { get; set; } = 600; | |||||
| public int keep_checkpoint_max { get; set; } = 5; | |||||
| public int keep_checkpoint_every_n_hours { get; set; } = 10000; | |||||
| public int log_step_count_steps{ get; set; } = 100; | |||||
| public object train_distribute { get; set; } | |||||
| public object device_fn { get; set; } | |||||
| public object protocol { get; set; } | |||||
| public object eval_distribute { get; set; } | |||||
| public object experimental_distribute { get; set; } | |||||
| public object experimental_max_worker_delay_secs { get; set; } | |||||
| public int session_creation_timeout_secs { get; set; } = 7200; | |||||
| public object service { get; set; } | |||||
| public RunConfig() | |||||
| { | |||||
| Initialize(); | |||||
| } | |||||
| public RunConfig(string model_dir) | |||||
| { | |||||
| this.model_dir = model_dir; | |||||
| Initialize(); | |||||
| } | |||||
| public RunConfig( | |||||
| string model_dir = null, | |||||
| int? tf_random_seed = null, | |||||
| int save_summary_steps=100, | |||||
| object save_checkpoints_steps = null, // _USE_DEFAULT | |||||
| object save_checkpoints_secs = null, // _USE_DEFAULT | |||||
| object session_config = null, | |||||
| int keep_checkpoint_max = 5, | |||||
| int keep_checkpoint_every_n_hours = 10000, | |||||
| int log_step_count_steps = 100, | |||||
| object train_distribute = null, | |||||
| object device_fn = null, | |||||
| object protocol = null, | |||||
| object eval_distribute = null, | |||||
| object experimental_distribute = null, | |||||
| object experimental_max_worker_delay_secs = null, | |||||
| int session_creation_timeout_secs = 7200) | |||||
| { | |||||
| this.model_dir = model_dir; | |||||
| Initialize(); | |||||
| } | |||||
| private void Initialize() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,22 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Data; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class TrainSpec | |||||
| { | |||||
| int _max_steps; | |||||
| public int max_steps => _max_steps; | |||||
| Func<DatasetV1Adapter> _input_fn; | |||||
| public Func<DatasetV1Adapter> input_fn => _input_fn; | |||||
| public TrainSpec(Func<DatasetV1Adapter> input_fn, int max_steps) | |||||
| { | |||||
| _max_steps = max_steps; | |||||
| _input_fn = input_fn; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,17 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class Training | |||||
| { | |||||
| public static void train_and_evaluate(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||||
| { | |||||
| var executor = new _TrainingExecutor(estimator: estimator, train_spec: train_spec, eval_spec: eval_spec); | |||||
| var config = estimator.config; | |||||
| executor.run(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,14 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class _Evaluator | |||||
| { | |||||
| public _Evaluator(Estimator estimator, EvalSpec eval_spec, int max_training_steps) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,16 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class _NewCheckpointListenerForEvaluate | |||||
| { | |||||
| _Evaluator _evaluator; | |||||
| public _NewCheckpointListenerForEvaluate(_Evaluator evaluator, int eval_throttle_secs) | |||||
| { | |||||
| _evaluator = evaluator; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,14 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| public class _SavedModelExporter : Exporter | |||||
| { | |||||
| public override void export(Estimator estimator, string export_path, string checkpoint_path) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,48 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Estimators | |||||
| { | |||||
| /// <summary> | |||||
| /// The executor to run `Estimator` training and evaluation. | |||||
| /// </summary> | |||||
| internal class _TrainingExecutor | |||||
| { | |||||
| Estimator _estimator; | |||||
| EvalSpec _eval_spec; | |||||
| TrainSpec _train_spec; | |||||
| public _TrainingExecutor(Estimator estimator, TrainSpec train_spec, EvalSpec eval_spec) | |||||
| { | |||||
| _estimator = estimator; | |||||
| _train_spec = train_spec; | |||||
| _eval_spec = eval_spec; | |||||
| } | |||||
| public void run() | |||||
| { | |||||
| var config = _estimator.config; | |||||
| Console.WriteLine("Running training and evaluation locally (non-distributed)."); | |||||
| run_local(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Runs training and evaluation locally (non-distributed). | |||||
| /// </summary> | |||||
| private void run_local() | |||||
| { | |||||
| var train_hooks = new Action[0]; | |||||
| Console.WriteLine("Start train and evaluate loop. The evaluate will happen " + | |||||
| "after every checkpoint. Checkpoint frequency is determined " + | |||||
| $"based on RunConfig arguments: save_checkpoints_steps {_estimator.config.save_checkpoints_steps} or " + | |||||
| $"save_checkpoints_secs {_estimator.config.save_checkpoints_secs}."); | |||||
| var evaluator = new _Evaluator(_estimator, _eval_spec, _train_spec.max_steps); | |||||
| var saving_listeners = new _NewCheckpointListenerForEvaluate[0]; | |||||
| _estimator.train(input_fn: _train_spec.input_fn, | |||||
| max_steps: _train_spec.max_steps, | |||||
| hooks: train_hooks, | |||||
| saving_listeners: saving_listeners); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -46,9 +46,9 @@ namespace Tensorflow | |||||
| if (!string.IsNullOrEmpty(unbound_inputs_col_name)) | if (!string.IsNullOrEmpty(unbound_inputs_col_name)) | ||||
| { | { | ||||
| foreach(var col in meta_graph_def.CollectionDef) | |||||
| foreach (var col in meta_graph_def.CollectionDef) | |||||
| { | { | ||||
| if(col.Key == unbound_inputs_col_name) | |||||
| if (col.Key == unbound_inputs_col_name) | |||||
| { | { | ||||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | ||||
| } | } | ||||
| @@ -78,7 +78,7 @@ namespace Tensorflow | |||||
| // Restores all the other collections. | // Restores all the other collections. | ||||
| var variable_objects = new Dictionary<ByteString, VariableV1>(); | var variable_objects = new Dictionary<ByteString, VariableV1>(); | ||||
| foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) | |||||
| foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) | |||||
| { | { | ||||
| // Don't add unbound_inputs to the new graph. | // Don't add unbound_inputs to the new graph. | ||||
| if (col.Key == unbound_inputs_col_name) | if (col.Key == unbound_inputs_col_name) | ||||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||||
| switch (col.Value.KindCase) | switch (col.Value.KindCase) | ||||
| { | { | ||||
| case KindOneofCase.NodeList: | case KindOneofCase.NodeList: | ||||
| foreach(var value in col.Value.NodeList.Value) | |||||
| foreach (var value in col.Value.NodeList.Value) | |||||
| { | { | ||||
| var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names)); | var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names)); | ||||
| graph.add_to_collection(col.Key, col_op); | graph.add_to_collection(col.Key, col_op); | ||||
| @@ -115,7 +115,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| foreach(var value in col.Value.BytesList.Value) | |||||
| foreach (var value in col.Value.BytesList.Value) | |||||
| { | { | ||||
| switch (col.Key) | switch (col.Key) | ||||
| { | { | ||||
| @@ -139,7 +139,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | ||||
| @@ -173,8 +173,8 @@ namespace Tensorflow | |||||
| string unbound_inputs_col_name = "unbound_inputs", | string unbound_inputs_col_name = "unbound_inputs", | ||||
| bool clear_devices = false, | bool clear_devices = false, | ||||
| SaverDef saver_def = null, | SaverDef saver_def = null, | ||||
| bool clear_extraneous_savers= false, | |||||
| bool strip_default_attrs= false, | |||||
| bool clear_extraneous_savers = false, | |||||
| bool strip_default_attrs = false, | |||||
| byte[] meta_info_def = null) | byte[] meta_info_def = null) | ||||
| { | { | ||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| @@ -236,12 +236,12 @@ namespace Tensorflow | |||||
| meta_graph_def.GraphDef = graph_def; | meta_graph_def.GraphDef = graph_def; | ||||
| // Fills in meta_info_def.stripped_op_list using the ops from graph_def. | // Fills in meta_info_def.stripped_op_list using the ops from graph_def. | ||||
| if (meta_graph_def.MetaInfoDef.StrippedOpList == null || | |||||
| if (meta_graph_def.MetaInfoDef.StrippedOpList == null || | |||||
| meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0) | meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0) | ||||
| meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef); | meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef); | ||||
| var clist = graph.get_all_collection_keys(); | var clist = graph.get_all_collection_keys(); | ||||
| foreach(var ctype in clist) | |||||
| foreach (var ctype in clist) | |||||
| { | { | ||||
| if (clear_extraneous_savers) | if (clear_extraneous_savers) | ||||
| { | { | ||||
| @@ -256,8 +256,8 @@ namespace Tensorflow | |||||
| return meta_graph_def; | return meta_graph_def; | ||||
| } | } | ||||
| private static void add_collection_def(MetaGraphDef meta_graph_def, | |||||
| string key, | |||||
| private static void add_collection_def(MetaGraphDef meta_graph_def, | |||||
| string key, | |||||
| Graph graph = null, | Graph graph = null, | ||||
| string export_scope = "") | string export_scope = "") | ||||
| { | { | ||||
| @@ -274,7 +274,7 @@ namespace Tensorflow | |||||
| var proto = x.to_proto(export_scope); | var proto = x.to_proto(export_scope); | ||||
| col_def.BytesList.Value.Add(proto.ToByteString()); | col_def.BytesList.Value.Add(proto.ToByteString()); | ||||
| } | } | ||||
| break; | break; | ||||
| case List<object> collection_list: | case List<object> collection_list: | ||||
| col_def.NodeList = new Types.NodeList(); | col_def.NodeList = new Types.NodeList(); | ||||
| @@ -75,9 +75,9 @@ namespace Tensorflow | |||||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | ||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | ||||
| public partial class Graph : DisposableObject, | |||||
| public partial class Graph : DisposableObject | |||||
| #if !SERIALIZABLE | #if !SERIALIZABLE | ||||
| IEnumerable<Operation> | |||||
| , IEnumerable<Operation> | |||||
| #endif | #endif | ||||
| { | { | ||||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | private Dictionary<int, ITensorOrOperation> _nodes_by_id; | ||||
| @@ -105,8 +105,18 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| private Dictionary<string, object> _collections = new Dictionary<string, object>(); | private Dictionary<string, object> _collections = new Dictionary<string, object>(); | ||||
| public bool building_function; | |||||
| public bool building_function; | |||||
| int _seed; | |||||
| public int seed | |||||
| { | |||||
| get => _seed; | |||||
| set | |||||
| { | |||||
| _seed = value; | |||||
| } | |||||
| } | |||||
| public Graph() | public Graph() | ||||
| { | { | ||||
| _handle = c_api.TF_NewGraph(); | _handle = c_api.TF_NewGraph(); | ||||
| @@ -230,10 +240,6 @@ namespace Tensorflow | |||||
| public void add_to_collection<T>(string name, T value) | public void add_to_collection<T>(string name, T value) | ||||
| { | { | ||||
| if(name == "update_ops") | |||||
| { | |||||
| } | |||||
| _check_not_finalized(); | _check_not_finalized(); | ||||
| if (_collections.ContainsKey(name)) | if (_collections.ContainsKey(name)) | ||||
| (_collections[name] as List<T>).Add(value); | (_collections[name] as List<T>).Add(value); | ||||
| @@ -404,7 +410,7 @@ namespace Tensorflow | |||||
| _names_in_use[name_key] = 1; | _names_in_use[name_key] = 1; | ||||
| // Return the new name with the original capitalization of the given name. | // Return the new name with the original capitalization of the given name. | ||||
| name = $"{name}_{i-1}"; | |||||
| name = $"{name}_{i - 1}"; | |||||
| } | } | ||||
| return name; | return name; | ||||
| } | } | ||||
| @@ -417,8 +423,8 @@ namespace Tensorflow | |||||
| TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | ||||
| unsafe | unsafe | ||||
| { | { | ||||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| var tf_output_ptr = (TF_Output*)return_output_handle; | |||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| return_outputs[i] = *(tf_output_ptr + i); | return_outputs[i] = *(tf_output_ptr + i); | ||||
| return return_outputs; | return return_outputs; | ||||
| } | } | ||||
| @@ -519,7 +525,7 @@ namespace Tensorflow | |||||
| string debugString = string.Empty; | string debugString = string.Empty; | ||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| return $"{graph_key}, ({_handle})"; | |||||
| return $"{graph_key}, ({_handle})"; | |||||
| /*if (string.IsNullOrEmpty(debugString)) | /*if (string.IsNullOrEmpty(debugString)) | ||||
| { | { | ||||
| int len = 0; | int len = 0; | ||||
| @@ -536,7 +542,7 @@ namespace Tensorflow | |||||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | ||||
| => GetEnumerable().GetEnumerator(); | => GetEnumerable().GetEnumerator(); | ||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| => throw new NotImplementedException(); | => throw new NotImplementedException(); | ||||
| #endif | #endif | ||||
| @@ -27,7 +27,7 @@ namespace Tensorflow.Keras | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public IInitializer he_normal(int? seed = null) | public IInitializer he_normal(int? seed = null) | ||||
| { | { | ||||
| return new VarianceScaling(scale: 2.0f, mode: "fan_in", distribution: "truncated_normal", seed: seed); | |||||
| return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.Keras.Optimizers | |||||
| { | |||||
| public class LearningRateSchedule | |||||
| { | |||||
| public LearningRateSchedule() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,62 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.Optimizers | |||||
| { | |||||
| /// <summary> | |||||
| /// A LearningRateSchedule that uses a polynomial decay schedule. | |||||
| /// </summary> | |||||
| public class PolynomialDecay : LearningRateSchedule | |||||
| { | |||||
| float initial_learning_rate; | |||||
| float decay_steps; | |||||
| float end_learning_rate; | |||||
| float power; | |||||
| bool cycle; | |||||
| string name; | |||||
| public PolynomialDecay(float initial_learning_rate, | |||||
| float decay_steps, | |||||
| float end_learning_rate = 0.0001f, | |||||
| float power = 1.0f, | |||||
| bool cycle = false, | |||||
| string name = null) : base() | |||||
| { | |||||
| this.initial_learning_rate = initial_learning_rate; | |||||
| this.decay_steps = decay_steps; | |||||
| this.end_learning_rate = end_learning_rate; | |||||
| this.power = power; | |||||
| this.cycle = cycle; | |||||
| this.name = name; | |||||
| } | |||||
| public Tensor __call__(RefVariable step) | |||||
| { | |||||
| tf_with(ops.name_scope(name ?? "PolynomialDecay"), scope => | |||||
| { | |||||
| name = scope; | |||||
| var initial_learning_rate_tensor = ops.convert_to_tensor(initial_learning_rate, name: "initial_learning_rate"); | |||||
| var dtype = initial_learning_rate_tensor.dtype; | |||||
| var end_learning_rate_tensor = math_ops.cast(end_learning_rate, dtype); | |||||
| var power_tensor = math_ops.cast(power, dtype); | |||||
| var global_step_recomp = math_ops.cast(step, dtype); | |||||
| var decay_steps_recomp = math_ops.cast(decay_steps, dtype); | |||||
| if(cycle) | |||||
| { | |||||
| throw new NotImplementedException("PolynomialDecay cycle"); | |||||
| } | |||||
| else | |||||
| { | |||||
| } | |||||
| }); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -102,6 +102,14 @@ namespace Tensorflow.Operations.Activation | |||||
| } | } | ||||
| } | } | ||||
| public class swish : IActivation | |||||
| { | |||||
| public Tensor Activate(Tensor x, string name = null) | |||||
| { | |||||
| return tf.multiply(x, tf.nn.sigmoid(x)); | |||||
| } | |||||
| } | |||||
| public class linear : IActivation | public class linear : IActivation | ||||
| { | { | ||||
| public Tensor Activate(Tensor x, string name = null) | public Tensor Activate(Tensor x, string name = null) | ||||
| @@ -305,9 +305,7 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| public virtual bool IsWhileContext() | public virtual bool IsWhileContext() | ||||
| { | |||||
| throw new NotImplementedException("IsWhileContext"); | |||||
| } | |||||
| => false; | |||||
| public virtual bool IsCondContext() | public virtual bool IsCondContext() | ||||
| => false; | => false; | ||||
| @@ -19,10 +19,12 @@ namespace Tensorflow.Operations.Initializers | |||||
| public class GlorotUniform : VarianceScaling | public class GlorotUniform : VarianceScaling | ||||
| { | { | ||||
| public GlorotUniform(float scale = 1.0f, | public GlorotUniform(float scale = 1.0f, | ||||
| string mode = "fan_avg", | |||||
| string distribution = "uniform", | |||||
| string mode = "FAN_AVG", | |||||
| int? seed = null, | int? seed = null, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale, | |||||
| mode: mode, | |||||
| seed: seed, | |||||
| dtype: dtype) | |||||
| { | { | ||||
| } | } | ||||
| @@ -33,7 +35,6 @@ namespace Tensorflow.Operations.Initializers | |||||
| { | { | ||||
| scale = _scale, | scale = _scale, | ||||
| mode = _mode, | mode = _mode, | ||||
| distribution = _distribution, | |||||
| seed = _seed, | seed = _seed, | ||||
| dtype = _dtype | dtype = _dtype | ||||
| }; | }; | ||||
| @@ -30,45 +30,51 @@ namespace Tensorflow.Operations.Initializers | |||||
| protected string _distribution; | protected string _distribution; | ||||
| protected int? _seed; | protected int? _seed; | ||||
| protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
| protected bool _uniform; | |||||
| public VarianceScaling(float scale = 1.0f, | |||||
| string mode = "fan_in", | |||||
| string distribution = "truncated_normal", | |||||
| public VarianceScaling(float factor = 2.0f, | |||||
| string mode = "FAN_IN", | |||||
| bool uniform = false, | |||||
| int? seed = null, | int? seed = null, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT) | TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| { | { | ||||
| if (scale < 0) | |||||
| if (!dtype.is_floating()) | |||||
| throw new TypeError("Cannot create initializer for non-floating point type."); | |||||
| if (!new string[] { "FAN_IN", "FAN_OUT", "FAN_AVG" }.Contains(mode)) | |||||
| throw new TypeError($"Unknown {mode} %s [FAN_IN, FAN_OUT, FAN_AVG]"); | |||||
| if (factor < 0) | |||||
| throw new ValueError("`scale` must be positive float."); | throw new ValueError("`scale` must be positive float."); | ||||
| _scale = scale; | |||||
| _scale = factor; | |||||
| _mode = mode; | _mode = mode; | ||||
| _distribution = distribution; | |||||
| _seed = seed; | _seed = seed; | ||||
| _dtype = dtype; | _dtype = dtype; | ||||
| _uniform = uniform; | |||||
| } | } | ||||
| public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | ||||
| { | { | ||||
| float n = 0; | |||||
| var (fan_in, fan_out) = _compute_fans(shape); | var (fan_in, fan_out) = _compute_fans(shape); | ||||
| if (_mode == "fan_in") | |||||
| _scale /= Math.Max(1, fan_in); | |||||
| else if (_mode == "fan_out") | |||||
| _scale /= Math.Max(1, fan_out); | |||||
| else | |||||
| _scale /= Math.Max(1, (fan_in + fan_out) / 2); | |||||
| if (_mode == "FAN_IN") | |||||
| n = fan_in; | |||||
| else if (_mode == "FAN_OUT") | |||||
| n = fan_out; | |||||
| else if(_mode == "FAN_AVG") | |||||
| n = (fan_in + fan_out) / 2.0f; | |||||
| if (_distribution == "normal" || _distribution == "truncated_normal") | |||||
| { | |||||
| float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f; | |||||
| return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed); | |||||
| } | |||||
| else if (_distribution == "untruncated_normal") | |||||
| if(_uniform) | |||||
| { | { | ||||
| throw new NotImplementedException("truncated_normal"); | |||||
| var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n)); | |||||
| return random_ops.random_uniform(shape, -limit, limit, | |||||
| dtype, seed: _seed); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var limit = Math.Sqrt(3.0f * _scale); | |||||
| return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); | |||||
| var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n)); | |||||
| return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype, | |||||
| seed: _seed); | |||||
| } | } | ||||
| } | } | ||||
| @@ -101,6 +107,7 @@ namespace Tensorflow.Operations.Initializers | |||||
| mode = _mode, | mode = _mode, | ||||
| distribution = _distribution, | distribution = _distribution, | ||||
| seed = _seed, | seed = _seed, | ||||
| uniform = _uniform, | |||||
| dtype = _dtype | dtype = _dtype | ||||
| }; | }; | ||||
| } | } | ||||
| @@ -60,6 +60,9 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// List this operation's output types. | /// List this operation's output types. | ||||
| /// </summary> | /// </summary> | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public TF_DataType[] _output_types | public TF_DataType[] _output_types | ||||
| { | { | ||||
| get | get | ||||
| @@ -78,7 +78,10 @@ namespace Tensorflow | |||||
| #if SERIALIZABLE | #if SERIALIZABLE | ||||
| [JsonIgnore] | [JsonIgnore] | ||||
| #endif | #endif | ||||
| bool _is_stateful; | |||||
| bool _is_stateful; | |||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public NodeDef node_def | public NodeDef node_def | ||||
| { | { | ||||
| get | get | ||||
| @@ -181,8 +184,8 @@ namespace Tensorflow | |||||
| // This will be set by self.inputs. | // This will be set by self.inputs. | ||||
| if (op_def == null) | if (op_def == null) | ||||
| op_def = g.GetOpDef(node_def.Op); | |||||
| op_def = g.GetOpDef(node_def.Op); | |||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
| _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | ||||
| _is_stateful = op_def.IsStateful; | _is_stateful = op_def.IsStateful; | ||||
| @@ -376,19 +376,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | ||||
| { | { | ||||
| // TODO: here a chunk of original code is missing | |||||
| /* | |||||
| with ops.name_scope(name, "cond", [pred]): | |||||
| if context.executing_eagerly(): | |||||
| if pred: | |||||
| return _UnpackIfSingleton(true_fn()) | |||||
| return _UnpackIfSingleton(false_fn()) | |||||
| */ | |||||
| // Add the Switch to the graph. | // Add the Switch to the graph. | ||||
| var switch_result= @switch(pred, pred); | var switch_result= @switch(pred, pred); | ||||
| var p_2=switch_result[0]; | |||||
| var p_1 = switch_result[1]; | |||||
| var (p_2, p_1 )= (switch_result[0], switch_result[1]); | |||||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | ||||
| var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | ||||
| pred = array_ops.identity(pred, name: "pred_id"); | pred = array_ops.identity(pred, name: "pred_id"); | ||||
| @@ -405,6 +395,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| context_t.Enter(); | context_t.Enter(); | ||||
| (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | ||||
| context_t.ExitResult(new[] { res_t }); | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| @@ -418,46 +409,36 @@ namespace Tensorflow | |||||
| { | { | ||||
| context_f.Enter(); | context_f.Enter(); | ||||
| (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | ||||
| context_f.ExitResult(new[] { res_f }); | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| context_f.Exit(); | context_f.Exit(); | ||||
| } | } | ||||
| //TODO: missing original code | |||||
| //if not strict: | |||||
| // orig_res_t = _UnpackIfSingleton(orig_res_t) | |||||
| // orig_res_f = _UnpackIfSingleton(orig_res_f) | |||||
| /* | |||||
| # Check that the return values of the two branches have the same structure. | |||||
| try: | |||||
| nest.assert_same_structure(orig_res_t, orig_res_f) | |||||
| except TypeError as e: | |||||
| raise TypeError( | |||||
| "Incompatible return types of true_fn and false_fn: {}".format(e)) | |||||
| except ValueError as e: | |||||
| raise ValueError( | |||||
| "Incompatible return values of true_fn and false_fn: {}".format(e))*/ | |||||
| var res_t_flat = new Tensor[] { res_t }; | var res_t_flat = new Tensor[] { res_t }; | ||||
| var res_f_flat = new Tensor[] { res_f }; | var res_f_flat = new Tensor[] { res_f }; | ||||
| foreach(var (val_x, val_y) in zip(res_t_flat, res_f_flat)) | |||||
| { | |||||
| } | |||||
| var merges = zip(res_f_flat, res_t_flat) | var merges = zip(res_f_flat, res_t_flat) | ||||
| .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) | |||||
| .Select(m => (Tensor)m) | |||||
| .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0]) | |||||
| .ToArray(); | .ToArray(); | ||||
| var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); | |||||
| if (orig_res_t is Tensor orig_res_tensor) | |||||
| merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges) | |||||
| .Select(x => x as Tensor) | |||||
| .ToArray(); | |||||
| else | |||||
| { | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| } | |||||
| return new Tensor(IntPtr.Zero); | |||||
| if(context_t.outer_context == null) | |||||
| { | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| } | |||||
| return merges[0]; | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -485,28 +466,43 @@ namespace Tensorflow | |||||
| var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | ||||
| context_t.Enter(); | context_t.Enter(); | ||||
| var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | ||||
| context_t.ExitResult(res_t); | |||||
| context_t.Exit(); | context_t.Exit(); | ||||
| // Build the graph for the false branch in a new context. | // Build the graph for the false branch in a new context. | ||||
| var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | ||||
| context_f.Enter(); | context_f.Enter(); | ||||
| var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | ||||
| context_f.ExitResult(res_f); | |||||
| context_f.Exit(); | context_f.Exit(); | ||||
| var res_t_flat = res_t; | var res_t_flat = res_t; | ||||
| var res_f_flat = res_f; | var res_f_flat = res_f; | ||||
| var merges = zip(res_f_flat, res_t_flat) | var merges = zip(res_f_flat, res_t_flat) | ||||
| .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) | |||||
| .Select(m => (Tensor)m) | |||||
| .Select(pair => merge(new [] { pair.Item1, pair.Item2 })[0]) | |||||
| .ToArray(); | .ToArray(); | ||||
| var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); | |||||
| if (orig_res_t is Tensor[] orig_res_tensor) | |||||
| merges = _convert_flows_to_tensorarrays(orig_res_tensor, merges) | |||||
| .Select(x => x as Tensor) | |||||
| .ToArray(); | |||||
| else if (orig_res_t is float[] orig_res_float) | |||||
| { | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| } | |||||
| else | |||||
| { | |||||
| } | |||||
| if(context_t.outer_context == null) | |||||
| { | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| } | |||||
| return new[] { new Tensor(IntPtr.Zero) }; | |||||
| return merges; | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -132,7 +132,18 @@ namespace Tensorflow | |||||
| if (while_ctxt == null) | if (while_ctxt == null) | ||||
| { | { | ||||
| throw new NotImplementedException("CheckInputFromValidContext"); | |||||
| // Neither op nor input_op is in a while loop, but one or both are in | |||||
| // conds. We allow this, although execution will fail if the branch | |||||
| // corresponding to input_op's cond context isn't taken. | |||||
| if (input_while_ctxt == null) | |||||
| valid = true; | |||||
| // Invalid if op isn't in a while loop and input_op is. Unless... | |||||
| if (IsLoopEnter(op)) | |||||
| // WhileContext._BuildLoop clears context for Enter nodes. | |||||
| valid = true; | |||||
| if (IsSwitch(op)) | |||||
| // CondContext.AddValue clears context for Switch nodes. | |||||
| valid = true; | |||||
| } | } | ||||
| else if (IsContainingContext(while_ctxt, input_while_ctxt)) | else if (IsContainingContext(while_ctxt, input_while_ctxt)) | ||||
| { | { | ||||
| @@ -383,7 +383,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name }); | var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name }); | ||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, | public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, | ||||
| @@ -115,7 +115,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | ||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null) | public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null) | ||||
| @@ -629,9 +629,9 @@ namespace Tensorflow | |||||
| public static Tensor _abs(Tensor x, string name = null) | public static Tensor _abs(Tensor x, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Abs", name, new { x }); | |||||
| var _op = _op_def_lib._apply_op_helper("Abs", name, args: new { x }); | |||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | } | ||||
| public static Tensor _any<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null) | public static Tensor _any<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null) | ||||
| @@ -662,14 +662,7 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor _sum(Tensor input, Tensor axis = null, bool keep_dims = false, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor _sum(Tensor input, int axis, bool keep_dims = false, string name = null) | |||||
| public static Tensor _sum<Tx, Ty>(Tx input, Ty axis = default, bool keep_dims = false, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); | var _op = _op_def_lib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); | ||||
| @@ -98,7 +98,8 @@ namespace Tensorflow | |||||
| /// <param name="seed2"></param> | /// <param name="seed2"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null) | |||||
| public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | |||||
| string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("RandomShuffle", | var _op = _op_def_lib._apply_op_helper("RandomShuffle", | ||||
| name: name, | name: name, | ||||
| @@ -116,7 +117,8 @@ namespace Tensorflow | |||||
| /// <param name="seed2"></param> | /// <param name="seed2"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) | |||||
| public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, | |||||
| int? seed2 = 0, string name = null) | |||||
| { | { | ||||
| if (!seed.HasValue) | if (!seed.HasValue) | ||||
| seed = 0; | seed = 0; | ||||
| @@ -127,7 +129,24 @@ namespace Tensorflow | |||||
| name: name, | name: name, | ||||
| args: new { shape, dtype, seed, seed2 }); | args: new { shape, dtype, seed, seed2 }); | ||||
| return _op.outputs[0]; | |||||
| return _op.output; | |||||
| } | |||||
| public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0, | |||||
| int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null) | |||||
| { | |||||
| if (!seed.HasValue) | |||||
| seed = 0; | |||||
| if (!seed2.HasValue) | |||||
| seed2 = 0; | |||||
| if (output_dtype == TF_DataType.DtInvalid) | |||||
| output_dtype = TF_DataType.TF_INT64; | |||||
| var _op = _op_def_lib._apply_op_helper("Multinomial", | |||||
| name: name, | |||||
| args: new { logits, num_samples, seed, seed2, output_dtype }); | |||||
| return _op.output; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,9 +14,11 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using NumSharp; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -102,6 +104,27 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| internal static Tensor resize_images(Tensor images, Tensor size, ResizeMethod method, bool align_corners, bool preserve_aspect_ratio, string name) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public static Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method, float extrapolation_value, string name) | |||||
| { | |||||
| var _op = gen_nn_ops._op_def_lib._apply_op_helper("CropAndResize", name: name, args: new | |||||
| { | |||||
| image, | |||||
| boxes, | |||||
| box_ind, | |||||
| crop_size, | |||||
| method, | |||||
| extrapolation_value | |||||
| }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor is_jpeg(Tensor contents, string name = null) | public static Tensor is_jpeg(Tensor contents, string name = null) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "is_jpeg"), scope => | return tf_with(ops.name_scope(name, "is_jpeg"), scope => | ||||
| @@ -129,22 +152,6 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Resize `images` to `size` using the specified `method`. | |||||
| /// </summary> | |||||
| /// <param name="images"></param> | |||||
| /// <param name="size"></param> | |||||
| /// <param name="method"></param> | |||||
| /// <param name="align_corners"></param> | |||||
| /// <param name="preserve_aspect_ratio"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor resize_images(Tensor images, Tensor size, ResizeMethod method = ResizeMethod.BILINEAR, | |||||
| bool align_corners = false, bool preserve_aspect_ratio = false, string name = null) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Resize `images` to `size` using nearest neighbor interpolation. | /// Resize `images` to `size` using nearest neighbor interpolation. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -31,6 +31,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "Abs", new { x }), scope => | return tf_with(ops.name_scope(name, "Abs", new { x }), scope => | ||||
| { | { | ||||
| name = scope; | |||||
| x = ops.convert_to_tensor(x, name: "x"); | x = ops.convert_to_tensor(x, name: "x"); | ||||
| if (x.dtype.is_complex()) | if (x.dtype.is_complex()) | ||||
| throw new NotImplementedException("math_ops.abs for dtype.is_complex"); | throw new NotImplementedException("math_ops.abs for dtype.is_complex"); | ||||
| @@ -80,6 +81,21 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | |||||
| { | |||||
| var base_type = dtype.as_base_dtype(); | |||||
| return tf_with(ops.name_scope(name, "Cast", new { x }), scope => | |||||
| { | |||||
| name = scope; | |||||
| var x_tensor = ops.convert_to_tensor(x, name: "x"); | |||||
| if (x_tensor.dtype.as_base_dtype() != base_type) | |||||
| x_tensor = gen_math_ops.cast(x_tensor, base_type, name: name); | |||||
| return x_tensor; | |||||
| }); | |||||
| } | |||||
| public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) | public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope => | return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope => | ||||
| @@ -203,6 +219,12 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public static Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) | |||||
| { | |||||
| var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); | |||||
| return _may_reduce_to_scalar(keepdims, axis, m); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes the product of elements across dimensions of a tensor. | /// Computes the product of elements across dimensions of a tensor. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -379,6 +401,13 @@ namespace Tensorflow | |||||
| return _may_reduce_to_scalar(keepdims, axis, max); | return _may_reduce_to_scalar(keepdims, axis, max); | ||||
| } | } | ||||
| public static Tensor reduce_max(Tensor input_tensor, int axis, bool keepdims = false, string name = null) | |||||
| { | |||||
| var r = _ReductionDims(input_tensor, axis); | |||||
| var max = gen_math_ops._max(input_tensor, r, keepdims, name); | |||||
| return _may_reduce_to_scalar(keepdims, axis, max); | |||||
| } | |||||
| public static Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | public static Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | ||||
| { | { | ||||
| var r = _ReductionDims(input_tensor, axis); | var r = _ReductionDims(input_tensor, axis); | ||||
| @@ -434,15 +463,14 @@ namespace Tensorflow | |||||
| public static Tensor reduce_sum(Tensor input_tensor, int[] axis, bool keepdims = false, string name = null) | public static Tensor reduce_sum(Tensor input_tensor, int[] axis, bool keepdims = false, string name = null) | ||||
| { | { | ||||
| var r = _ReductionDims(input_tensor, axis); | |||||
| var m = gen_math_ops._sum(input_tensor, r, keep_dims: keepdims, name: name); | |||||
| var m = gen_math_ops._sum(input_tensor, axis, keep_dims: keepdims, name: name); | |||||
| return _may_reduce_to_scalar(keepdims, axis, m); | return _may_reduce_to_scalar(keepdims, axis, m); | ||||
| } | } | ||||
| public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false, string name = null) | public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false, string name = null) | ||||
| { | { | ||||
| var m = gen_math_ops._sum(input_tensor, axis, keep_dims: keepdims, name: name); | var m = gen_math_ops._sum(input_tensor, axis, keep_dims: keepdims, name: name); | ||||
| return _may_reduce_to_scalar(keepdims, new int[] { axis }, m); | |||||
| return _may_reduce_to_scalar(keepdims, axis, m); | |||||
| } | } | ||||
| private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor axis, Tensor output) | private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor axis, Tensor output) | ||||
| @@ -464,6 +492,11 @@ namespace Tensorflow | |||||
| return output; | return output; | ||||
| } | } | ||||
| private static Tensor _may_reduce_to_scalar(bool keepdims, int axis, Tensor output) | |||||
| { | |||||
| return output; | |||||
| } | |||||
| private static Tensor _ReductionDims(Tensor x, Tensor axis) | private static Tensor _ReductionDims(Tensor x, Tensor axis) | ||||
| { | { | ||||
| if (axis != null) | if (axis != null) | ||||
| @@ -477,6 +510,11 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private static int _ReductionDims(Tensor x, int axis) | |||||
| { | |||||
| return axis; | |||||
| } | |||||
| private static Tensor _ReductionDims(Tensor x, int[] axis) | private static Tensor _ReductionDims(Tensor x, int[] axis) | ||||
| { | { | ||||
| if (axis != null) | if (axis != null) | ||||
| @@ -142,6 +142,35 @@ namespace Tensorflow | |||||
| { | { | ||||
| return ops.convert_to_tensor(shape, name: "shape"); | return ops.convert_to_tensor(shape, name: "shape"); | ||||
| } | } | ||||
| public static Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | |||||
| string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "multinomial", new { logits }), delegate | |||||
| { | |||||
| return multinomial_categorical_impl(logits, num_samples, output_dtype, seed); | |||||
| }); | |||||
| } | |||||
| /// <summary> | |||||
| /// Implementation for random.categorical (v1) and random.categorical (v2). | |||||
| /// </summary> | |||||
| /// <param name="logits"></param> | |||||
| /// <param name="num_samples"></param> | |||||
| /// <param name="output_dtype"></param> | |||||
| /// <param name="seed"></param> | |||||
| /// <returns></returns> | |||||
| private static Tensor multinomial_categorical_impl(Tensor logits, int num_samples, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| int? seed = null) | |||||
| { | |||||
| logits = ops.convert_to_tensor(logits, name: "logits"); | |||||
| var (seed1, seed2) = random_seed.get_seed(seed); | |||||
| return gen_random_ops.multinomial(logits, | |||||
| num_samples, | |||||
| seed: seed1, | |||||
| seed2: seed2, | |||||
| output_dtype: dtype); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Summaries | |||||
| return val; | return val; | ||||
| } | } | ||||
| public Tensor merge_all(string key = "summaries", string scope= null, string name= null) | |||||
| public Tensor merge_all(string key = "summaries", string scope = null, string name = null) | |||||
| { | { | ||||
| var summary_ops = ops.get_collection(key, scope: scope); | var summary_ops = ops.get_collection(key, scope: scope); | ||||
| if (summary_ops == null) | if (summary_ops == null) | ||||
| @@ -1,11 +1,11 @@ | |||||
| <Project Sdk="Microsoft.NET.Sdk"> | <Project Sdk="Microsoft.NET.Sdk"> | ||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <TargetFramework>netstandard2.0</TargetFramework> | |||||
| <TargetFrameworks>net472;netstandard2.0</TargetFrameworks> | |||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <TargetTensorFlow>1.14.0</TargetTensorFlow> | |||||
| <Version>0.12.0</Version> | |||||
| <TargetTensorFlow>1.14.1</TargetTensorFlow> | |||||
| <Version>0.12.1</Version> | |||||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| @@ -18,13 +18,14 @@ | |||||
| <Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
| Building, training and infering deep learning models. | Building, training and infering deep learning models. | ||||
| https://tensorflownet.readthedocs.io</Description> | https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.12.0.0</AssemblyVersion> | |||||
| <AssemblyVersion>0.12.1.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Changes since v0.11.0: | <PackageReleaseNotes>Changes since v0.11.0: | ||||
| 1: Add ICanBeFlattened for nest.flatten2. | 1: Add ICanBeFlattened for nest.flatten2. | ||||
| 2: Complete the WhileContext. | 2: Complete the WhileContext. | ||||
| 3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn.</PackageReleaseNotes> | |||||
| 3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn. | |||||
| 4: Add EstimatorSpec.</PackageReleaseNotes> | |||||
| <LangVersion>7.3</LangVersion> | <LangVersion>7.3</LangVersion> | ||||
| <FileVersion>0.12.0.0</FileVersion> | |||||
| <FileVersion>0.12.1.0</FileVersion> | |||||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
| <SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
| @@ -41,8 +42,14 @@ https://tensorflownet.readthedocs.io</Description> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <Compile Remove="Distribute\**" /> | |||||
| <Compile Remove="Models\**" /> | |||||
| <Compile Remove="runtimes\**" /> | <Compile Remove="runtimes\**" /> | ||||
| <EmbeddedResource Remove="Distribute\**" /> | |||||
| <EmbeddedResource Remove="Models\**" /> | |||||
| <EmbeddedResource Remove="runtimes\**" /> | <EmbeddedResource Remove="runtimes\**" /> | ||||
| <None Remove="Distribute\**" /> | |||||
| <None Remove="Models\**" /> | |||||
| <None Remove="runtimes\**" /> | <None Remove="runtimes\**" /> | ||||
| <None Include="..\..\LICENSE"> | <None Include="..\..\LICENSE"> | ||||
| <Pack>True</Pack> | <Pack>True</Pack> | ||||
| @@ -55,13 +62,11 @@ https://tensorflownet.readthedocs.io</Description> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.10.0" /> | |||||
| <PackageReference Include="Google.Protobuf" Version="3.10.1" /> | |||||
| <PackageReference Include="NumSharp" Version="0.20.4" /> | <PackageReference Include="NumSharp" Version="0.20.4" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <Folder Include="Distribute\" /> | |||||
| <Folder Include="Keras\Initializers\" /> | <Folder Include="Keras\Initializers\" /> | ||||
| <Folder Include="Models\" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| </Project> | </Project> | ||||
| @@ -79,6 +79,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| strides.Add(s.Step); | strides.Add(s.Step); | ||||
| if (s.IsIndex) | |||||
| shrink_axis_mask |= (1 << index); | |||||
| } | } | ||||
| index += 1; | index += 1; | ||||
| @@ -16,6 +16,7 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Numerics; | using System.Numerics; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -25,7 +26,7 @@ namespace Tensorflow | |||||
| public partial class Tensor | public partial class Tensor | ||||
| { | { | ||||
| #if _REGEN | #if _REGEN | ||||
| #region Compute | |||||
| #region Compute | |||||
| %operators = ["add", "sub", "mul", "div", "mod"] | %operators = ["add", "sub", "mul", "div", "mod"] | ||||
| %operators_sign = ["+", "-", "*", "/", "%"] | %operators_sign = ["+", "-", "*", "/", "%"] | ||||
| %operators_comparers = [">", "<", ">=", "<="] | %operators_comparers = [">", "<", ">=", "<="] | ||||
| @@ -49,11 +50,11 @@ namespace Tensorflow | |||||
| % | % | ||||
| % | % | ||||
| public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); | public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); | ||||
| #endregion | |||||
| #endregion | |||||
| #else | #else | ||||
| #region Compute | |||||
| #region Compute | |||||
| public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); | public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); | ||||
| public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); | public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs); | ||||
| public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); | public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs); | ||||
| @@ -281,24 +282,43 @@ namespace Tensorflow | |||||
| public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs); | public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs); | ||||
| public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | ||||
| public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); | public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); | ||||
| #endregion | |||||
| #endregion | |||||
| #endif | #endif | ||||
| private static readonly TF_DataType[] _intTfDataTypes = { | private static readonly TF_DataType[] _intTfDataTypes = { | ||||
| TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, | TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, | ||||
| TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, | TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, | ||||
| TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | ||||
| }; | }; | ||||
| private static string div_or_truediv<Tx, Ty>(string name, Tx x, Ty y) | |||||
| { | |||||
| bool is_floating = false; | |||||
| var types = new List<bool>(); | |||||
| if (x is Tensor t1) | |||||
| types.add(t1.dtype.is_floating()); | |||||
| if (y is Tensor t2) | |||||
| types.add(t2.dtype.is_floating()); | |||||
| is_floating = types.Contains(true); | |||||
| return is_floating ? "truediv" : name; | |||||
| } | |||||
| private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | ||||
| { | { | ||||
| TF_DataType dtype = TF_DataType.DtInvalid; | TF_DataType dtype = TF_DataType.DtInvalid; | ||||
| if (x is Tensor tl) | if (x is Tensor tl) | ||||
| dtype = tl.dtype.as_base_dtype(); | dtype = tl.dtype.as_base_dtype(); | ||||
| if (y is Tensor tr) | if (y is Tensor tr) | ||||
| dtype = tr.dtype.as_base_dtype(); | dtype = tr.dtype.as_base_dtype(); | ||||
| if (name == "div") | |||||
| name = div_or_truediv(name, x, y); | |||||
| return tf_with(ops.name_scope(null, name, new { x, y }), scope => | return tf_with(ops.name_scope(null, name, new { x, y }), scope => | ||||
| { | { | ||||
| Tensor result; | Tensor result; | ||||
| @@ -308,18 +328,16 @@ namespace Tensorflow | |||||
| switch (name.ToLowerInvariant()) | switch (name.ToLowerInvariant()) | ||||
| { | { | ||||
| case "add": | case "add": | ||||
| result = gen_math_ops.add(x1, y1, name: scope); | |||||
| result = math_ops.add(x1, y1, name: scope); | |||||
| break; | break; | ||||
| case "div": | case "div": | ||||
| result = _intTfDataTypes.Contains(x1.dtype) || _intTfDataTypes.Contains(y1.dtype) | |||||
| ? gen_math_ops.floor_div(x1, y1, name: scope) | |||||
| : gen_math_ops.real_div(x1, y1, name: scope); | |||||
| result = math_ops.div(x1, y1, name: scope); | |||||
| break; | break; | ||||
| case "floordiv": | case "floordiv": | ||||
| result = gen_math_ops.floor_div(x1, y1, name: scope); | result = gen_math_ops.floor_div(x1, y1, name: scope); | ||||
| break; | break; | ||||
| case "truediv": | case "truediv": | ||||
| result = gen_math_ops.real_div(x1, y1, name: scope); | |||||
| result = math_ops.truediv(x1, y1, name: scope); | |||||
| break; | break; | ||||
| case "mul": | case "mul": | ||||
| result = gen_math_ops.mul(x1, y1, name: scope); | result = gen_math_ops.mul(x1, y1, name: scope); | ||||
| @@ -102,6 +102,9 @@ namespace Tensorflow | |||||
| [JsonIgnore] | [JsonIgnore] | ||||
| #endif | #endif | ||||
| public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | ||||
| #if SERIALIZABLE | |||||
| [JsonIgnore] | |||||
| #endif | |||||
| public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | ||||
| public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
| #if SERIALIZABLE | #if SERIALIZABLE | ||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| /// <summary> | |||||
| /// A coordinator for threads | |||||
| /// </summary> | |||||
| public class Coordinator | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -1,26 +1,26 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| namespace Tensorflow.Train | namespace Tensorflow.Train | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Optimizer that implements the gradient descent algorithm. | |||||
| /// <summary> | |||||
| /// Optimizer that implements the gradient descent algorithm. | |||||
| /// </summary> | /// </summary> | ||||
| public class GradientDescentOptimizer : Optimizer | public class GradientDescentOptimizer : Optimizer | ||||
| { | |||||
| { | |||||
| /// <summary> | /// <summary> | ||||
| /// Construct a new gradient descent optimizer. | /// Construct a new gradient descent optimizer. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -41,9 +41,9 @@ namespace Tensorflow.Train | |||||
| { | { | ||||
| _lr = learning_rate; | _lr = learning_rate; | ||||
| _useTensor = false; | _useTensor = false; | ||||
| } | |||||
| public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") | |||||
| } | |||||
| public GradientDescentOptimizer(Tensor learning_rate, bool use_locking = false, string name = "GradientDescent") | |||||
| : base(learning_rate, use_locking, name) | : base(learning_rate, use_locking, name) | ||||
| { | { | ||||
| _lr_t = learning_rate; | _lr_t = learning_rate; | ||||
| @@ -52,10 +52,10 @@ namespace Tensorflow.Train | |||||
| public override void _prepare() | public override void _prepare() | ||||
| { | { | ||||
| if(!_useTensor) | |||||
| { | |||||
| var lr = _call_if_callable(_lr); | |||||
| _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||||
| if(!_useTensor) | |||||
| { | |||||
| var lr = _call_if_callable(_lr); | |||||
| _lr_t = ops.convert_to_tensor(lr, name: "learning_rate"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,43 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| public class SecondOrStepTimer : _HookTimer | |||||
| { | |||||
| int _every_secs = 60; | |||||
| int _every_steps = 0; | |||||
| int _last_triggered_step = 0; | |||||
| int _last_triggered_time = 0; | |||||
| public SecondOrStepTimer(int every_secs, int every_steps) | |||||
| { | |||||
| _every_secs = every_secs; | |||||
| _every_steps = every_steps; | |||||
| } | |||||
| public override void reset() | |||||
| { | |||||
| _last_triggered_step = 0; | |||||
| _last_triggered_time = 0; | |||||
| } | |||||
| public override int last_triggered_step() | |||||
| { | |||||
| return _last_triggered_step; | |||||
| } | |||||
| public override bool should_trigger_for_step(int step) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void update_last_triggered_step(int step) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,7 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Train | |||||
| namespace Tensorflow.Training | |||||
| { | { | ||||
| public class SessionRunArgs | public class SessionRunArgs | ||||
| { | { | ||||
| @@ -2,7 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Train | |||||
| namespace Tensorflow.Training | |||||
| { | { | ||||
| public class SessionRunContext | public class SessionRunContext | ||||
| { | { | ||||
| @@ -0,0 +1,52 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| /// <summary> | |||||
| /// Hook to extend calls to MonitoredSession.run(). | |||||
| /// </summary> | |||||
| public abstract class SessionRunHook | |||||
| { | |||||
| /// <summary> | |||||
| /// Called once before using the session. | |||||
| /// </summary> | |||||
| public virtual void begin() | |||||
| { | |||||
| } | |||||
| /// <summary> | |||||
| /// Called when new TensorFlow session is created. | |||||
| /// </summary> | |||||
| /// <param name="session"></param> | |||||
| /// <param name="coord"></param> | |||||
| public virtual void after_create_session(Session session, Coordinator coord) | |||||
| { | |||||
| } | |||||
| /// <summary> | |||||
| /// Called before each call to run(). | |||||
| /// </summary> | |||||
| /// <param name="run_context"></param> | |||||
| public virtual void before_run(SessionRunContext run_context) | |||||
| { | |||||
| } | |||||
| /// <summary> | |||||
| /// Called after each call to run(). | |||||
| /// </summary> | |||||
| public virtual void after_run(SessionRunContext run_context, SessionRunValues run_values) | |||||
| { | |||||
| } | |||||
| /// <summary> | |||||
| /// Called at the end of session. | |||||
| /// </summary> | |||||
| public virtual void end(Session session) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,7 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Train | |||||
| namespace Tensorflow.Training | |||||
| { | { | ||||
| public class SessionRunValues | public class SessionRunValues | ||||
| { | { | ||||
| @@ -16,7 +16,7 @@ namespace Tensorflow.Train | |||||
| // Create in proper graph and base name_scope. | // Create in proper graph and base name_scope. | ||||
| var g = graph.as_default(); | var g = graph.as_default(); | ||||
| g.name_scope(null); | g.name_scope(null); | ||||
| var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64, | |||||
| var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new int[0], dtype: dtypes.int64, | |||||
| initializer: tf.zeros_initializer, | initializer: tf.zeros_initializer, | ||||
| trainable: false, | trainable: false, | ||||
| aggregation: VariableAggregation.OnlyFirstReplica, | aggregation: VariableAggregation.OnlyFirstReplica, | ||||
| @@ -0,0 +1,38 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| /// <summary> | |||||
| /// Base timer for determining when Hooks should trigger. | |||||
| /// </summary> | |||||
| public abstract class _HookTimer | |||||
| { | |||||
| /// <summary> | |||||
| /// Resets the timer. | |||||
| /// </summary> | |||||
| public abstract void reset(); | |||||
| /// <summary> | |||||
| /// Return true if the timer should trigger for the specified step. | |||||
| /// </summary> | |||||
| /// <param name="step"></param> | |||||
| /// <returns></returns> | |||||
| public abstract bool should_trigger_for_step(int step); | |||||
| /// <summary> | |||||
| /// Update the last triggered time and step number. | |||||
| /// </summary> | |||||
| /// <param name="step"></param> | |||||
| public abstract void update_last_triggered_step(int step); | |||||
| /// <summary> | |||||
| /// Returns the last triggered time step or None if never triggered. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public abstract int last_triggered_step(); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,29 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| public class learning_rate_decay | |||||
| { | |||||
| /// <summary> | |||||
| /// Applies a polynomial decay to the learning rate. | |||||
| /// </summary> | |||||
| /// <param name="learning_rate"></param> | |||||
| /// <param name="global_step"></param> | |||||
| /// <param name="decay_steps"></param> | |||||
| /// <param name="end_learning_rate"></param> | |||||
| /// <param name="power"></param> | |||||
| /// <param name="cycle"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor polynomial_decay(float learning_rate, RefVariable global_step, float decay_steps, | |||||
| float end_learning_rate = 0.0001f, float power = 1.0f, bool cycle = false, | |||||
| string name = null) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -43,7 +43,7 @@ namespace Tensorflow | |||||
| protected Graph _graph; | protected Graph _graph; | ||||
| bool _building_function; | bool _building_function; | ||||
| public variable_scope(string name, | |||||
| public variable_scope(string name, | |||||
| string default_name = "", | string default_name = "", | ||||
| Tensor[] values = null, | Tensor[] values = null, | ||||
| bool? reuse = null, | bool? reuse = null, | ||||
| @@ -113,7 +113,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| // Reenter the current name scope | // Reenter the current name scope | ||||
| string name_scope = ops.get_name_scope(); | string name_scope = ops.get_name_scope(); | ||||
| if(!string.IsNullOrEmpty(name_scope)) | |||||
| if (!string.IsNullOrEmpty(name_scope)) | |||||
| // Hack to reenter | // Hack to reenter | ||||
| name_scope += "/"; | name_scope += "/"; | ||||
| current_name_scope = ops.name_scope(name_scope); | current_name_scope = ops.name_scope(name_scope); | ||||
| @@ -128,8 +128,8 @@ namespace Tensorflow | |||||
| string current_name_scope_name = current_name_scope; | string current_name_scope_name = current_name_scope; | ||||
| _current_name_scope = current_name_scope; | _current_name_scope = current_name_scope; | ||||
| string old_name_scope = _scope == null ? current_name_scope_name : _scope.original_name_scope; | string old_name_scope = _scope == null ? current_name_scope_name : _scope.original_name_scope; | ||||
| if(_scope == null) | |||||
| if (_scope == null) | |||||
| pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); | pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); | ||||
| else | else | ||||
| pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope); | pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope); | ||||
| @@ -179,7 +179,7 @@ namespace Tensorflow | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| int[] shape = null, | int[] shape = null, | ||||
| bool validate_shape = false, | bool validate_shape = false, | ||||
| bool ? use_resource = null, | |||||
| bool? use_resource = null, | |||||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | VariableSynchronization synchronization = VariableSynchronization.Auto, | ||||
| VariableAggregation aggregation = VariableAggregation.None) | VariableAggregation aggregation = VariableAggregation.None) | ||||
| { | { | ||||
| @@ -189,7 +189,7 @@ namespace Tensorflow | |||||
| use_resource = get_variable_scope().use_resource; | use_resource = get_variable_scope().use_resource; | ||||
| } | } | ||||
| if(!use_resource.HasValue) | |||||
| if (!use_resource.HasValue) | |||||
| use_resource = _DEFAULT_USE_RESOURCE; | use_resource = _DEFAULT_USE_RESOURCE; | ||||
| if (use_resource.Value) | if (use_resource.Value) | ||||
| @@ -204,7 +204,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return new RefVariable(initial_value, | |||||
| return new RefVariable(initial_value, | |||||
| trainable: trainable.Value, | trainable: trainable.Value, | ||||
| validate_shape: validate_shape, | validate_shape: validate_shape, | ||||
| collections: collections, | collections: collections, | ||||
| @@ -251,7 +251,7 @@ namespace Tensorflow | |||||
| default: | default: | ||||
| throw new InvalidOperationException("get_variable_scope_store"); | throw new InvalidOperationException("get_variable_scope_store"); | ||||
| } | } | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -271,7 +271,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| trainable = true; | trainable = true; | ||||
| } | } | ||||
| return trainable.Value; | return trainable.Value; | ||||
| } | } | ||||
| @@ -294,7 +294,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| // TODO for Switch/Case | // TODO for Switch/Case | ||||
| public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, | |||||
| public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, | |||||
| TensorShape shape = null, | TensorShape shape = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| bool trainable = false, | bool trainable = false, | ||||
| @@ -305,12 +305,12 @@ namespace Tensorflow | |||||
| public void __init__() | public void __init__() | ||||
| { | { | ||||
| } | } | ||||
| public void __del__() | public void __del__() | ||||
| { | { | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -41,13 +41,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| var all = new List<VariableV1>(); | var all = new List<VariableV1>(); | ||||
| var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| if(collection != null) | |||||
| all.AddRange(collection as List<VariableV1>); | |||||
| collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope); | |||||
| if (collection != null) | |||||
| all.AddRange(collection as List<VariableV1>); | |||||
| all.AddRange(ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope)); | |||||
| all.AddRange(ops.get_collection<VariableV1>(tf.GraphKeys.SAVEABLE_OBJECTS, scope)); | |||||
| return all.ToArray(); | return all.ToArray(); | ||||
| } | } | ||||
| @@ -65,9 +60,8 @@ namespace Tensorflow | |||||
| /// <returns>A list of `Variable` objects.</returns> | /// <returns>A list of `Variable` objects.</returns> | ||||
| public static List<VariableV1> global_variables(string scope = null) | public static List<VariableV1> global_variables(string scope = null) | ||||
| { | { | ||||
| var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| return ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| return result == null ? new List<VariableV1>() : result as List<VariableV1>; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -1,24 +0,0 @@ | |||||
| using System; | |||||
| namespace TensorFlowDatasets | |||||
| { | |||||
| /// <summary> | |||||
| /// Abstract base class for all datasets. | |||||
| /// </summary> | |||||
| public class DatasetBuilder | |||||
| { | |||||
| /// <summary> | |||||
| /// Downloads and prepares dataset for reading. | |||||
| /// </summary> | |||||
| /// <param name="download_dir"> | |||||
| /// directory where downloaded files are stored. | |||||
| /// </param> | |||||
| /// <param name="download_config"> | |||||
| /// further configuration for downloading and preparing dataset. | |||||
| /// </param> | |||||
| public void download_and_prepare(string download_dir = null, DownloadConfig download_config = null) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,10 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace TensorFlowDatasets | |||||
| { | |||||
| public class DownloadConfig | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -1,21 +0,0 @@ | |||||
| <Project Sdk="Microsoft.NET.Sdk"> | |||||
| <PropertyGroup> | |||||
| <TargetFramework>netcoreapp2.2</TargetFramework> | |||||
| <PackageId>SciSharp.TensorFlowDatasets</PackageId> | |||||
| <Version>0.0.1</Version> | |||||
| <Authors>SciSharp Team</Authors> | |||||
| <Product>TensorFlow Datasets</Product> | |||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||||
| <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | |||||
| <PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl> | |||||
| <Description>TensorFlow Datasets provides many public datasets as tf.data.Datasets.</Description> | |||||
| <RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | |||||
| <RepositoryType>git</RepositoryType> | |||||
| <PackageTags>SciSharp, Dataset, TensorFlow</PackageTags> | |||||
| <Copyright>Apache 2.0</Copyright> | |||||
| <RootNamespace>TensorFlow.Datasets</RootNamespace> | |||||
| <AssemblyName>TensorFlow.Datasets</AssemblyName> | |||||
| </PropertyGroup> | |||||
| </Project> | |||||
| @@ -1,13 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public abstract class DataSetBase : IDataSet | |||||
| { | |||||
| public NDArray Data { get; protected set; } | |||||
| public NDArray Labels { get; protected set; } | |||||
| } | |||||
| } | |||||
| @@ -1,46 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public class Datasets<TDataSet> where TDataSet : IDataSet | |||||
| { | |||||
| public TDataSet Train { get; private set; } | |||||
| public TDataSet Validation { get; private set; } | |||||
| public TDataSet Test { get; private set; } | |||||
| public Datasets(TDataSet train, TDataSet validation, TDataSet test) | |||||
| { | |||||
| Train = train; | |||||
| Validation = validation; | |||||
| Test = test; | |||||
| } | |||||
| public (NDArray, NDArray) Randomize(NDArray x, NDArray y) | |||||
| { | |||||
| var perm = np.random.permutation(y.shape[0]); | |||||
| np.random.shuffle(perm); | |||||
| return (x[perm], y[perm]); | |||||
| } | |||||
| /// <summary> | |||||
| /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) | |||||
| /// </summary> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="y"></param> | |||||
| /// <param name="start"></param> | |||||
| /// <param name="end"></param> | |||||
| /// <returns></returns> | |||||
| public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) | |||||
| { | |||||
| var slice = new Slice(start, end); | |||||
| var x_batch = x[slice]; | |||||
| var y_batch = y[slice]; | |||||
| return (x_batch, y_batch); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,13 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public interface IDataSet | |||||
| { | |||||
| NDArray Data { get; } | |||||
| NDArray Labels { get; } | |||||
| } | |||||
| } | |||||
| @@ -1,14 +0,0 @@ | |||||
| using System; | |||||
| using System.Threading.Tasks; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public interface IModelLoader<TDataSet> | |||||
| where TDataSet : IDataSet | |||||
| { | |||||
| Task<Datasets<TDataSet>> LoadAsync(ModelLoadSetting setting); | |||||
| } | |||||
| } | |||||