| @@ -26,12 +26,12 @@ In comparison to other projects, like for instance [TensorFlowSharp](https://www | |||
| ### How to use | |||
| | TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 | | |||
| | -------------------------- | ------------- | -------------- | ------------- | | |||
| | tf.net 0.3x, tf.keras 0.2 | | | x | | |||
| | tf.net 0.2x | | x | x | | |||
| | tf.net 0.15 | x | x | | | |||
| | tf.net 0.14 | x | | | | |||
| | TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 | | |||
| | -------------------------- | ------------- | -------------- | ------------- | ------------- | | |||
| | tf.net 0.3x, tf.keras 0.2 | | | x | not compatible | | |||
| | tf.net 0.2x | | x | x | | | |||
| | tf.net 0.15 | x | x | | | | |||
| | tf.net 0.14 | x | | | | | |||
| Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md). | |||
| @@ -22,11 +22,19 @@ https://www.nuget.org/packages/SciSharp.TensorFlow.Redist | |||
| Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77). | |||
| #### Download pre-build package | |||
| [Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip) | |||
| #### Pack and Deploy #### | |||
| On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. | |||
| 1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. | |||
| 2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.3.1.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||
| 2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||
| @@ -8,7 +8,7 @@ | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -574,7 +574,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||
| return string.Join(string.Empty, nd.ToArray<byte>() | |||
| .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); | |||
| case TF_DataType.TF_BOOL: | |||
| return (nd.GetByte(0) > 0).ToString(); | |||
| return nd.GetBoolean(0).ToString(); | |||
| case TF_DataType.TF_VARIANT: | |||
| case TF_DataType.TF_RESOURCE: | |||
| return "<unprintable>"; | |||
| @@ -37,19 +37,38 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| _steps_per_execution_value = args.StepsPerExecution.numpy(); | |||
| } | |||
| _adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs | |||
| if(args.Dataset == null) | |||
| { | |||
| X = args.X, | |||
| Y = args.Y, | |||
| BatchSize = args.BatchSize, | |||
| Steps = args.StepsPerEpoch, | |||
| Epochs = args.Epochs - args.InitialEpoch, | |||
| Shuffle = args.Shuffle, | |||
| MaxQueueSize = args.MaxQueueSize, | |||
| Worker = args.Workers, | |||
| UseMultiprocessing = args.UseMultiprocessing, | |||
| Model = args.Model | |||
| }); | |||
| _adapter = new TensorLikeDataAdapter(new DataAdapterArgs | |||
| { | |||
| X = args.X, | |||
| Y = args.Y, | |||
| BatchSize = args.BatchSize, | |||
| Steps = args.StepsPerEpoch, | |||
| Epochs = args.Epochs - args.InitialEpoch, | |||
| Shuffle = args.Shuffle, | |||
| MaxQueueSize = args.MaxQueueSize, | |||
| Worker = args.Workers, | |||
| UseMultiprocessing = args.UseMultiprocessing, | |||
| Model = args.Model | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| _adapter = new DatasetAdapter(new DataAdapterArgs | |||
| { | |||
| Dataset = args.Dataset, | |||
| BatchSize = args.BatchSize, | |||
| Steps = args.StepsPerEpoch, | |||
| Epochs = args.Epochs - args.InitialEpoch, | |||
| Shuffle = args.Shuffle, | |||
| MaxQueueSize = args.MaxQueueSize, | |||
| Worker = args.Workers, | |||
| UseMultiprocessing = args.UseMultiprocessing, | |||
| Model = args.Model | |||
| }); | |||
| } | |||
| _dataset = _adapter.GetDataset(); | |||
| _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||
| _current_step = 0; | |||
| @@ -66,7 +85,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| if (adapter_steps > -1) | |||
| return adapter_steps; | |||
| throw new NotImplementedException(""); | |||
| var size = dataset.dataset_cardinality(); | |||
| return size.numpy(); | |||
| } | |||
| public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | |||
| @@ -0,0 +1,35 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| namespace Tensorflow.Keras.Engine.DataAdapters | |||
| { | |||
| public class DatasetAdapter : IDataAdapter | |||
| { | |||
| DataAdapterArgs args; | |||
| IDatasetV2 _dataset => args.Dataset; | |||
| public DatasetAdapter(DataAdapterArgs args) | |||
| { | |||
| this.args = args; | |||
| } | |||
| public bool CanHandle(Tensor x, Tensor y = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public IDatasetV2 GetDataset() | |||
| => _dataset; | |||
| public int GetSize() | |||
| => -1; | |||
| public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||
| { | |||
| if (y.TensorShape.ndim == 1) | |||
| y = array_ops.expand_dims(y, axis: -1); | |||
| return (x, y); | |||
| } | |||
| } | |||
| } | |||
| @@ -9,14 +9,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| /// </summary> | |||
| public class TensorLikeDataAdapter : IDataAdapter | |||
| { | |||
| TensorLikeDataAdapterArgs args; | |||
| DataAdapterArgs args; | |||
| int _size; | |||
| int _batch_size; | |||
| int num_samples; | |||
| int num_full_batches; | |||
| IDatasetV2 _dataset; | |||
| public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) | |||
| public TensorLikeDataAdapter(DataAdapterArgs args) | |||
| { | |||
| this.args = args; | |||
| _process_tensorlike(); | |||
| @@ -39,10 +39,12 @@ namespace Tensorflow.Keras.Engine | |||
| _input_coordinates = new List<KerasHistory>(); | |||
| _output_coordinates = new List<KerasHistory>(); | |||
| tensor_usage_count = new Dictionary<int, int>(); | |||
| if (this is Sequential) | |||
| return; | |||
| _init_graph_network(inputs, outputs); | |||
| } | |||
| void _init_graph_network(Tensors inputs, Tensors outputs) | |||
| protected void _init_graph_network(Tensors inputs, Tensors outputs) | |||
| { | |||
| _is_graph_network = true; | |||
| this.inputs = inputs; | |||
| @@ -9,10 +9,6 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| LossesContainer compiled_loss; | |||
| MetricsContainer compiled_metrics; | |||
| public void compile(string optimizerName, ILossFunc lossName) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||
| { | |||
| @@ -29,12 +25,12 @@ namespace Tensorflow.Keras.Engine | |||
| this.loss = loss; | |||
| } | |||
| public void compile(string optimizerName, string lossName) | |||
| public void compile(string optimizer, string loss, string[] metrics) | |||
| { | |||
| switch (optimizerName) | |||
| switch (optimizer) | |||
| { | |||
| case "rmsprop": | |||
| optimizer = new RMSprop(new RMSpropArgs | |||
| this.optimizer = new RMSprop(new RMSpropArgs | |||
| { | |||
| }); | |||
| @@ -68,5 +68,49 @@ namespace Tensorflow.Keras.Engine | |||
| Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||
| } | |||
| } | |||
| public void fit(IDatasetV2 dataset, | |||
| IDatasetV2 validation_data = null, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| float validation_split = 0f, | |||
| bool shuffle = true, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| bool use_multiprocessing = false) | |||
| { | |||
| data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| Dataset = dataset, | |||
| BatchSize = batch_size, | |||
| InitialEpoch = initial_epoch, | |||
| Epochs = epochs, | |||
| Shuffle = shuffle, | |||
| MaxQueueSize = max_queue_size, | |||
| Workers = workers, | |||
| UseMultiprocessing = use_multiprocessing, | |||
| Model = this, | |||
| StepsPerExecution = _steps_per_execution | |||
| }); | |||
| stop_training = false; | |||
| _train_counter.assign(0); | |||
| Console.WriteLine($"Training..."); | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| // reset_metrics(); | |||
| // callbacks.on_epoch_begin(epoch) | |||
| // data_handler.catch_stop_iteration(); | |||
| IEnumerable<(string, Tensor)> results = null; | |||
| foreach (var step in data_handler.steps()) | |||
| { | |||
| // callbacks.on_train_batch_begin(step) | |||
| results = step_function(iterator); | |||
| } | |||
| Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -35,7 +35,7 @@ namespace Tensorflow.Keras.Engine | |||
| public int[] node_indices; | |||
| public int[] tensor_indices; | |||
| public Tensors input_tensors => args.InputTensors; | |||
| public Tensors input_tensors => is_input ? Outputs : args.InputTensors; | |||
| public Tensors Outputs => args.Outputs; | |||
| public TensorShape[] input_shapes; | |||
| public TensorShape[] output_shapes; | |||
| @@ -17,6 +17,7 @@ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.KerasApi; | |||
| namespace Tensorflow.Keras.Engine | |||
| @@ -25,36 +26,40 @@ namespace Tensorflow.Keras.Engine | |||
| /// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. | |||
| /// `Sequential` provides training and inference features on this model. | |||
| /// </summary> | |||
| public class Sequential : Model | |||
| public class Sequential : Functional | |||
| { | |||
| SequentialArgs args; | |||
| bool _is_graph_network; | |||
| Tensor inputs; | |||
| Tensor outputs; | |||
| bool computeOutputAndMaskJointly; | |||
| bool autoTrackSubLayers; | |||
| TensorShape inferredInputShape; | |||
| bool hasExplicitInputShape; | |||
| TF_DataType inputDType; | |||
| List<ILayer> layers => args.Layers; | |||
| public TensorShape output_shape => outputs.TensorShape; | |||
| Tensors inputs; | |||
| Tensors outputs; | |||
| bool _compute_output_and_mask_jointly; | |||
| bool _auto_track_sub_layers; | |||
| TensorShape _inferred_input_shape; | |||
| bool _has_explicit_input_shape; | |||
| TF_DataType _input_dtype; | |||
| public TensorShape output_shape => outputs[0].TensorShape; | |||
| bool built = false; | |||
| public Sequential(SequentialArgs args) | |||
| : base(new ModelArgs | |||
| { | |||
| Name = args.Name | |||
| }) | |||
| : base(args.Inputs, args.Outputs, name: args.Name) | |||
| { | |||
| this.args = args; | |||
| if (args.Layers == null) | |||
| args.Layers = new List<ILayer>(); | |||
| // SupportsMasking = true; | |||
| computeOutputAndMaskJointly = true; | |||
| autoTrackSubLayers = false; | |||
| hasExplicitInputShape = false; | |||
| _compute_output_and_mask_jointly = true; | |||
| _auto_track_sub_layers = false; | |||
| _has_explicit_input_shape = false; | |||
| _is_graph_network = false; | |||
| // Add to the model any layers passed to the constructor. | |||
| if (args.Layers != null) | |||
| { | |||
| foreach (var layer in args.Layers) | |||
| add(layer as Layer); | |||
| } | |||
| } | |||
| public void add(Tensor tensor) | |||
| @@ -71,7 +76,7 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| built = false; | |||
| var set_inputs = false; | |||
| if (layers.Count == 0) | |||
| if (_layers.Count == 0) | |||
| { | |||
| if (layer is InputLayer) | |||
| { | |||
| @@ -83,7 +88,7 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| // Instantiate an input layer. | |||
| var x = keras.Input( | |||
| shape: layer.BatchInputShape, | |||
| batch_input_shape: layer.BatchInputShape, | |||
| dtype: layer.DType, | |||
| name: layer.Name + "_input"); | |||
| @@ -99,36 +104,26 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| // If an input layer (placeholder) is available. | |||
| outputs = layer.InboundNodes[^1].Outputs; | |||
| inputs = layer_utils.get_source_inputs(outputs[0]); | |||
| built = true; | |||
| _has_explicit_input_shape = true; | |||
| } | |||
| } | |||
| else if (outputs != null) | |||
| { | |||
| outputs = layer.Apply(outputs); | |||
| built = true; | |||
| } | |||
| if (set_inputs || _is_graph_network) | |||
| { | |||
| _init_graph_network(inputs, outputs); | |||
| _is_graph_network = true; | |||
| } | |||
| else | |||
| { | |||
| } | |||
| } | |||
| void _init_graph_network(Tensor inputs, Tensor outputs) | |||
| { | |||
| _is_graph_network = true; | |||
| this.inputs = inputs; | |||
| this.outputs = outputs; | |||
| built = true; | |||
| _map_graph_network(inputs, outputs); | |||
| } | |||
| void _map_graph_network(Tensor inputs, Tensor outputs) | |||
| { | |||
| layers.add(outputs.KerasHistory.Layer); | |||
| } | |||
| } | |||
| } | |||
| @@ -62,16 +62,21 @@ namespace Tensorflow.Keras | |||
| /// <returns></returns> | |||
| public Tensor Input(TensorShape shape = null, | |||
| int batch_size = -1, | |||
| TensorShape batch_input_shape = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string name = null, | |||
| bool sparse = false, | |||
| bool ragged = false, | |||
| Tensor tensor = null) | |||
| { | |||
| if (batch_input_shape != null) | |||
| shape = batch_input_shape.dims[1..]; | |||
| var args = new InputLayerArgs | |||
| { | |||
| Name = name, | |||
| InputShape = shape, | |||
| BatchInputShape = batch_input_shape, | |||
| BatchSize = batch_size, | |||
| DType = dtype, | |||
| Sparse = sparse, | |||
| @@ -23,5 +23,10 @@ namespace Tensorflow.Keras.Layers | |||
| offset = math_ops.cast(args.Offset, args.DType); | |||
| return math_ops.cast(inputs, args.DType) * scale + offset; | |||
| } | |||
| public override TensorShape ComputeOutputShape(TensorShape input_shape) | |||
| { | |||
| return input_shape; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| @@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Layers | |||
| public Flatten(FlattenArgs args) | |||
| : base(args) | |||
| { | |||
| this.args = args; | |||
| args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); | |||
| input_spec = new InputSpec(min_ndim: 1); | |||
| _channels_first = args.DataFormat == "channels_first"; | |||
| @@ -31,8 +33,29 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 }); | |||
| } | |||
| else | |||
| { | |||
| var input_shape = inputs.shape; | |||
| var rank = inputs.shape.rank; | |||
| if (rank == 1) | |||
| return array_ops.expand_dims(inputs, axis: 1); | |||
| var batch_dim = tensor_shape.dimension_value(input_shape[0]); | |||
| if (batch_dim != -1) | |||
| { | |||
| return array_ops.reshape(inputs, new[] { batch_dim, -1 }); | |||
| } | |||
| throw new NotImplementedException(""); | |||
| var non_batch_dims = ((int[])input_shape)[1..]; | |||
| var num = 1; | |||
| if (non_batch_dims.Length > 0) | |||
| { | |||
| for (var i = 0; i < non_batch_dims.Length; i++) | |||
| { | |||
| num *= non_batch_dims[i]; | |||
| } | |||
| } | |||
| return array_ops.reshape(inputs, new[] { inputs.shape[0], num }); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -31,7 +31,7 @@ namespace Tensorflow.Keras | |||
| img = tf.image.decode_image( | |||
| img, channels: num_channels, expand_animations: false); | |||
| img = tf.image.resize_images_v2(img, image_size, method: interpolation); | |||
| img.set_shape((image_size[0], image_size[1], num_channels)); | |||
| // img.set_shape((image_size[0], image_size[1], num_channels)); | |||
| return img; | |||
| } | |||
| } | |||
| @@ -187,5 +187,34 @@ namespace Tensorflow.Keras.Utils | |||
| var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum(); | |||
| return total; | |||
| } | |||
| public static Tensors get_source_inputs(Tensor tensor, ILayer layer = null, int node_index = -1) | |||
| { | |||
| if (layer == null) | |||
| (layer, node_index, _) = tensor.KerasHistory; | |||
| if (layer.InboundNodes == null || layer.InboundNodes.Count == 0) | |||
| return tensor; | |||
| else | |||
| { | |||
| var node = layer.InboundNodes[node_index]; | |||
| if (node.is_input) | |||
| return node.input_tensors; | |||
| else | |||
| { | |||
| var source_tensors = new List<Tensor>(); | |||
| foreach (var _layer in node.iterate_inbound()) | |||
| { | |||
| (layer, node_index, tensor) = (_layer.Item1, _layer.Item2, _layer.Item4); | |||
| var previous_sources = get_source_inputs(tensor, layer, node_index); | |||
| foreach(var x in previous_sources) | |||
| { | |||
| // should be check if exist? | |||
| source_tensors.append(x); | |||
| } | |||
| } | |||
| return source_tensors; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -24,7 +24,7 @@ More information about [System.Drawing on Linux](<https://www.hanselman.com/blog | |||
| Before running verify you installed CUDA and cuDNN (TensorFlow v1.15 is compatible with CUDA v10.0 and cuDNN v7.4 , TensorFlow v2.x is compatible with CUDA v10.2 and cuDNN v7.65), and make sure the corresponding cuda version is compatible. | |||
| #### Mac OS | |||
| There is no GPU support for macOS. | |||
| There is no GPU support for macOS, in the future TensorFlow will support [Apple M1 chip](https://github.com/apple/tensorflow_macos). | |||
| #### GPU for Windows | |||
| @@ -37,9 +37,11 @@ PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | |||
| PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU | |||
| ``` | |||
| Since NuGet limits file size for 250M, we can't ship Linux GPU version as NuGet, you can download the library from [Google TensorFlow Storage](https://storage.googleapis.com/tensorflow). | |||
| ### Download prebuild binary manually | |||
| Tensorflow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built. | |||
| TensorFlow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built. | |||
| ### Build from source for Windows | |||
| @@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| public void decode_image() | |||
| { | |||
| var img = tf.image.decode_image(contents); | |||
| Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0"); | |||
| Assert.AreEqual(img.name, "decode_image/Identity:0"); | |||
| } | |||
| [TestMethod] | |||
| @@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.Keras | |||
| { 2, 3, 4, 5 }, | |||
| { 3, 4, 5, 6 } | |||
| }); | |||
| model.compile("rmsprop", "mse"); | |||
| // model.compile("rmsprop", "mse"); | |||
| var output_array = model.predict(input_array); | |||
| Assert.AreEqual((32, 10, 64), output_array.TensorShape); | |||
| } | |||
| @@ -48,10 +48,10 @@ | |||
| <ItemGroup> | |||
| <PackageReference Include="FluentAssertions" Version="5.10.3" /> | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.1" /> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.8.3" /> | |||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | |||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||