| @@ -15,20 +15,6 @@ | |||
| English | [中文](docs/README-CN.md) | |||
| **=========================================================** | |||
| ### [Voting: Naming Convention Approach of v1.0.0](https://github.com/SciSharp/TensorFlow.NET/issues/1074) | |||
| Dear all, | |||
| We would like to urge you to participate in our upcoming vote regarding the naming convention for TensorFlow.NET version 1.0.0 in [#1074](https://github.com/SciSharp/TensorFlow.NET/issues/1074). Your participation in the vote is essential to help us decide on the best approach for improving the naming convention used in previous versions. | |||
| Thank you, | |||
| TensorFlow.NET Authors | |||
| **=========================================================** | |||
| *master branch and v0.100.x is corresponding to tensorflow v2.10, v0.6x branch is from tensorflow v2.6, v0.15-tensorflow1.15 is from tensorflow1.15. Please add `https://www.myget.org/F/scisharp/api/v3/index.json` to nuget source to use nightly release.* | |||
| @@ -75,9 +61,12 @@ PM> Install-Package TensorFlow.Keras | |||
| The second part is the computing support part. Only one of the following packages is needed, depending on your device and system. | |||
| ``` | |||
| ### CPU version for Windows, Linux and Mac | |||
| ### CPU version for Windows and Linux | |||
| PM> Install-Package SciSharp.TensorFlow.Redist | |||
| ### CPU version for MacOS | |||
| PM> Install-Package SciSharp.TensorFlow.Redist-OSX | |||
| ### GPU version for Windows (CUDA and cuDNN are required) | |||
| PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | |||
| @@ -39,6 +39,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Benchmark", "too | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Console", "tools\TensorFlowNET.Console\Tensorflow.Console.csproj", "{1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}" | |||
| EndProject | |||
| Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorFlow.Kernel.UnitTest", "test\TensorFlow.Kernel.UnitTest\TensorFlow.Kernel.UnitTest.csproj", "{654A027D-1364-4729-880B-144DFE1FF5BB}" | |||
| EndProject | |||
| Global | |||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
| Debug|Any CPU = Debug|Any CPU | |||
| @@ -322,6 +324,24 @@ Global | |||
| {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x64.Build.0 = Release|x64 | |||
| {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x86.Build.0 = Release|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x86.Build.0 = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|Any CPU.Build.0 = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x64.ActiveCfg = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x64.Build.0 = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x86.ActiveCfg = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x86.Build.0 = Debug|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x64.ActiveCfg = Release|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x64.Build.0 = Release|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x86.Build.0 = Release|Any CPU | |||
| EndGlobalSection | |||
| GlobalSection(SolutionProperties) = preSolution | |||
| HideSolutionNode = FALSE | |||
| @@ -342,6 +362,7 @@ Global | |||
| {D24FCAA5-548C-4251-B226-A1B6535D0845} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} | |||
| {C23563DB-FE21-48E7-A411-87A109E4A899} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} | |||
| {1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0} = {E1A5D2B7-10AF-4876-85C0-7714EF274214} | |||
| {654A027D-1364-4729-880B-144DFE1FF5BB} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD} | |||
| EndGlobalSection | |||
| GlobalSection(ExtensibilityGlobals) = postSolution | |||
| SolutionGuid = {2DEAD3CC-486B-4918-A607-50B0DE7B114A} | |||
| @@ -8,10 +8,10 @@ namespace Tensorflow | |||
| public partial class c_api | |||
| { | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | |||
| public static extern void TF_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| public static extern SafeBufferHandle TF_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||
| public static extern void TF_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -140,6 +140,16 @@ namespace Tensorflow | |||
| public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | |||
| => array_ops.gather(@params, indices, name: name, axis: ops.convert_to_tensor(axis)); | |||
| /// <summary> | |||
| /// Gather slices from `params` into a Tensor with shape specified by `indices`. | |||
| /// </summary> | |||
| /// <param name="params"></param> | |||
| /// <param name="indices"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor gather_nd(Tensor @params, Tensor indices, string name = null) | |||
| => gen_array_ops.gather_nd(@params, indices, name: name); | |||
| /// <summary> | |||
| /// Return the elements, either from `x` or `y`, depending on the `condition`. | |||
| /// </summary> | |||
| @@ -339,6 +339,13 @@ namespace Tensorflow | |||
| => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, | |||
| name: name, expand_animations: expand_animations); | |||
| public Tensor encode_png(Tensor contents, string name = null) | |||
| => image_ops_impl.encode_png(contents, name: name); | |||
| public Tensor encode_jpeg(Tensor contents, string name = null) | |||
| => image_ops_impl.encode_jpeg(contents, name: name); | |||
| /// <summary> | |||
| /// Convenience function to check if the 'contents' encodes a JPEG image. | |||
| /// </summary> | |||
| @@ -16,6 +16,7 @@ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.IO; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -46,6 +47,12 @@ namespace Tensorflow | |||
| public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, | |||
| string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||
| => ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); | |||
| public Operation write_file(string filename, Tensor conentes, string name = null) | |||
| => write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name); | |||
| public Operation write_file(Tensor filename, Tensor conentes, string name = null) | |||
| => gen_ops.write_file(filename, conentes, name); | |||
| } | |||
| public GFile gfile = new GFile(); | |||
| @@ -101,6 +101,8 @@ namespace Tensorflow | |||
| name: name); | |||
| public IActivation relu() => new relu(); | |||
| public IActivation swish() => new swish(); | |||
| public IActivation tanh() => new tanh(); | |||
| @@ -111,6 +113,9 @@ namespace Tensorflow | |||
| public Tensor relu(Tensor features, string name = null) | |||
| => gen_nn_ops.relu(features, name); | |||
| public Tensor relu6(Tensor features, string name = null) | |||
| => gen_nn_ops.relu6(features, name); | |||
| public Tensor[] fused_batch_norm(Tensor x, | |||
| Tensor scale, | |||
| Tensor offset, | |||
| @@ -80,6 +80,11 @@ namespace Tensorflow.Eager | |||
| Tensor[] op_outputs) | |||
| => (out_grads, unneeded_gradients) => | |||
| { | |||
| if(!ops.gradientFunctions.ContainsKey(op_name)) | |||
| { | |||
| throw new Exception($"gradientFunctions not find op_name: {op_name}"); | |||
| } | |||
| if (ops.gradientFunctions[op_name] == null) | |||
| return new Tensor[op_inputs.Length]; | |||
| @@ -381,5 +381,48 @@ namespace Tensorflow.Gradients | |||
| var axis = op.inputs[1]; | |||
| return new Tensor[] { array_ops.reverse(grad, axis), null }; | |||
| } | |||
| [RegisterGradient("Tile")] | |||
| public static Tensor[] _TileGrad(Operation op, Tensor[] grads) | |||
| { | |||
| var grad = grads[0]; | |||
| var input_shape = array_ops.shape(op.inputs[0], out_type: op.inputs[1].dtype); | |||
| var split_shape = array_ops.reshape(array_ops.transpose(array_ops.stack(new Tensor[] { op.inputs[1], input_shape })), new Shape(-1)); | |||
| var axes = math_ops.range(0, array_ops.size(split_shape), 2); | |||
| //# Sum reduces grad along the first dimension for IndexedSlices | |||
| //if isinstance(grad, indexed_slices_lib.IndexedSlices): | |||
| //input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype) | |||
| //grad = math_ops.unsorted_segment_sum( | |||
| // grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0) | |||
| //split_shape = array_ops.concat([[1], split_shape[1:]], axis = 0) | |||
| var input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes); | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| input_grad.set_shape(op.inputs[0].GetShape()); | |||
| } | |||
| return new Tensor[] { input_grad, null }; | |||
| } | |||
| [RegisterGradient("GatherNd")] | |||
| public static Tensor[] _GatherNdGrad(Operation op, Tensor[] grads) | |||
| { | |||
| var @ref = op.inputs[0]; | |||
| var indices = op.inputs[1]; | |||
| var grad = grads[0]; | |||
| var ref_shape = array_ops.shape(@ref, out_type: indices.dtype); | |||
| Tensor ref_grad = null; | |||
| if (indices.shape.ndim == 2 && indices.shape.dims[indices.shape.Length - 1] == 1) | |||
| { | |||
| ref_grad = (Tensor)new IndexedSlices(grad, array_ops.squeeze(indices, axis: -1), ref_shape); | |||
| } | |||
| else | |||
| { | |||
| ref_grad = gen_array_ops.scatter_nd(indices, grad, ref_shape); | |||
| } | |||
| return new Tensor[] { ref_grad, null }; | |||
| } | |||
| } | |||
| } | |||
| @@ -229,6 +229,37 @@ namespace Tensorflow.Gradients | |||
| }; | |||
| } | |||
| /// <summary> | |||
| /// Gradient function for Conv2D. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <param name="grads"></param> | |||
| /// <returns></returns> | |||
| [RegisterGradient("DepthwiseConv2dNative")] | |||
| public static Tensor[] _DepthwiseConv2DGrad(Operation op, Tensor[] grads) | |||
| { | |||
| var dilations = op.get_attr_list<int>("dilations"); | |||
| var strides = op.get_attr_list<int>("strides"); | |||
| var padding = op.get_attr<string>("padding"); | |||
| var explicit_paddings = op.get_attr_list<int>("explicit_paddings"); | |||
| var data_format = op.get_attr<string>("data_format"); | |||
| var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | |||
| return new Tensor[] | |||
| { | |||
| gen_nn_ops.depthwise_conv2d_native_backprop_input( | |||
| shape[0], op.inputs[1], grads[0], | |||
| strides, padding, explicit_paddings, | |||
| dilations: dilations, | |||
| data_format: data_format), | |||
| gen_nn_ops.depthwise_conv2d_native_backprop_filter(op.inputs[0], shape[1], grads[0], | |||
| strides, padding, | |||
| dilations: dilations, | |||
| explicit_paddings: explicit_paddings, | |||
| data_format: data_format) | |||
| }; | |||
| } | |||
| [RegisterGradient("FusedBatchNorm")] | |||
| public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) | |||
| => _BaseFusedBatchNormGrad(op, 0, grads); | |||
| @@ -32,6 +32,7 @@ namespace Tensorflow.Keras | |||
| Activation Linear { get; } | |||
| Activation Relu { get; } | |||
| Activation Relu6 { get; } | |||
| Activation Sigmoid { get; } | |||
| @@ -1,5 +1,6 @@ | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| @@ -16,5 +17,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
| public int Worker { get; set; } | |||
| public bool UseMultiprocessing { get; set; } | |||
| public IModel Model { get; set; } | |||
| public Dictionary<int, float> ClassWeight = null; | |||
| public NDArray SampleWeight = null; | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| @@ -18,5 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
| public bool UseMultiprocessing { get; set; } = false; | |||
| public IModel Model { get; set; } | |||
| public IVariableV1 StepsPerExecution { get; set; } | |||
| public Dictionary<int, float> ClassWeight = null; | |||
| public NDArray SampleWeight = null; | |||
| } | |||
| } | |||
| @@ -1,13 +1,15 @@ | |||
| using System; | |||
| using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| // TODO: complete the implementation | |||
| public class MergeArgs : LayerArgs | |||
| public class MergeArgs : AutoSerializeLayerArgs | |||
| { | |||
| public Tensors Inputs { get; set; } | |||
| [JsonProperty("axis")] | |||
| public int Axis { get; set; } | |||
| } | |||
| } | |||
| @@ -4,10 +4,8 @@ using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class GRUOptionalArgs | |||
| public class GRUOptionalArgs : RnnOptionalArgs | |||
| { | |||
| public string Identifier => "GRU"; | |||
| public Tensor Mask { get; set; } = null; | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class LSTMOptionalArgs : RnnOptionalArgs | |||
| { | |||
| public string Identifier => "LSTM"; | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class SimpleRNNOptionalArgs : RnnOptionalArgs | |||
| { | |||
| public string Identifier => "SimpleRNN"; | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Keras.Engine; | |||
| @@ -22,8 +23,11 @@ public interface IModel : ILayer | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| (NDArray val_x, NDArray val_y)? validation_data = null, | |||
| ValidationDataPack validation_data = null, | |||
| int validation_step = 10, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -35,8 +39,24 @@ public interface IModel : ILayer | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| (IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||
| ValidationDataPack validation_data = null, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| bool use_multiprocessing = false); | |||
| public ICallback fit(IDatasetV2 dataset, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| IDatasetV2 validation_data = null, | |||
| int validation_step = 10, // 间隔多少次会进行一次验证 | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -63,6 +83,8 @@ public interface IModel : ILayer | |||
| Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||
| int batch_size = -1, | |||
| int verbose = 1, | |||
| NDArray sample_weight = null, | |||
| int steps = -1, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -78,6 +100,14 @@ public interface IModel : ILayer | |||
| int workers = 1, | |||
| bool use_multiprocessing = false); | |||
| public Tensors predict(IDatasetV2 dataset, | |||
| int batch_size = -1, | |||
| int verbose = 0, | |||
| int steps = -1, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| bool use_multiprocessing = false); | |||
| void summary(int line_length = -1, float[] positions = null); | |||
| IKerasConfig get_config(); | |||
| @@ -55,6 +55,12 @@ namespace Tensorflow.Keras.Layers | |||
| string kernel_initializer = "glorot_uniform", | |||
| string bias_initializer = "zeros"); | |||
| public ILayer Conv2D(int filters, | |||
| Shape kernel_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid" | |||
| ); | |||
| public ILayer Conv2D(int filters, | |||
| Shape kernel_size = null, | |||
| Shape strides = null, | |||
| @@ -95,6 +101,19 @@ namespace Tensorflow.Keras.Layers | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string bias_initializer = "zeros"); | |||
| public ILayer DepthwiseConv2D(Shape kernel_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| Shape dilation_rate = null, | |||
| int groups = 1, | |||
| int depth_multiplier = 1, | |||
| string activation = null, | |||
| bool use_bias = false, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string bias_initializer = "zeros", | |||
| string depthwise_initializer = "glorot_uniform" | |||
| ); | |||
| public ILayer Dense(int units); | |||
| public ILayer Dense(int units, | |||
| @@ -161,6 +180,9 @@ namespace Tensorflow.Keras.Layers | |||
| public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | |||
| public ILayer LeakyReLU(float alpha = 0.3f); | |||
| public ILayer ReLU6(); | |||
| public IRnnCell LSTMCell(int uints, | |||
| string activation = "tanh", | |||
| string recurrent_activation = "sigmoid", | |||
| @@ -30,6 +30,15 @@ namespace Tensorflow.NumPy | |||
| [AutoNumPy] | |||
| public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays)); | |||
| [AutoNumPy] | |||
| public static NDArray stack(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.stack(arrays, axis)); | |||
| [AutoNumPy] | |||
| public static NDArray stack((NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2 }, axis)); | |||
| [AutoNumPy] | |||
| public static NDArray stack((NDArray, NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2, tuple.Item3 }, axis)); | |||
| [AutoNumPy] | |||
| public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination)); | |||
| } | |||
| @@ -437,7 +437,7 @@ namespace Tensorflow | |||
| internal void _set_attr_with_buf(string attr_name, Buffer attr_buf) | |||
| { | |||
| Status status = new(); | |||
| c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status); | |||
| c_api.TF_SetAttr(graph, _handle, attr_name, attr_buf, status); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| @@ -166,6 +166,11 @@ namespace Tensorflow | |||
| throw new ValueError("mask cannot be scalar."); | |||
| var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 })); | |||
| if (leading_size.rank == 0) | |||
| { | |||
| leading_size = expand_dims(leading_size, 0); | |||
| } | |||
| var shape1 = concat(new[] | |||
| { | |||
| shape(tensor_tensor)[$":{axis}"], | |||
| @@ -185,7 +190,7 @@ namespace Tensorflow | |||
| private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) | |||
| { | |||
| var indices = squeeze(where(mask), axis: new[] { 1 }); | |||
| var indices = squeeze(where_v2(mask), axis: new[] { 1 }); | |||
| return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis)); | |||
| } | |||
| @@ -829,7 +834,7 @@ namespace Tensorflow | |||
| /// <returns>A `Tensor`. Has the same type as `input`. | |||
| /// Contains the same data as `input`, but has one or more dimensions of | |||
| /// size 1 removed.</returns> | |||
| public static Tensor squeeze(Tensor input, int[] axis = null, string name = null) | |||
| public static Tensor squeeze(Tensor input, Axis axis = null, string name = null) | |||
| => gen_array_ops.squeeze(input, axis, name); | |||
| public static Tensor identity(Tensor input, string name = null) | |||
| @@ -990,7 +995,7 @@ namespace Tensorflow | |||
| return @params.sparse_read(indices, name); | |||
| } | |||
| public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false) | |||
| public static Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false) | |||
| { | |||
| return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | |||
| { | |||
| @@ -1139,5 +1144,18 @@ namespace Tensorflow | |||
| var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); | |||
| return _op.output; | |||
| } | |||
| public static int get_positive_axis(int axis, int ndims=-100, string axis_name="axis", string ndims_name= "ndims") | |||
| { | |||
| if(ndims != -100) | |||
| { | |||
| if (axis >= 0 && axis < ndims) return axis; | |||
| else if (-ndims <= axis && axis < 0) return axis + ndims; | |||
| else throw new ValueError($"{axis_name}={axis} out of bounds:expected {-ndims}<={axis_name}<{ndims}"); | |||
| } else if(axis < 0) throw new ValueError($"{axis_name}={axis} may only be negative if {ndims_name} is statically known."); | |||
| return axis; | |||
| } | |||
| } | |||
| } | |||
| @@ -51,7 +51,7 @@ namespace Tensorflow.Operations | |||
| } | |||
| Status status = new(); | |||
| var proto = handle_data.ToByteArray(); | |||
| c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); | |||
| c_api.TF_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); | |||
| status.Check(true); | |||
| } | |||
| @@ -102,7 +102,10 @@ namespace Tensorflow | |||
| { | |||
| throw new ValueError("\'image\' must be fully defined."); | |||
| } | |||
| var dims = image_shape["-3:"]; | |||
| var dims = new Shape(new[] { | |||
| image_shape.dims[image_shape.dims.Length - 3], | |||
| image_shape.dims[image_shape.dims.Length - 2], | |||
| image_shape.dims[image_shape.dims.Length - 1]}); | |||
| foreach (var dim in dims.dims) | |||
| { | |||
| if (dim == 0) | |||
| @@ -112,16 +115,18 @@ namespace Tensorflow | |||
| } | |||
| var image_shape_last_three_elements = new Shape(new[] { | |||
| image_shape.dims[image_shape.dims.Length - 1], | |||
| image_shape.dims[image_shape.dims.Length - 3], | |||
| image_shape.dims[image_shape.dims.Length - 2], | |||
| image_shape.dims[image_shape.dims.Length - 3]}); | |||
| image_shape.dims[image_shape.dims.Length - 1]}); | |||
| if (!image_shape_last_three_elements.IsFullyDefined) | |||
| { | |||
| Tensor image_shape_ = array_ops.shape(image); | |||
| var image_shape_return = tf.constant(new[] { | |||
| image_shape_.dims[image_shape.dims.Length - 1], | |||
| image_shape_.dims[image_shape.dims.Length - 2], | |||
| image_shape_.dims[image_shape.dims.Length - 3]}); | |||
| var image_shape_return = tf.slice(image_shape_, new[] { Math.Max(image_shape.dims.Length - 3, 0) }, new[] { 3 }); | |||
| //var image_shape_return = tf.constant(new[] { | |||
| // image_shape_.dims[image_shape_.dims.Length - 3], | |||
| // image_shape_.dims[image_shape_.dims.Length - 2], | |||
| // image_shape_.dims[image_shape_.dims.Length - 1]}); | |||
| return new Operation[] { | |||
| check_ops.assert_positive( | |||
| @@ -209,10 +214,10 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor flip_left_right(Tensor image) | |||
| => _flip(image, 0, "flip_left_right"); | |||
| => _flip(image, 1, "flip_left_right"); | |||
| public static Tensor flip_up_down(Tensor image) | |||
| => _flip(image, 1, "flip_up_down"); | |||
| => _flip(image, 0, "flip_up_down"); | |||
| internal static Tensor _flip(Tensor image, int flip_index, string scope_name) | |||
| { | |||
| @@ -223,11 +228,11 @@ namespace Tensorflow | |||
| Shape shape = image.shape; | |||
| if (shape.ndim == 3 || shape.ndim == Unknown) | |||
| { | |||
| return fix_image_flip_shape(image, gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index }))); | |||
| return fix_image_flip_shape(image, gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new int[] { flip_index }))); | |||
| } | |||
| else if (shape.ndim == 4) | |||
| { | |||
| return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { (flip_index + 1) % 2 })); | |||
| return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { flip_index + 1 })); | |||
| } | |||
| else | |||
| { | |||
| @@ -2047,6 +2052,22 @@ new_height, new_width"); | |||
| }); | |||
| } | |||
| public static Tensor encode_jpeg(Tensor contents, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "encode_jpeg"), scope => | |||
| { | |||
| return gen_ops.encode_jpeg(contents, name:name); | |||
| }); | |||
| } | |||
| public static Tensor encode_png(Tensor contents, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "encode_png"), scope => | |||
| { | |||
| return gen_ops.encode_png(contents, name: name); | |||
| }); | |||
| } | |||
| public static Tensor is_jpeg(Tensor contents, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "is_jpeg"), scope => | |||
| @@ -4,8 +4,8 @@ | |||
| <TargetFrameworks>netstandard2.0;net6.0</TargetFrameworks> | |||
| <AssemblyName>Tensorflow.Binding</AssemblyName> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <TargetTensorFlow>2.11.0</TargetTensorFlow> | |||
| <Version>0.110.3</Version> | |||
| <TargetTensorFlow>2.15.0</TargetTensorFlow> | |||
| <Version>0.150.0</Version> | |||
| <LangVersion>10.0</LangVersion> | |||
| <Nullable>enable</Nullable> | |||
| <Authors>Haiping Chen, Eli Belash, Yaohui Liu, Meinrad Recheis</Authors> | |||
| @@ -20,12 +20,16 @@ | |||
| <Description>Google's TensorFlow full binding in .NET Standard. | |||
| Building, training and infering deep learning models. | |||
| https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.110.3.0</AssemblyVersion> | |||
| <AssemblyVersion>0.150.0.0</AssemblyVersion> | |||
| <PackageReleaseNotes> | |||
| tf.net 0.150.x and above are based on tensorflow native 2.15.0 | |||
| * Support BERT model. | |||
| tf.net 0.110.x and above are based on tensorflow native 2.11.0 | |||
| * Support RNN, LSTM model. | |||
| * Support Transformer model. | |||
| * Added IMDB dataset. | |||
| tf.net 0.100.x and above are based on tensorflow native 2.10.0 | |||
| * Eager Mode is added finally. | |||
| @@ -42,8 +46,9 @@ https://tensorflownet.readthedocs.io</Description> | |||
| tf.net 0.7x.x aligns with TensorFlow v2.7.x native library. | |||
| tf.net 0.10x.x aligns with TensorFlow v2.10.x native library. | |||
| tf.net 0.11x.x aligns with TensorFlow v2.11.x native library. | |||
| tf.net 0.15x.x aligns with TensorFlow v2.15.x native library. | |||
| </PackageReleaseNotes> | |||
| <FileVersion>0.110.3.0</FileVersion> | |||
| <FileVersion>0.150.0.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
| <PackageOutputPath>packages</PackageOutputPath> | |||
| @@ -174,8 +179,8 @@ https://tensorflownet.readthedocs.io</Description> | |||
| <ItemGroup> | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" /> | |||
| <PackageReference Include="Newtonsoft.Json" Version="13.0.3" /> | |||
| <PackageReference Include="OneOf" Version="3.0.255" /> | |||
| <PackageReference Include="Protobuf.Text" Version="0.7.1" /> | |||
| <PackageReference Include="OneOf" Version="3.0.263" /> | |||
| <PackageReference Include="Protobuf.Text" Version="0.7.2" /> | |||
| <PackageReference Include="Razorvine.Pickle" Version="1.4.0" /> | |||
| <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | |||
| </ItemGroup> | |||
| @@ -163,5 +163,38 @@ namespace Tensorflow | |||
| { | |||
| return tensor.Tag as RaggedTensor; | |||
| } | |||
| public Tensor nrows(TF_DataType out_type, string name = null) | |||
| { | |||
| tf_with(ops.name_scope(name, "RaggedNRows"), scope => | |||
| { | |||
| return math_ops.cast(this._row_partition.nrows(), dtype: out_type); | |||
| }); | |||
| return null; | |||
| } | |||
| public RaggedTensor row_lengths(int axis=-1, string name=null) | |||
| { | |||
| if (axis == 0) return this._row_partition.nrows(); | |||
| if (axis == 1) return this._row_partition.row_lengths(); | |||
| var values = (RaggedTensor)this._values; | |||
| axis = array_ops.get_positive_axis( | |||
| axis, this.shape.rank, ndims_name: "rank(this)"); | |||
| if (axis == 0) return this.nrows(this._row_partition.GetDataType()); | |||
| else if (axis == 1) | |||
| { | |||
| var splits = this._row_partition.row_splits; | |||
| return splits[new Slice(start: 1)] - splits[new Slice(stop: -1)]; | |||
| } | |||
| else if (this._values is RaggedTensor) | |||
| { | |||
| return values.row_lengths(axis - 1); | |||
| } | |||
| else | |||
| { | |||
| var shape = array_ops.shape(values, out_type: this._row_partition.GetDataType()); | |||
| return array_ops.ones(shape[new Slice(stop:axis - 1)], this._row_partition.GetDataType()) * | |||
| shape[axis - 1]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -14,10 +14,15 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Serilog.Debugging; | |||
| using System; | |||
| using System.Collections.Concurrent; | |||
| using System.Collections.Generic; | |||
| //using System.ComponentModel.DataAnnotations; | |||
| using System.Text; | |||
| using System.Xml.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -99,5 +104,55 @@ namespace Tensorflow | |||
| return new RowPartition(row_splits); | |||
| }); | |||
| } | |||
| public static RowPartition from_row_lengths(Tensor row_lengths, | |||
| bool validate=true, | |||
| TF_DataType dtype = TF_DataType.TF_INT32, | |||
| TF_DataType dtype_hint= TF_DataType.TF_INT32) | |||
| { | |||
| row_lengths = _convert_row_partition( | |||
| row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype); | |||
| Tensor row_limits = math_ops.cumsum<Tensor>(row_lengths, tf.constant(-1)); | |||
| Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0); | |||
| return new RowPartition(row_splits: row_splits, row_lengths: row_lengths); | |||
| } | |||
| public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype, | |||
| TF_DataType dtype_hint= TF_DataType.TF_INT64) | |||
| { | |||
| if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name); | |||
| if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64"); | |||
| return partition; | |||
| } | |||
| public Tensor nrows() | |||
| { | |||
| /*Returns the number of rows created by this `RowPartition*/ | |||
| if (this._nrows != null) return this._nrows; | |||
| var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0); | |||
| if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1; | |||
| else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype); | |||
| } | |||
| public Tensor row_lengths() | |||
| { | |||
| if (this._row_splits != null) | |||
| { | |||
| int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]); | |||
| return tf.constant(nrows_plus_one - 1); | |||
| } | |||
| if (this._row_lengths != null) | |||
| { | |||
| var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]); | |||
| return tf.constant(nrows); | |||
| } | |||
| if(this._nrows != null) | |||
| { | |||
| return tensor_util.constant_value(this._nrows); | |||
| } | |||
| return tf.constant(-1); | |||
| } | |||
| } | |||
| } | |||
| @@ -249,6 +249,9 @@ namespace Tensorflow | |||
| case sbyte val: | |||
| tensor_proto.IntVal.AddRange(new[] { (int)val }); | |||
| break; | |||
| case byte val: | |||
| tensor_proto.IntVal.AddRange(new[] { (int)val }); | |||
| break; | |||
| case int val: | |||
| tensor_proto.IntVal.AddRange(new[] { val }); | |||
| break; | |||
| @@ -262,7 +265,7 @@ namespace Tensorflow | |||
| tensor_proto.DoubleVal.AddRange(new[] { val }); | |||
| break; | |||
| default: | |||
| throw new Exception("make_tensor_proto Not Implemented"); | |||
| throw new Exception($"make_tensor_proto Not Implemented {values.GetType().Name}"); | |||
| } | |||
| } | |||
| @@ -0,0 +1,66 @@ | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Util | |||
| { | |||
| /// <summary> | |||
| /// ValidationDataPack is used to pass validation data to fit method. | |||
| /// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays. | |||
| /// </summary> | |||
| public class ValidationDataPack | |||
| { | |||
| public NDArray val_x; | |||
| public NDArray val_y; | |||
| public NDArray val_sample_weight = null; | |||
| public ValidationDataPack((NDArray, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1; | |||
| this.val_y = validation_data.Item2; | |||
| } | |||
| public ValidationDataPack((NDArray, NDArray, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1; | |||
| this.val_y = validation_data.Item2; | |||
| this.val_sample_weight = validation_data.Item3; | |||
| } | |||
| public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1.ToArray()[0]; | |||
| this.val_y = validation_data.Item2; | |||
| } | |||
| public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1.ToArray()[0]; | |||
| this.val_y = validation_data.Item2; | |||
| this.val_sample_weight = validation_data.Item3; | |||
| } | |||
| public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | |||
| => new ValidationDataPack(validation_data); | |||
| public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data) | |||
| => new ValidationDataPack(validation_data); | |||
| public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||
| => new ValidationDataPack(validation_data); | |||
| public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||
| => new ValidationDataPack(validation_data); | |||
| public void Deconstruct(out NDArray val_x, out NDArray val_y) | |||
| { | |||
| val_x = this.val_x; | |||
| val_y = this.val_y; | |||
| } | |||
| public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||
| { | |||
| val_x = this.val_x; | |||
| val_y = this.val_y; | |||
| val_sample_weight = this.val_sample_weight; | |||
| } | |||
| } | |||
| } | |||
| @@ -590,7 +590,7 @@ namespace Tensorflow | |||
| public static HandleData get_resource_handle_data(Tensor graph_op) | |||
| { | |||
| var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
| var handle_data = c_api.TF_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
| try{ | |||
| var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data)); | |||
| return HandleData.Parser.ParseFrom(handle_str); | |||
| @@ -20,6 +20,11 @@ namespace Tensorflow.Keras | |||
| Name = "relu", | |||
| ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features)) | |||
| }; | |||
| private static Activation _relu6 = new Activation() | |||
| { | |||
| Name = "relu6", | |||
| ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu6", name, new ExecuteOpArgs(features)) | |||
| }; | |||
| private static Activation _sigmoid = new Activation() | |||
| { | |||
| Name = "sigmoid", | |||
| @@ -55,6 +60,7 @@ namespace Tensorflow.Keras | |||
| _nameActivationMap = new Dictionary<string, Activation>(); | |||
| RegisterActivation(_relu); | |||
| RegisterActivation(_relu6); | |||
| RegisterActivation(_linear); | |||
| RegisterActivation(_sigmoid); | |||
| RegisterActivation(_softmax); | |||
| @@ -65,6 +71,7 @@ namespace Tensorflow.Keras | |||
| public Activation Linear => _linear; | |||
| public Activation Relu => _relu; | |||
| public Activation Relu6 => _relu6; | |||
| public Activation Sigmoid => _sigmoid; | |||
| @@ -112,35 +112,39 @@ namespace Tensorflow.Keras.Datasets | |||
| if (start_char != null) | |||
| { | |||
| int[,] new_x_train_array = new int[x_train_array.GetLength(0), x_train_array.GetLength(1) + 1]; | |||
| for (var i = 0; i < x_train_array.GetLength(0); i++) | |||
| var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1)); | |||
| int[,] new_x_train_array = new int[d1, d2 + 1]; | |||
| for (var i = 0; i < d1; i++) | |||
| { | |||
| new_x_train_array[i, 0] = (int)start_char; | |||
| Array.Copy(x_train_array, i * x_train_array.GetLength(1), new_x_train_array, i * new_x_train_array.GetLength(1) + 1, x_train_array.GetLength(1)); | |||
| Array.Copy(x_train_array, i * d2, new_x_train_array, i * (d2 + 1) + 1, d2); | |||
| } | |||
| int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1]; | |||
| for (var i = 0; i < x_test_array.GetLength(0); i++) | |||
| (d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1)); | |||
| int[,] new_x_test_array = new int[d1, d2 + 1]; | |||
| for (var i = 0; i < d1; i++) | |||
| { | |||
| new_x_test_array[i, 0] = (int)start_char; | |||
| Array.Copy(x_test_array, i * x_test_array.GetLength(1), new_x_test_array, i * new_x_test_array.GetLength(1) + 1, x_test_array.GetLength(1)); | |||
| Array.Copy(x_test_array, i * d2, new_x_test_array, i * (d2 + 1) + 1, d2); | |||
| } | |||
| x_train_array = new_x_train_array; | |||
| x_test_array = new_x_test_array; | |||
| } | |||
| else if (index_from != 0) | |||
| { | |||
| for (var i = 0; i < x_train_array.GetLength(0); i++) | |||
| var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1)); | |||
| for (var i = 0; i < d1; i++) | |||
| { | |||
| for (var j = 0; j < x_train_array.GetLength(1); j++) | |||
| for (var j = 0; j < d2; j++) | |||
| { | |||
| if (x_train_array[i, j] == 0) | |||
| break; | |||
| x_train_array[i, j] += index_from; | |||
| } | |||
| } | |||
| for (var i = 0; i < x_test_array.GetLength(0); i++) | |||
| (d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1)); | |||
| for (var i = 0; i < d1; i++) | |||
| { | |||
| for (var j = 0; j < x_test_array.GetLength(1); j++) | |||
| for (var j = 0; j < d2; j++) | |||
| { | |||
| if (x_test_array[i, j] == 0) | |||
| break; | |||
| @@ -169,9 +173,10 @@ namespace Tensorflow.Keras.Datasets | |||
| if (num_words == null) | |||
| { | |||
| var (d1, d2) = (xs_array.GetLength(0), xs_array.GetLength(1)); | |||
| num_words = 0; | |||
| for (var i = 0; i < xs_array.GetLength(0); i++) | |||
| for (var j = 0; j < xs_array.GetLength(1); j++) | |||
| for (var i = 0; i < d1; i++) | |||
| for (var j = 0; j < d2; j++) | |||
| num_words = max((int)num_words, (int)xs_array[i, j]); | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Keras.Engine.DataAdapters | |||
| { | |||
| @@ -34,9 +35,67 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| return (x, y); | |||
| } | |||
| public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight) | |||
| { | |||
| for (int i = 0; i < x.Length; i++) | |||
| { | |||
| if (x[i].shape.ndim == 1) | |||
| x[i] = array_ops.expand_dims(x[i], axis: -1); | |||
| } | |||
| for (int i = 0; i < y.Length; i++) | |||
| { | |||
| if (y[i].shape.ndim == 1) | |||
| y[i] = array_ops.expand_dims(y[i], axis: -1); | |||
| } | |||
| for (int i = 0; i < sample_weight.Length; i++) | |||
| { | |||
| if (sample_weight[i].shape.ndim == 1) | |||
| sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1); | |||
| } | |||
| return (x, y, sample_weight); | |||
| } | |||
| public virtual bool ShouldRecreateIterator() | |||
| { | |||
| return true; | |||
| } | |||
| public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split) | |||
| { | |||
| var x = x_y_sample_weight.Item1; | |||
| var y = x_y_sample_weight.Item2; | |||
| var sample_weight = x_y_sample_weight.Item3; | |||
| int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||
| var train_x = x[new Slice(0, train_count)]; | |||
| var train_y = y[new Slice(0, train_count)]; | |||
| ValidationDataPack validation_data; | |||
| if (sample_weight != null) | |||
| { | |||
| validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]); | |||
| sample_weight = sample_weight[new Slice(0, train_count)]; | |||
| } | |||
| else | |||
| { | |||
| validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]); | |||
| } | |||
| return ((train_x, train_y, sample_weight), validation_data); | |||
| } | |||
| public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable<NDArray>, NDArray, NDArray) x_y_sample_weight, float validation_split) | |||
| { | |||
| var x = x_y_sample_weight.Item1; | |||
| var y = x_y_sample_weight.Item2; | |||
| var sample_weight = x_y_sample_weight.Item3; | |||
| int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||
| var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||
| var train_y = y[new Slice(0, train_count)]; | |||
| var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||
| var val_y = y[new Slice(train_count)]; | |||
| NDArray tmp_sample_weight = sample_weight; | |||
| sample_weight = sample_weight[new Slice(0, train_count)]; | |||
| ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); | |||
| return ((train_x, train_y, sample_weight), validation_data); | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,9 @@ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Util; | |||
| using Tensorflow.Framework; | |||
| namespace Tensorflow.Keras.Engine.DataAdapters | |||
| { | |||
| @@ -23,11 +26,13 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| long _steps_per_execution_value; | |||
| int _initial_epoch => args.InitialEpoch; | |||
| int _epochs => args.Epochs; | |||
| NDArray _sample_weight => args.SampleWeight; | |||
| IVariableV1 _steps_per_execution; | |||
| public DataHandler(DataHandlerArgs args) | |||
| { | |||
| this.args = args; | |||
| if (args.StepsPerExecution == null) | |||
| { | |||
| _steps_per_execution = tf.Variable(1L); | |||
| @@ -48,6 +53,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| BatchSize = args.BatchSize, | |||
| Steps = args.StepsPerEpoch, | |||
| Epochs = args.Epochs - args.InitialEpoch, | |||
| SampleWeight = args.SampleWeight, | |||
| Shuffle = args.Shuffle, | |||
| MaxQueueSize = args.MaxQueueSize, | |||
| Worker = args.Workers, | |||
| @@ -72,10 +78,75 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| } | |||
| _dataset = _adapter.GetDataset(); | |||
| _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||
| _current_step = 0; | |||
| _step_increment = _steps_per_execution_value - 1; | |||
| _insufficient_data = false; | |||
| _configure_dataset_and_inferred_steps(args.X, args.ClassWeight); | |||
| } | |||
| void _configure_dataset_and_inferred_steps(Tensors x, Dictionary<int, float> class_weight) | |||
| { | |||
| if (_dataset == null) | |||
| { | |||
| _dataset = _adapter.GetDataset(); | |||
| _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||
| } | |||
| if (class_weight != null) | |||
| { | |||
| _dataset = _dataset.map(_make_class_weight_map_fn(class_weight)); | |||
| } | |||
| _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||
| } | |||
| Func<Tensors, Tensors> _make_class_weight_map_fn(Dictionary<int, float> class_weight) | |||
| { | |||
| var class_ids = class_weight.Keys.OrderBy(key => key).ToList(); | |||
| var expected_class_ids = range(class_ids[0], class_ids[class_ids.Count - 1] + 1); | |||
| if (!class_ids.SequenceEqual(expected_class_ids)) | |||
| { | |||
| throw new ValueError("Expected `class_weight` to be a dict with keys from 0 to one less "+ | |||
| $"than the number of classes, found {class_weight}"); | |||
| } | |||
| var class_weight_list = new List<float>(); | |||
| foreach (var class_id in class_ids) | |||
| { | |||
| class_weight_list.Add(class_weight[class_id]); | |||
| } | |||
| var class_weight_tensor = tf.convert_to_tensor(class_weight_list.ToArray()); | |||
| Func<Tensors, Tensors> _class_weight_map_fn = (Tensors data) => | |||
| { | |||
| var x = data[0]; | |||
| var y = data[1]; | |||
| var sw = _sample_weight == null ? null : ops.convert_to_tensor(_sample_weight); | |||
| if (y.shape.rank > 2) | |||
| { | |||
| throw new ValueError("`class_weight` not supported for 3+ dimensional targets."); | |||
| } | |||
| var y_classes = smart_module.smart_cond( | |||
| y.shape.rank == 2 && y.shape[1] > 1, | |||
| () => math_ops.argmax(y, dimension: 1), | |||
| () => math_ops.cast(tf.reshape(y, (-1)), TF_DataType.TF_INT64)); | |||
| var cw = array_ops.gather(class_weight_tensor, y_classes); | |||
| if (sw != null) | |||
| { | |||
| cw = tf.cast(cw, sw.dtype); | |||
| cw *= sw; | |||
| } | |||
| else | |||
| { | |||
| sw = cw; | |||
| } | |||
| return new Tensors { x, y, sw }; | |||
| }; | |||
| return _class_weight_map_fn; | |||
| } | |||
| long _infer_steps(int steps_per_epoch, IDatasetV2 dataset) | |||
| @@ -17,6 +17,8 @@ | |||
| IDatasetV2 GetDataset(); | |||
| int GetSize(); | |||
| (Tensors, Tensors) Expand1d(Tensors x, Tensors y); | |||
| (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight); | |||
| bool ShouldRecreateIterator(); | |||
| } | |||
| } | |||
| @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| public TensorLikeDataAdapter(DataAdapterArgs args) | |||
| { | |||
| this.args = args; | |||
| _process_tensorlike(); | |||
| Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null; | |||
| num_samples = (int)args.X.shape[0]; | |||
| var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | |||
| _batch_size = batch_size; | |||
| @@ -37,6 +37,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| inputs.AddRange(args.X); | |||
| if (args.Y != null) | |||
| inputs.AddRange(args.Y); | |||
| if (sample_weight_tensor != null) | |||
| inputs.Add(sample_weight_tensor); | |||
| dataset = slice_inputs(indices_dataset, inputs); | |||
| dataset.FirstInputTensorCount = args.X.Length; | |||
| } | |||
| @@ -94,8 +96,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| public override bool ShouldRecreateIterator() => false; | |||
| void _process_tensorlike() | |||
| Tensor _process_tensorlike(NDArray sample_weights) | |||
| { | |||
| return tf.convert_to_tensor(sample_weights); | |||
| } | |||
| } | |||
| } | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine | |||
| created_layers = created_layers ?? new Dictionary<string, ILayer>(); | |||
| var node_index_map = new Dictionary<(string, int), int>(); | |||
| var node_count_by_layer = new Dictionary<ILayer, int>(); | |||
| var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | |||
| var unprocessed_nodes = new Dictionary<ILayer, List<NodeConfig>>(); | |||
| // First, we create all layers and enqueue nodes to be processed | |||
| foreach (var layer_data in config.Layers) | |||
| process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer); | |||
| @@ -79,7 +79,7 @@ namespace Tensorflow.Keras.Engine | |||
| static void process_layer(Dictionary<string, ILayer> created_layers, | |||
| LayerConfig layer_data, | |||
| Dictionary<ILayer, NodeConfig> unprocessed_nodes, | |||
| Dictionary<ILayer, List<NodeConfig>> unprocessed_nodes, | |||
| Dictionary<ILayer, int> node_count_by_layer) | |||
| { | |||
| ILayer layer = null; | |||
| @@ -92,32 +92,38 @@ namespace Tensorflow.Keras.Engine | |||
| created_layers[layer_name] = layer; | |||
| } | |||
| node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0; | |||
| node_count_by_layer[layer] = layer_data.InboundNodes.Count - (_should_skip_first_node(layer) ? 1 : 0); | |||
| var inbound_nodes_data = layer_data.InboundNodes; | |||
| foreach (var node_data in inbound_nodes_data) | |||
| { | |||
| if (!unprocessed_nodes.ContainsKey(layer)) | |||
| unprocessed_nodes[layer] = node_data; | |||
| unprocessed_nodes[layer] = new List<NodeConfig>() { node_data }; | |||
| else | |||
| unprocessed_nodes.Add(layer, node_data); | |||
| unprocessed_nodes[layer].Add(node_data); | |||
| } | |||
| } | |||
| static void process_node(ILayer layer, | |||
| NodeConfig node_data, | |||
| List<NodeConfig> nodes_data, | |||
| Dictionary<string, ILayer> created_layers, | |||
| Dictionary<ILayer, int> node_count_by_layer, | |||
| Dictionary<(string, int), int> node_index_map) | |||
| { | |||
| var input_tensors = new List<Tensor>(); | |||
| var inbound_layer_name = node_data.Name; | |||
| var inbound_node_index = node_data.NodeIndex; | |||
| var inbound_tensor_index = node_data.TensorIndex; | |||
| var inbound_layer = created_layers[inbound_layer_name]; | |||
| var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; | |||
| input_tensors.Add(inbound_node.Outputs[inbound_node_index]); | |||
| for (int i = 0; i < nodes_data.Count; i++) | |||
| { | |||
| var node_data = nodes_data[i]; | |||
| var inbound_layer_name = node_data.Name; | |||
| var inbound_node_index = node_data.NodeIndex; | |||
| var inbound_tensor_index = node_data.TensorIndex; | |||
| var inbound_layer = created_layers[inbound_layer_name]; | |||
| var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; | |||
| input_tensors.Add(inbound_node.Outputs[inbound_node_index]); | |||
| } | |||
| var output_tensors = layer.Apply(input_tensors); | |||
| @@ -27,6 +27,6 @@ public abstract partial class Layer | |||
| children = new Dictionary<string, Trackable>(); | |||
| } | |||
| return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value); | |||
| return children.Concat(base._trackable_children(save_type, cache)).GroupBy(x => x.Key).Select(g => g.First()).ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| } | |||
| @@ -26,11 +26,11 @@ namespace Tensorflow.Keras.Engine | |||
| /// </summary> | |||
| /// <param name="y_true"></param> | |||
| /// <param name="y_pred"></param> | |||
| public Tensor Call(Tensor y_true, Tensor y_pred) | |||
| public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||
| { | |||
| if (!_built) | |||
| Build(y_pred); | |||
| var loss_value = _losses.Call(y_true, y_pred); | |||
| var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight); | |||
| var loss_metric_value = loss_value; | |||
| var batch_dim = array_ops.shape(y_true)[0]; | |||
| @@ -30,6 +30,7 @@ namespace Tensorflow.Keras.Engine | |||
| public Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||
| int batch_size = -1, | |||
| int verbose = 1, | |||
| NDArray sample_weight = null, | |||
| int steps = -1, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine | |||
| StepsPerEpoch = steps, | |||
| InitialEpoch = 0, | |||
| Epochs = 1, | |||
| SampleWeight = sample_weight, | |||
| MaxQueueSize = max_queue_size, | |||
| Workers = workers, | |||
| UseMultiprocessing = use_multiprocessing, | |||
| @@ -130,6 +132,7 @@ namespace Tensorflow.Keras.Engine | |||
| var end_step = step + data_handler.StepIncrement; | |||
| if (!is_val) | |||
| callbacks.on_test_batch_end(end_step, logs); | |||
| GC.Collect(); | |||
| } | |||
| } | |||
| callbacks.on_test_end(logs); | |||
| @@ -140,7 +143,8 @@ namespace Tensorflow.Keras.Engine | |||
| Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | |||
| { | |||
| var data = iterator.next(); | |||
| var outputs = test_step(data_handler, data[0], data[1]); | |||
| var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) : | |||
| test_step(data_handler, data[0], data[1], data[2]); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| @@ -149,7 +153,13 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| var data = iterator.next(); | |||
| var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
| var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray()); | |||
| var outputs = data.Length == 2 ? | |||
| test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : | |||
| test_step( | |||
| data_handler, | |||
| new Tensors(data.Take(x_size).ToArray()), | |||
| new Tensors(data.Skip(x_size).Take(x_size).ToArray()), | |||
| new Tensors(data.Skip(2 * x_size).ToArray())); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| @@ -157,11 +167,22 @@ namespace Tensorflow.Keras.Engine | |||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) | |||
| { | |||
| (x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
| (x,y) = data_handler.DataAdapter.Expand1d(x, y); | |||
| var y_pred = Apply(x, training: false); | |||
| var loss = compiled_loss.Call(y, y_pred); | |||
| compiled_metrics.update_state(y, y_pred); | |||
| return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | |||
| } | |||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight) | |||
| { | |||
| (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | |||
| var y_pred = Apply(x, training: false); | |||
| var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight); | |||
| compiled_metrics.update_state(y, y_pred); | |||
| return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,10 +6,12 @@ using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine.DataAdapters; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Keras.Callbacks; | |||
| using System.Data; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public partial class Model | |||
| { | |||
| /// <summary> | |||
| @@ -19,19 +21,30 @@ namespace Tensorflow.Keras.Engine | |||
| /// <param name="y"></param> | |||
| /// <param name="batch_size"></param> | |||
| /// <param name="epochs"></param> | |||
| /// <param name="callbacks"></param> | |||
| /// <param name="verbose"></param> | |||
| /// <param name="callbacks"></param> | |||
| /// <param name="validation_split"></param> | |||
| /// <param name="validation_data"></param> | |||
| /// <param name="shuffle"></param> | |||
| /// <param name="class_weight"></param> | |||
| /// <param name="sample_weight"></param> | |||
| /// <param name="initial_epoch"></param> | |||
| /// <param name="max_queue_size"></param> | |||
| /// <param name="workers"></param> | |||
| /// <param name="use_multiprocessing"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="InvalidArgumentError"></exception> | |||
| public ICallback fit(NDArray x, NDArray y, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| (NDArray val_x, NDArray val_y)? validation_data = null, | |||
| ValidationDataPack validation_data = null, | |||
| int validation_step = 10, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -43,25 +56,24 @@ namespace Tensorflow.Keras.Engine | |||
| $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | |||
| } | |||
| var train_x = x; | |||
| var train_y = y; | |||
| // The default dtype in NDArray is double, so we need to cast sample_weight to float to mul with loss which's dtype is float. | |||
| sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); | |||
| if (validation_split != 0f && validation_data == null) | |||
| { | |||
| int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||
| train_x = x[new Slice(0, train_count)]; | |||
| train_y = y[new Slice(0, train_count)]; | |||
| validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); | |||
| ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | |||
| } | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = train_x, | |||
| Y = train_y, | |||
| X = x, | |||
| Y = y, | |||
| SampleWeight = sample_weight, | |||
| BatchSize = batch_size, | |||
| InitialEpoch = initial_epoch, | |||
| Epochs = epochs, | |||
| Shuffle = shuffle, | |||
| ClassWeight = class_weight, | |||
| MaxQueueSize = max_queue_size, | |||
| Workers = workers, | |||
| UseMultiprocessing = use_multiprocessing, | |||
| @@ -73,14 +85,17 @@ namespace Tensorflow.Keras.Engine | |||
| train_step_func: train_step_function); | |||
| } | |||
| public ICallback fit(IEnumerable<NDArray> x, NDArray y, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| (IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||
| ValidationDataPack validation_data = null, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -95,27 +110,24 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| } | |||
| var train_x = x; | |||
| var train_y = y; | |||
| sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); | |||
| if (validation_split != 0f && validation_data == null) | |||
| { | |||
| int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||
| train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||
| train_y = y[new Slice(0, train_count)]; | |||
| var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||
| var val_y = y[new Slice(train_count)]; | |||
| validation_data = (val_x, val_y); | |||
| ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | |||
| } | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = new Tensors(train_x.ToArray()), | |||
| Y = train_y, | |||
| X = new Tensors(x.ToArray()), | |||
| Y = y, | |||
| SampleWeight = sample_weight, | |||
| BatchSize = batch_size, | |||
| InitialEpoch = initial_epoch, | |||
| Epochs = epochs, | |||
| Shuffle = shuffle, | |||
| ClassWeight = class_weight, | |||
| MaxQueueSize = max_queue_size, | |||
| Workers = workers, | |||
| UseMultiprocessing = use_multiprocessing, | |||
| @@ -136,14 +148,15 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| } | |||
| public History fit(IDatasetV2 dataset, | |||
| public ICallback fit(IDatasetV2 dataset, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| IDatasetV2 validation_data = null, | |||
| int validation_step = 10, // 间隔多少次会进行一次验证 | |||
| int validation_step = 10, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -157,6 +170,7 @@ namespace Tensorflow.Keras.Engine | |||
| InitialEpoch = initial_epoch, | |||
| Epochs = epochs, | |||
| Shuffle = shuffle, | |||
| ClassWeight = class_weight, | |||
| MaxQueueSize = max_queue_size, | |||
| Workers = workers, | |||
| UseMultiprocessing = use_multiprocessing, | |||
| @@ -204,13 +218,14 @@ namespace Tensorflow.Keras.Engine | |||
| var end_step = step + data_handler.StepIncrement; | |||
| End_step = end_step; | |||
| callbacks.on_train_batch_end(end_step, logs); | |||
| GC.Collect(); | |||
| } | |||
| if (validation_data != null) | |||
| { | |||
| if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) | |||
| continue; | |||
| var val_logs = evaluate(validation_data); | |||
| foreach(var log in val_logs) | |||
| { | |||
| @@ -219,11 +234,10 @@ namespace Tensorflow.Keras.Engine | |||
| callbacks.on_train_batch_end(End_step, logs); | |||
| } | |||
| GC.Collect(); | |||
| callbacks.on_epoch_end(epoch, logs); | |||
| GC.Collect(); | |||
| GC.WaitForPendingFinalizers(); | |||
| if (stop_training) | |||
| { | |||
| break; | |||
| @@ -233,7 +247,7 @@ namespace Tensorflow.Keras.Engine | |||
| return callbacks.History; | |||
| } | |||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data, | |||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, ValidationDataPack validation_data, | |||
| Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
| { | |||
| stop_training = false; | |||
| @@ -268,13 +282,15 @@ namespace Tensorflow.Keras.Engine | |||
| var end_step = step + data_handler.StepIncrement; | |||
| End_step = end_step; | |||
| callbacks.on_train_batch_end(end_step, logs); | |||
| GC.Collect(); | |||
| } | |||
| if (validation_data != null) | |||
| { | |||
| // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||
| // so we need to pass a is_val parameter to stop on_test_batch_end | |||
| var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); | |||
| var (val_x, val_y, val_sample_weight) = validation_data; | |||
| var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true); | |||
| foreach (var log in val_logs) | |||
| { | |||
| logs["val_" + log.Key] = log.Value; | |||
| @@ -286,7 +302,6 @@ namespace Tensorflow.Keras.Engine | |||
| callbacks.on_epoch_end(epoch, logs); | |||
| GC.Collect(); | |||
| GC.WaitForPendingFinalizers(); | |||
| if (stop_training) | |||
| { | |||
| break; | |||
| @@ -296,64 +311,5 @@ namespace Tensorflow.Keras.Engine | |||
| return callbacks.History; | |||
| } | |||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data, | |||
| Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
| { | |||
| stop_training = false; | |||
| _train_counter.assign(0); | |||
| var callbacks = new CallbackList(new CallbackParams | |||
| { | |||
| Model = this, | |||
| Verbose = verbose, | |||
| Epochs = epochs, | |||
| Steps = data_handler.Inferredsteps | |||
| }); | |||
| if (callbackList != null) | |||
| { | |||
| foreach (var callback in callbackList) | |||
| callbacks.callbacks.add(callback); | |||
| } | |||
| callbacks.on_train_begin(); | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| reset_metrics(); | |||
| callbacks.on_epoch_begin(epoch); | |||
| // data_handler.catch_stop_iteration(); | |||
| var logs = new Dictionary<string, float>(); | |||
| long End_step = 0; | |||
| foreach (var step in data_handler.steps()) | |||
| { | |||
| callbacks.on_train_batch_begin(step); | |||
| logs = train_step_func(data_handler, iterator); | |||
| var end_step = step + data_handler.StepIncrement; | |||
| End_step = end_step; | |||
| callbacks.on_train_batch_end(end_step, logs); | |||
| } | |||
| if (validation_data != null) | |||
| { | |||
| var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); | |||
| foreach (var log in val_logs) | |||
| { | |||
| logs["val_" + log.Key] = log.Value; | |||
| callbacks.on_train_batch_end(End_step, logs); | |||
| } | |||
| } | |||
| callbacks.on_epoch_end(epoch, logs); | |||
| GC.Collect(); | |||
| GC.WaitForPendingFinalizers(); | |||
| if (stop_training) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| return callbacks.History; | |||
| } | |||
| } | |||
| } | |||
| @@ -102,9 +102,9 @@ namespace Tensorflow.Keras.Engine | |||
| for (int i = 0; i < batch_outputs.Length; i++) | |||
| batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0); | |||
| } | |||
| var end_step = step + data_handler.StepIncrement; | |||
| callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } }); | |||
| GC.Collect(); | |||
| } | |||
| } | |||
| @@ -12,7 +12,9 @@ namespace Tensorflow.Keras.Engine | |||
| Dictionary<string, float> train_step_function(DataHandler data_handler, OwnedIterator iterator) | |||
| { | |||
| var data = iterator.next(); | |||
| var outputs = train_step(data_handler, data[0], data[1]); | |||
| // whether have sample_weight | |||
| var outputs = data.Length == 2 ? train_step(data_handler, data[0], data[1]) : | |||
| train_step(data_handler, data[0], data[1], data[2]); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| @@ -21,7 +23,13 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| var data = iterator.next(); | |||
| var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
| var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||
| var outputs = data.Length == 2 ? | |||
| train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : | |||
| train_step( | |||
| data_handler, | |||
| new Tensors(data.Take(x_size).ToArray()), | |||
| new Tensors(data.Skip(x_size).Take(x_size).ToArray()), | |||
| new Tensors(data.Skip(2 * x_size).ToArray())); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| @@ -61,6 +69,34 @@ namespace Tensorflow.Keras.Engine | |||
| }); | |||
| return dict; | |||
| } | |||
| Dictionary<string, float> train_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) | |||
| { | |||
| (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | |||
| using var tape = tf.GradientTape(); | |||
| var y_pred = Apply(x, training: true); | |||
| var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); | |||
| // For custom training steps, users can just write: | |||
| // trainable_variables = self.trainable_variables | |||
| // gradients = tape.gradient(loss, trainable_variables) | |||
| // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | |||
| // The _minimize call does a few extra steps unnecessary in most cases, | |||
| // such as loss scaling and gradient clipping. | |||
| _minimize(tape, optimizer, loss, TrainableVariables); | |||
| compiled_metrics.update_state(y, y_pred); | |||
| var dict = new Dictionary<string, float>(); | |||
| metrics.ToList().ForEach(x => | |||
| { | |||
| var r = x.result(); | |||
| if (r.ndim > 0) | |||
| { | |||
| r = tf.reduce_mean(r); | |||
| } | |||
| dict[x.Name] = (float)r; | |||
| }); | |||
| return dict; | |||
| } | |||
| void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List<IVariableV1> trainable_variables) | |||
| { | |||
| @@ -10,8 +10,38 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| public partial class Model | |||
| { | |||
| static Dictionary<string, List<(string, NDArray)>> weightsCache | |||
| = new Dictionary<string, List<(string, NDArray)>>(); | |||
| public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) | |||
| { | |||
| // Get from cache | |||
| if (weightsCache.ContainsKey(filepath)) | |||
| { | |||
| var filtered_layers = new List<ILayer>(); | |||
| foreach (var layer in Layers) | |||
| { | |||
| var weights = hdf5_format._legacy_weights(layer); | |||
| if (weights.Count > 0) | |||
| filtered_layers.append(layer); | |||
| } | |||
| var weight_value_tuples = new List<(IVariableV1, NDArray)>(); | |||
| filtered_layers.Select((layer, i) => | |||
| { | |||
| var symbolic_weights = hdf5_format._legacy_weights(layer); | |||
| foreach(var weight in symbolic_weights) | |||
| { | |||
| var weight_value = weightsCache[filepath].First(x => x.Item1 == weight.Name).Item2; | |||
| weight_value_tuples.Add((weight, weight_value)); | |||
| } | |||
| return layer; | |||
| }).ToList(); | |||
| keras.backend.batch_set_value(weight_value_tuples); | |||
| return; | |||
| } | |||
| long fileId = Hdf5.OpenFile(filepath, true); | |||
| if(fileId < 0) | |||
| { | |||
| @@ -29,8 +59,11 @@ namespace Tensorflow.Keras.Engine | |||
| throw new NotImplementedException(""); | |||
| else | |||
| { | |||
| hdf5_format.load_weights_from_hdf5_group(fileId, Layers); | |||
| var weight_value_tuples = hdf5_format.load_weights_from_hdf5_group(fileId, Layers); | |||
| Hdf5.CloseFile(fileId); | |||
| weightsCache[filepath] = weight_value_tuples.Select(x => (x.Item1.Name, x.Item2)).ToList(); | |||
| keras.backend.batch_set_value(weight_value_tuples); | |||
| } | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| /// <summary> | |||
| /// Leaky version of a Rectified Linear Unit. | |||
| /// </summary> | |||
| public class ReLu6 : Layer | |||
| { | |||
| public ReLu6() : base(new LayerArgs { }) | |||
| { | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| return tf.nn.relu6(inputs); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,167 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Operations; | |||
| using Newtonsoft.Json; | |||
| using System.Security.Cryptography; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class DepthwiseConv2DArgs: Conv2DArgs | |||
| { | |||
| /// <summary> | |||
| /// depth_multiplier: The number of depthwise convolution output channels for | |||
| /// each input channel.The total number of depthwise convolution output | |||
| /// channels will be equal to `filters_in* depth_multiplier`. | |||
| /// </summary> | |||
| [JsonProperty("depth_multiplier")] | |||
| public int DepthMultiplier { get; set; } = 1; | |||
| [JsonProperty("depthwise_initializer")] | |||
| public IInitializer DepthwiseInitializer { get; set; } | |||
| } | |||
| public class DepthwiseConv2D : Conv2D | |||
| { | |||
| /// <summary> | |||
| /// depth_multiplier: The number of depthwise convolution output channels for | |||
| /// each input channel.The total number of depthwise convolution output | |||
| /// channels will be equal to `filters_in* depth_multiplier`. | |||
| /// </summary> | |||
| int DepthMultiplier = 1; | |||
| IInitializer DepthwiseInitializer; | |||
| int[] strides; | |||
| int[] dilation_rate; | |||
| string getDataFormat() | |||
| { | |||
| return data_format == "channels_first" ? "NCHW" : "NHWC"; | |||
| } | |||
| static int _id = 1; | |||
| public DepthwiseConv2D(DepthwiseConv2DArgs args):base(args) | |||
| { | |||
| args.Padding = args.Padding.ToUpper(); | |||
| if(string.IsNullOrEmpty(args.Name)) | |||
| name = "DepthwiseConv2D_" + _id; | |||
| this.DepthMultiplier = args.DepthMultiplier; | |||
| this.DepthwiseInitializer = args.DepthwiseInitializer; | |||
| } | |||
| public override void build(KerasShapesWrapper input_shape) | |||
| { | |||
| //base.build(input_shape); | |||
| var shape = input_shape.ToSingleShape(); | |||
| int channel_axis = data_format == "channels_first" ? 1 : -1; | |||
| var input_channel = channel_axis < 0 ? | |||
| shape.dims[shape.ndim + channel_axis] : | |||
| shape.dims[channel_axis]; | |||
| var arg = args as DepthwiseConv2DArgs; | |||
| if (arg.Strides.ndim != shape.ndim) | |||
| { | |||
| if (arg.Strides.ndim == 2) | |||
| { | |||
| this.strides = new int[] { 1, (int)arg.Strides[0], (int)arg.Strides[1], 1 }; | |||
| } | |||
| else | |||
| { | |||
| this.strides = conv_utils.normalize_tuple(new int[] { (int)arg.Strides[0] }, shape.ndim, "strides"); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| this.strides = arg.Strides.dims.Select(o=>(int)(o)).ToArray(); | |||
| } | |||
| if (arg.DilationRate.ndim != shape.ndim) | |||
| { | |||
| this.dilation_rate = conv_utils.normalize_tuple(new int[] { (int)arg.DilationRate[0] }, shape.ndim, "dilation_rate"); | |||
| } | |||
| long channel_data = data_format == "channels_first" ? shape[0] : shape[shape.Length - 1]; | |||
| var depthwise_kernel_shape = this.kernel_size.dims.concat(new long[] { | |||
| channel_data, | |||
| this.DepthMultiplier | |||
| }); | |||
| this.kernel = this.add_weight( | |||
| shape: depthwise_kernel_shape, | |||
| initializer: this.DepthwiseInitializer != null ? this.DepthwiseInitializer : this.kernel_initializer, | |||
| name: "depthwise_kernel", | |||
| trainable: true, | |||
| dtype: DType, | |||
| regularizer: this.kernel_regularizer | |||
| ); | |||
| var axes = new Dictionary<int, int>(); | |||
| axes.Add(-1, (int)input_channel); | |||
| inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes); | |||
| if (use_bias) | |||
| { | |||
| bias = add_weight(name: "bias", | |||
| shape: ((int)channel_data), | |||
| initializer: bias_initializer, | |||
| trainable: true, | |||
| dtype: DType); | |||
| } | |||
| built = true; | |||
| _buildInputShape = input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, | |||
| bool? training = false, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor outputs = null; | |||
| outputs = gen_nn_ops.depthwise_conv2d_native( | |||
| inputs, | |||
| filter: this.kernel.AsTensor(), | |||
| strides: this.strides, | |||
| padding: this.padding, | |||
| dilations: this.dilation_rate, | |||
| data_format: this.getDataFormat(), | |||
| name: name | |||
| ); | |||
| if (use_bias) | |||
| { | |||
| if (data_format == "channels_first") | |||
| { | |||
| throw new NotImplementedException("call channels_first"); | |||
| } | |||
| else | |||
| { | |||
| outputs = gen_nn_ops.bias_add(outputs, ops.convert_to_tensor(bias), | |||
| data_format: this.getDataFormat(), name: name); | |||
| } | |||
| } | |||
| if (activation != null) | |||
| outputs = activation.Apply(outputs); | |||
| return outputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -112,7 +112,28 @@ namespace Tensorflow.Keras.Layers | |||
| KernelInitializer = GetInitializerByName(kernel_initializer), | |||
| BiasInitializer = GetInitializerByName(bias_initializer) | |||
| }); | |||
| public ILayer Conv2D(int filters, | |||
| Shape kernel_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid") | |||
| => new Conv2D(new Conv2DArgs | |||
| { | |||
| Rank = 2, | |||
| Filters = filters, | |||
| KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, | |||
| Strides = strides == null ? (1, 1) : strides, | |||
| Padding = padding, | |||
| DataFormat = null, | |||
| DilationRate = (1, 1), | |||
| Groups = 1, | |||
| UseBias = false, | |||
| KernelRegularizer = null, | |||
| KernelInitializer =tf.glorot_uniform_initializer, | |||
| BiasInitializer = tf.zeros_initializer, | |||
| BiasRegularizer = null, | |||
| ActivityRegularizer = null, | |||
| Activation = keras.activations.Linear, | |||
| }); | |||
| /// <summary> | |||
| /// 2D convolution layer (e.g. spatial convolution over images). | |||
| /// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. | |||
| @@ -210,6 +231,38 @@ namespace Tensorflow.Keras.Layers | |||
| Activation = keras.activations.GetActivationFromName(activation) | |||
| }); | |||
| public ILayer DepthwiseConv2D(Shape kernel_size = null, | |||
| Shape strides = null, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| Shape dilation_rate = null, | |||
| int groups = 1, | |||
| int depth_multiplier = 1, | |||
| string activation = null, | |||
| bool use_bias = false, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string bias_initializer = "zeros", | |||
| string depthwise_initializer = "glorot_uniform" | |||
| ) | |||
| => new DepthwiseConv2D(new DepthwiseConv2DArgs | |||
| { | |||
| Rank = 2, | |||
| Filters = 1, | |||
| KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, | |||
| Strides = strides == null ? (1) : strides, | |||
| Padding = padding, | |||
| DepthMultiplier = depth_multiplier, | |||
| DataFormat = data_format, | |||
| DilationRate = dilation_rate == null ? (1) : dilation_rate, | |||
| Groups = groups, | |||
| UseBias = use_bias, | |||
| KernelInitializer = GetInitializerByName(kernel_initializer), | |||
| DepthwiseInitializer = GetInitializerByName(depthwise_initializer == null ? kernel_initializer : depthwise_initializer), | |||
| BiasInitializer = GetInitializerByName(bias_initializer), | |||
| Activation = keras.activations.GetActivationFromName(activation), | |||
| }); | |||
| /// <summary> | |||
| /// Transposed convolution layer (sometimes called Deconvolution). | |||
| /// </summary> | |||
| @@ -682,6 +735,15 @@ namespace Tensorflow.Keras.Layers | |||
| }); | |||
| /// <summary> | |||
| /// Leaky version of a Rectified Linear Unit. | |||
| /// </summary> | |||
| /// <param name="alpha">Negative slope coefficient.</param> | |||
| /// <returns></returns> | |||
| public ILayer ReLU6() | |||
| => new ReLu6(); | |||
| public IRnnCell SimpleRNNCell( | |||
| int units, | |||
| string activation = "tanh", | |||
| @@ -39,6 +39,7 @@ namespace Tensorflow.Keras.Layers | |||
| shape_set.Add(shape); | |||
| }*/ | |||
| _buildInputShape = input_shape; | |||
| built = true; | |||
| } | |||
| protected override Tensors _merge_function(Tensors inputs) | |||
| @@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Saving | |||
| } | |||
| public static void load_weights_from_hdf5_group(long f, List<ILayer> layers) | |||
| public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List<ILayer> layers) | |||
| { | |||
| string original_keras_version = "2.5.0"; | |||
| string original_backend = null; | |||
| @@ -152,7 +152,7 @@ namespace Tensorflow.Keras.Saving | |||
| weight_value_tuples.AddRange(zip(symbolic_weights, weight_values)); | |||
| } | |||
| keras.backend.batch_set_value(weight_value_tuples); | |||
| return weight_value_tuples; | |||
| } | |||
| public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false) | |||
| @@ -7,7 +7,7 @@ | |||
| <Nullable>enable</Nullable> | |||
| <RootNamespace>Tensorflow.Keras</RootNamespace> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| <Version>0.11.3</Version> | |||
| <Version>0.15.0</Version> | |||
| <Authors>Haiping Chen</Authors> | |||
| <Product>Keras for .NET</Product> | |||
| <Copyright>Apache 2.0, Haiping Chen since 2018</Copyright> | |||
| @@ -30,6 +30,7 @@ | |||
| * Fixed memory leak for YOLOv3 model. | |||
| * Support RNN and LSTM models | |||
| * Support Transformer model | |||
| * Support BERT model | |||
| </PackageReleaseNotes> | |||
| <Description>Keras for .NET | |||
| @@ -42,8 +43,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
| <RepositoryType>Git</RepositoryType> | |||
| <SignAssembly>False</SignAssembly> | |||
| <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
| <AssemblyVersion>0.11.3.0</AssemblyVersion> | |||
| <FileVersion>0.11.3.0</FileVersion> | |||
| <AssemblyVersion>0.15.0.0</AssemblyVersion> | |||
| <FileVersion>0.15.0.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <Configurations>Debug;Release;GPU</Configurations> | |||
| </PropertyGroup> | |||
| @@ -143,7 +144,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="HDF5-CSharp" Version="1.18.0" /> | |||
| <PackageReference Include="HDF5-CSharp" Version="1.19.0" /> | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" /> | |||
| <PackageReference Include="SharpZipLib" Version="1.4.2" /> | |||
| </ItemGroup> | |||
| @@ -53,15 +53,17 @@ namespace Tensorflow.Keras.Utils | |||
| new_seq, new_label: shortened lists for `seq` and `label`. | |||
| */ | |||
| var nRow = seq.GetLength(0); | |||
| var nCol = seq.GetLength(1); | |||
| List<int[]> new_seq = new List<int[]>(); | |||
| List<long> new_label = new List<long>(); | |||
| for (var i = 0; i < seq.GetLength(0); i++) | |||
| for (var i = 0; i < nRow; i++) | |||
| { | |||
| if (maxlen < seq.GetLength(1) && seq[i, maxlen] != 0) | |||
| if (maxlen < nCol && seq[i, maxlen] != 0) | |||
| continue; | |||
| int[] sentence = new int[maxlen]; | |||
| for (var j = 0; j < maxlen && j < seq.GetLength(1); j++) | |||
| for (var j = 0; j < maxlen && j < nCol; j++) | |||
| { | |||
| sentence[j] = seq[i, j]; | |||
| } | |||
| @@ -112,12 +112,23 @@ namespace Tensorflow.Keras.Utils | |||
| foreach (var token in layersToken) | |||
| { | |||
| var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]); | |||
| List<NodeConfig> nodeConfig = null; //python tensorflow sometimes exports inbound nodes in an extra nested array | |||
| if (token["inbound_nodes"].Count() > 0 && token["inbound_nodes"][0].Count() > 0 && token["inbound_nodes"][0][0].Count() > 0) | |||
| { | |||
| nodeConfig = token["inbound_nodes"].ToObject<List<List<NodeConfig>>>().FirstOrDefault() ?? new List<NodeConfig>(); | |||
| } | |||
| else | |||
| { | |||
| nodeConfig = token["inbound_nodes"].ToObject<List<NodeConfig>>(); | |||
| } | |||
| config.Layers.Add(new LayerConfig() | |||
| { | |||
| Config = args, | |||
| Name = token["name"].ToObject<string>(), | |||
| ClassName = token["class_name"].ToObject<string>(), | |||
| InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>() | |||
| InboundNodes = nodeConfig, | |||
| }); | |||
| } | |||
| config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>(); | |||
| @@ -26,7 +26,7 @@ | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="SharpCompress" Version="0.33.0" /> | |||
| <PackageReference Include="SharpCompress" Version="0.34.1" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -0,0 +1,24 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <TargetFramework>net6.0</TargetFramework> | |||
| <ImplicitUsings>enable</ImplicitUsings> | |||
| <Nullable>enable</Nullable> | |||
| <IsPackable>false</IsPackable> | |||
| <IsTestProject>true</IsTestProject> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.5.0" /> | |||
| <PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> | |||
| <PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> | |||
| <PackageReference Include="coverlet.collector" Version="3.2.0" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | |||
| <ProjectReference Include="..\..\tools\Tensorflow.UnitTest.RedistHolder\Tensorflow.UnitTest.RedistHolder.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -0,0 +1,63 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlow.Kernel.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class concat_op_test | |||
| { | |||
| [TestMethod] | |||
| public void testConcatEmpty() | |||
| { | |||
| var t1 = tf.constant(new int[] { }); | |||
| var t2 = tf.constant(new int[] { }); | |||
| var c = array_ops.concat(new[] { t1, t2 }, 0); | |||
| var expected = np.array(new int[] { }); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), c.numpy().ToArray<int>())); | |||
| } | |||
| [TestMethod] | |||
| public void testConcatNegativeAxis() | |||
| { | |||
| var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||
| var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } }); | |||
| var c = array_ops.concat(new[] { t1, t2 }, -2); | |||
| var expected = np.array(new int[,,] { { { 1, 2, 3 }, { 4, 5, 6 } }, { { 7, 8, 9 }, { 10, 11, 12 } } }); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), c.numpy().ToArray<int>())); | |||
| c = array_ops.concat(new[] { t1, t2 }, -1); | |||
| expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } }); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), c.numpy().ToArray<int>())); | |||
| } | |||
| [TestMethod] | |||
| [DataRow(TF_DataType.TF_INT32)] | |||
| [DataRow(TF_DataType.TF_INT64)] | |||
| [DataRow(TF_DataType.TF_UINT32)] | |||
| [DataRow(TF_DataType.TF_UINT64)] | |||
| public void testConcatDtype(TF_DataType dtype) | |||
| { | |||
| var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }, dtype: dtype); | |||
| var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } }, dtype: dtype); | |||
| var c = array_ops.concat(new[] { t1, t2 }, 1); | |||
| var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } }); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray<int>())); | |||
| } | |||
| [TestMethod] | |||
| [DataRow(TF_DataType.TF_INT32)] | |||
| [DataRow(TF_DataType.TF_INT64)] | |||
| public void testConcatAxisType(TF_DataType dtype) | |||
| { | |||
| var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||
| var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } }); | |||
| var c = array_ops.concat(new[] { t1, t2 }, tf.constant(1, dtype: dtype)); | |||
| var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } }); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray<int>(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray<int>())); | |||
| } | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using Tensorflow.NumPy; | |||
| using System; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest.Basics | |||
| { | |||
| @@ -60,14 +61,14 @@ namespace TensorFlowNET.UnitTest.Basics | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | |||
| } | |||
| [TestMethod, Ignore] | |||
| [TestMethod] | |||
| public void boolean_mask() | |||
| { | |||
| if (!tf.executing_eagerly()) | |||
| tf.enable_eager_execution(); | |||
| var tensor = new[] { 0, 1, 2, 3 }; | |||
| var mask = np.array(new[] { true, false, true, false }); | |||
| var masked = tf.boolean_mask(tensor, mask); | |||
| var sess = tf.Session(); | |||
| var result = sess.run(masked); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | |||
| } | |||
| } | |||
| @@ -4,6 +4,7 @@ using System.Linq; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| using System; | |||
| using System.IO; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| @@ -164,5 +165,94 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.AreEqual(result.size, 16ul); | |||
| Assert.AreEqual(result[0, 0, 0, 0], 12f); | |||
| } | |||
| [TestMethod] | |||
| public void ImageSaveTest() | |||
| { | |||
| var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp"); | |||
| var jpegImgPath = TestHelper.GetFullPathFromDataDir("img001.jpeg"); | |||
| var pngImgPath = TestHelper.GetFullPathFromDataDir("img001.png"); | |||
| File.Delete(jpegImgPath); | |||
| File.Delete(pngImgPath); | |||
| var contents = tf.io.read_file(imgPath); | |||
| var bmp = tf.image.decode_image(contents); | |||
| Assert.AreEqual(bmp.name, "decode_image/DecodeImage:0"); | |||
| var jpeg = tf.image.encode_jpeg(bmp); | |||
| var op1 = tf.io.write_file(jpegImgPath, jpeg); | |||
| var png = tf.image.encode_png(bmp); | |||
| var op2 = tf.io.write_file(pngImgPath, png); | |||
| this.session().run(op1); | |||
| this.session().run(op2); | |||
| Assert.IsTrue(File.Exists(jpegImgPath), "not find file:" + jpegImgPath); | |||
| Assert.IsTrue(File.Exists(pngImgPath), "not find file:" + pngImgPath); | |||
| // 如果要测试图片正确性,需要注释下面两行代码 | |||
| File.Delete(jpegImgPath); | |||
| File.Delete(pngImgPath); | |||
| } | |||
| [TestMethod] | |||
| public void ImageFlipTest() | |||
| { | |||
| var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp"); | |||
| var contents = tf.io.read_file(imgPath); | |||
| var bmp = tf.image.decode_image(contents); | |||
| // 左右翻转 | |||
| var lrImgPath = TestHelper.GetFullPathFromDataDir("img001_lr.png"); | |||
| File.Delete(lrImgPath); | |||
| var lr = tf.image.flip_left_right(bmp); | |||
| var png = tf.image.encode_png(lr); | |||
| var op = tf.io.write_file(lrImgPath, png); | |||
| this.session().run(op); | |||
| Assert.IsTrue(File.Exists(lrImgPath), "not find file:" + lrImgPath); | |||
| // 上下翻转 | |||
| var updownImgPath = TestHelper.GetFullPathFromDataDir("img001_updown.png"); | |||
| File.Delete(updownImgPath); | |||
| var updown = tf.image.flip_up_down(bmp); | |||
| var pngupdown = tf.image.encode_png(updown); | |||
| var op2 = tf.io.write_file(updownImgPath, pngupdown); | |||
| this.session().run(op2); | |||
| Assert.IsTrue(File.Exists(updownImgPath)); | |||
| // 暂时先人工观测图片是否翻转,观测时需要删除下面这两行代码 | |||
| File.Delete(lrImgPath); | |||
| File.Delete(updownImgPath); | |||
| // 多图翻转 | |||
| // 目前直接通过 bmp 拿到 shape ,这里先用默认定义图片大小来构建了 | |||
| var mImg = tf.stack(new[] { bmp, lr }, axis:0); | |||
| print(mImg.shape); | |||
| var up2 = tf.image.flip_up_down(mImg); | |||
| var updownImgPath_m1 = TestHelper.GetFullPathFromDataDir("img001_m_ud.png"); // 直接上下翻转 | |||
| File.Delete(updownImgPath_m1); | |||
| var img001_updown_m2 = TestHelper.GetFullPathFromDataDir("img001_m_lr_ud.png"); // 先左右再上下 | |||
| File.Delete(img001_updown_m2); | |||
| var png2 = tf.image.encode_png(up2[0]); | |||
| tf.io.write_file(updownImgPath_m1, png2); | |||
| png2 = tf.image.encode_png(up2[1]); | |||
| tf.io.write_file(img001_updown_m2, png2); | |||
| // 如果要测试图片正确性,需要注释下面两行代码 | |||
| File.Delete(updownImgPath_m1); | |||
| File.Delete(img001_updown_m2); | |||
| } | |||
| } | |||
| } | |||
| @@ -33,6 +33,40 @@ namespace Tensorflow.Keras.UnitTest | |||
| return ret; | |||
| } | |||
| public void AssertArray(int[] f1, int[] f2) | |||
| { | |||
| bool ret = false; | |||
| for (var i = 0; i < f1.Length; i++) | |||
| { | |||
| ret = f1[i] == f2[i]; | |||
| if (!ret) | |||
| break; | |||
| } | |||
| if (!ret) | |||
| { | |||
| Assert.Fail($"Array not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]"); | |||
| } | |||
| } | |||
| public void AssertArray(float[] f1, float[] f2) | |||
| { | |||
| bool ret = false; | |||
| var tolerance = .00001f; | |||
| for (var i = 0; i < f1.Length; i++) | |||
| { | |||
| ret = Math.Abs(f1[i] - f2[i]) <= tolerance; | |||
| if (!ret) | |||
| break; | |||
| } | |||
| if (!ret) | |||
| { | |||
| Assert.Fail($"Array float not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]"); | |||
| } | |||
| } | |||
| public bool Equal(double[] d1, double[] d2) | |||
| { | |||
| bool ret = false; | |||
| @@ -1,6 +1,8 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System.Linq; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.KerasApi; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.UnitTest.Layers | |||
| { | |||
| @@ -193,5 +195,128 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||
| Assert.AreEqual(x.dims[2], y.shape[2]); | |||
| Assert.AreEqual(filters, y.shape[3]); | |||
| } | |||
| [TestMethod] | |||
| public void BasicDepthwiseConv2D() | |||
| { | |||
| var conv = keras.layers.DepthwiseConv2D(kernel_size:3, strides:1, activation: null, | |||
| padding:"same", depthwise_initializer: "ones"); | |||
| var x = np.arange(2 * 9* 9* 3).reshape((2, 9, 9, 3)); | |||
| var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); | |||
| var y = conv.Apply(x2); | |||
| print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); | |||
| Assert.AreEqual(4, y.shape.ndim); | |||
| var arr = y.numpy().reshape((2, 9, 9, 3)); | |||
| AssertArray(x[new int[] { 1, 1, 1 }].ToArray<int>(), new int[] { 273, 274, 275 }); | |||
| AssertArray(arr[new int[] { 1, 1, 1 }].ToArray<float>(), new float[] { 2457f, 2466f, 2475f }); | |||
| var bn = keras.layers.BatchNormalization(); | |||
| var y2 = bn.Apply(y); | |||
| arr = y2.numpy().ToArray<float>(); | |||
| double delta = 0.0001; // 误差范围 | |||
| Assert.AreEqual(arr[0], 59.97002f, delta); | |||
| Assert.AreEqual(arr[1], 63.96802f, delta); | |||
| } | |||
| [TestMethod] | |||
| public void BasicDepthwiseConv2D_strides_2() | |||
| { | |||
| var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: (1, 2, 2, 1), activation: null, | |||
| padding: "same", depthwise_initializer: "ones"); | |||
| var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3)); | |||
| var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); | |||
| var y = conv.Apply(x2); | |||
| print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); | |||
| Assert.AreEqual(4, y.shape.ndim); | |||
| var arr = y.numpy().reshape((2, 5, 5, 3)); | |||
| AssertArray(x[new int[] { 1, 1, 1 }].ToArray<int>(), new int[] { 273, 274, 275 }); | |||
| AssertArray(arr[new int[] { 1, 1, 1 }].ToArray<float>(), new float[] { 2727f, 2736f, 2745f }); | |||
| var bn = keras.layers.BatchNormalization(); | |||
| var y2 = bn.Apply(y); | |||
| arr = y2.numpy().ToArray<float>(); | |||
| double delta = 0.0001; // 误差范围 | |||
| Assert.AreEqual(arr[0], 59.97002f, delta); | |||
| Assert.AreEqual(arr[1], 63.96802f, delta); | |||
| } | |||
| [TestMethod] | |||
| public void BasicDepthwiseConv2D_strides_3() | |||
| { | |||
| var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 3, activation: null, | |||
| padding: "same", depthwise_initializer: "ones"); | |||
| var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3)); | |||
| var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); | |||
| var y = conv.Apply(x2); | |||
| print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); | |||
| Assert.AreEqual(4, y.shape.ndim); | |||
| var arr = y.numpy().reshape((2, 3, 3, 3)); | |||
| AssertArray(x[new int[] { 1, 1, 1 }].ToArray<int>(), new int[] { 273, 274, 275 }); | |||
| AssertArray(arr[new int[] { 1, 1, 1 }].ToArray<float>(), new float[] { 3267f, 3276f, 3285f }); | |||
| var bn = keras.layers.BatchNormalization(); | |||
| var y2 = bn.Apply(y); | |||
| arr = y2.numpy().ToArray<float>(); | |||
| double delta = 0.0001; // 误差范围 | |||
| Assert.AreEqual(arr[0], 269.86508f, delta); | |||
| Assert.AreEqual(arr[1], 278.8606f, delta); | |||
| } | |||
| [TestMethod] | |||
| public void BasicDepthwiseConv2D_UseBias() | |||
| { | |||
| var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 1, activation: null, | |||
| use_bias: true, padding: "same", | |||
| depthwise_initializer: "ones", | |||
| bias_initializer:"ones" | |||
| ); | |||
| var weight = conv.get_weights(); | |||
| var x = np.arange(9 * 9 * 3).reshape((1, 9, 9, 3)); | |||
| var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); | |||
| var y = conv.Apply(x2); | |||
| Assert.AreEqual(4, y.shape.ndim); | |||
| var arr = y.numpy().ToArray<float>(); | |||
| Assert.AreEqual(arr[0], 61f); | |||
| Assert.AreEqual(arr[1], 65f); | |||
| var bn = keras.layers.BatchNormalization(); | |||
| var y2 = bn.Apply(y); | |||
| arr = y2.numpy().ToArray<float>(); | |||
| double delta = 0.0001; // 误差范围 | |||
| Assert.AreEqual(arr[0], 60.96952f, delta); | |||
| Assert.AreEqual(arr[1], 64.96752f, delta); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.KerasApi; | |||
| @@ -8,12 +9,16 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||
| public class LayersMergingTest : EagerModeTestBase | |||
| { | |||
| [TestMethod] | |||
| public void Concatenate() | |||
| [DataRow(1, 4, 1, 5)] | |||
| [DataRow(2, 2, 2, 5)] | |||
| [DataRow(3, 2, 1, 10)] | |||
| public void Concatenate(int axis, int shapeA, int shapeB, int shapeC) | |||
| { | |||
| var x = np.arange(20).reshape((2, 2, 5)); | |||
| var y = np.arange(20, 30).reshape((2, 1, 5)); | |||
| var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y)); | |||
| Assert.AreEqual((2, 3, 5), z.shape); | |||
| var x = np.arange(10).reshape((1, 2, 1, 5)); | |||
| var y = np.arange(10, 20).reshape((1, 2, 1, 5)); | |||
| var z = keras.layers.Concatenate(axis: axis).Apply(new Tensors(x, y)); | |||
| Assert.AreEqual((1, shapeA, shapeB, shapeC), z.shape); | |||
| } | |||
| } | |||
| } | |||
| @@ -74,8 +74,8 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||
| OneHot = true, | |||
| ValidationSize = 55000, | |||
| }).Result; | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1); | |||
| var sample_weight = np.ones(((int)dataset.Train.Data.shape[0])); | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1, sample_weight:sample_weight); | |||
| } | |||
| [TestMethod] | |||
| @@ -1,10 +1,13 @@ | |||
| using Microsoft.VisualStudio.TestPlatform.Utilities; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Newtonsoft.Json.Linq; | |||
| using System.Linq; | |||
| using System.Xml.Linq; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Keras.UnitTest.Helpers; | |||
| using Tensorflow.NumPy; | |||
| using static HDF.PInvoke.H5Z; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| @@ -124,4 +127,44 @@ public class ModelLoadTest | |||
| var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model; | |||
| model.summary(); | |||
| } | |||
| [TestMethod] | |||
| public void CreateConcatenateModelSaveAndLoad() | |||
| { | |||
| // a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded. | |||
| var input_layer = tf.keras.layers.Input((8, 8, 5)); | |||
| var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer); | |||
| conv1.Name = "conv1"; | |||
| var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer); | |||
| conv2.Name = "conv2"; | |||
| var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2)); | |||
| concat1.Name = "concat1"; | |||
| var model = tf.keras.Model(input_layer, concat1); | |||
| model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); | |||
| model.save(@"Assets/concat_axis3_model"); | |||
| var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT); | |||
| var tensors1 = model.predict(tensorInput); | |||
| Assert.AreEqual((1, 8, 8, 4), tensors1.shape); | |||
| model = null; | |||
| keras.backend.clear_session(); | |||
| var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model"); | |||
| var tensors2 = model2.predict(tensorInput); | |||
| Assert.AreEqual(tensors1.shape, tensors2.shape); | |||
| } | |||
| } | |||
| @@ -20,6 +20,20 @@ namespace TensorFlowNET.UnitTest | |||
| return Math.Abs(f1 - f2) <= tolerance; | |||
| } | |||
| public bool Equal(long[] l1, long[] l2) | |||
| { | |||
| if (l1.Length != l2.Length) | |||
| return false; | |||
| for (var i = 0; i < l1.Length; i++) | |||
| { | |||
| if (l1[i] != l2[i]) | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| public bool Equal(float[] f1, float[] f2) | |||
| { | |||
| bool ret = false; | |||
| @@ -62,7 +62,7 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| // Calcute the gradient of (x1-x2)^2 | |||
| // by Automatic Differentiation in Eager mode | |||
| // Expected is 2*(abs(x1-x2)) | |||
| Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 }); | |||
| Tensor x1 = new NDArray(new float[] { 1, 3, 5, 21, 19, 17 }); | |||
| Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 }); | |||
| float[] expected = new float[] | |||
| { | |||
| @@ -173,5 +173,34 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
| var result = grad(x, 4); | |||
| Assert.AreEqual((float)result, 4.0f); | |||
| } | |||
| [TestMethod] | |||
| public void Tile() | |||
| { | |||
| var a = tf.constant(new int[] { 1 }, TF_DataType.TF_FLOAT); | |||
| var b = tf.constant(new int[] { 2 }); | |||
| using (var tape = tf.GradientTape()) | |||
| { | |||
| tape.watch(a); | |||
| var y = tf.tile(a, b); | |||
| var grad = tape.gradient(y, a); | |||
| Assert.AreEqual((float)grad.numpy(), 2.0f); | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void GatherNdTest() | |||
| { | |||
| var x = tf.constant(new float[,] { { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f } }, dtype: TF_DataType.TF_FLOAT); | |||
| var indices = tf.constant(new int[,] { { 0, 1 }, { 1, 1 }, { 2, 1 } }, dtype: TF_DataType.TF_INT32); | |||
| using (var tape = tf.GradientTape()) | |||
| { | |||
| tape.watch(x); | |||
| var res = tf.gather_nd(x, indices); | |||
| var grad = tape.gradient(res, x); | |||
| var expected = np.array(new float[,] { { 0f, 1f, 0f }, { 0f, 1f, 0f }, { 0f, 1f, 0f } }); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(grad.ToArray<float>(), expected.ToArray<float>())); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using Tensorflow.NumPy; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| using System.Linq; | |||
| using Tensorflow.Operations; | |||
| namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| { | |||
| @@ -105,5 +106,321 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>())); | |||
| Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>())); | |||
| } | |||
| [TestMethod] | |||
| public void ReverseImgArray3D() | |||
| { | |||
| // 创建 sourceImg 数组 | |||
| var sourceImgArray = new float[,,] { | |||
| { | |||
| { 237, 28, 36 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| }; | |||
| var sourceImg = ops.convert_to_tensor(sourceImgArray); | |||
| // 创建 lrImg 数组 | |||
| var lrImgArray = new float[,,] { | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 237, 28, 36 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| }; | |||
| var lrImg = ops.convert_to_tensor(lrImgArray); | |||
| var lr = tf.image.flip_left_right(sourceImg); | |||
| Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail."); | |||
| var lr2 = tf.reverse(sourceImg, 1); | |||
| Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||
| var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); | |||
| Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||
| // 创建 udImg 数组 | |||
| var udImgArray = new float[,,] { | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 237, 28, 36 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| }; | |||
| var udImg = ops.convert_to_tensor(udImgArray); | |||
| var ud = tf.image.flip_up_down(sourceImg); | |||
| Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail."); | |||
| var ud2 = tf.reverse(sourceImg, new Axis(0)); | |||
| Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=0) fail."); | |||
| var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); | |||
| Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=0 fail."); | |||
| } | |||
| [TestMethod] | |||
| public void ReverseImgArray4D() | |||
| { | |||
| // 原图左上角,加一张左右翻转后的图片 | |||
| var m = new float[,,,] { | |||
| { | |||
| { | |||
| { 237, 28, 36 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| }, | |||
| { | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 237, 28, 36 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| } | |||
| }; | |||
| var sourceImg = ops.convert_to_tensor(m); | |||
| var lrArray = new float[,,,] { | |||
| { | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 237, 28, 36 }, | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| }, | |||
| { | |||
| { | |||
| { 237, 28, 36 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| } | |||
| }; | |||
| var lrImg = ops.convert_to_tensor(lrArray); | |||
| // 创建 ud 数组 | |||
| var udArray = new float[,,,] { | |||
| { | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 237, 28, 36 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| }, | |||
| { | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 237, 28, 36 } | |||
| } | |||
| } | |||
| }; | |||
| var udImg = ops.convert_to_tensor(udArray); | |||
| var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); | |||
| Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||
| var ud2 = tf.reverse(sourceImg, new Axis(1)); | |||
| Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||
| var ud = tf.image.flip_up_down(sourceImg); | |||
| Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail."); | |||
| // 左右翻转 | |||
| var lr = tf.image.flip_left_right(sourceImg); | |||
| Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail."); | |||
| var lr2 = tf.reverse(sourceImg, 0); | |||
| Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||
| var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); | |||
| Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||
| } | |||
| [TestMethod] | |||
| public void ReverseImgArray4D_3x3() | |||
| { | |||
| // 原图左上角,加一张左右翻转后的图片 | |||
| var m = new float[,,,] { | |||
| { | |||
| { | |||
| { 237, 28, 36 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| }, | |||
| { | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 237, 28, 36 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| } | |||
| }; | |||
| var sourceImg = ops.convert_to_tensor(m); | |||
| var lrArray = new float[,,,] { | |||
| { | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 237, 28, 36 }, | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| }, | |||
| { | |||
| { | |||
| { 237, 28, 36 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| } | |||
| }; | |||
| var lrImg = ops.convert_to_tensor(lrArray); | |||
| // 创建 ud 数组 | |||
| var udArray = new float[,,,] { | |||
| { | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 237, 28, 36 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| } | |||
| }, | |||
| { { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 } | |||
| }, | |||
| { | |||
| { 255, 255, 255 }, | |||
| { 255, 255, 255 }, | |||
| { 237, 28, 36 } | |||
| } | |||
| } | |||
| }; | |||
| var udImg = ops.convert_to_tensor(udArray); | |||
| var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); | |||
| Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||
| var ud2 = tf.reverse(sourceImg, new Axis(1)); | |||
| Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||
| var ud = tf.image.flip_up_down(sourceImg); | |||
| Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail."); | |||
| // 左右翻转 | |||
| var lr = tf.image.flip_left_right(sourceImg); | |||
| Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail."); | |||
| var lr2 = tf.reverse(sourceImg, 0); | |||
| Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||
| var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); | |||
| Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| { | |||
| public class RaggedTensorTest :EagerModeTestBase | |||
| { | |||
| [TestMethod] | |||
| public void Test_from_row_lengths() | |||
| { | |||
| var row_lengths = tf.convert_to_tensor(np.array(new int[] { 2, 0, 3, 1, 1 }, TF_DataType.TF_INT64)); | |||
| var rp = RowPartition.from_row_lengths(row_lengths, validate: false); | |||
| var rp_row_lengths = rp.row_lengths(); | |||
| var rp_nrows = rp.nrows(); | |||
| Assert.IsTrue(rp_nrows.ToArray<long>()[0] == rp.nrows().ToArray<long>()[0]); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,44 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow.NumPy; | |||
| using System; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest.NumPy | |||
| { | |||
| [TestClass] | |||
| public class ShapeTest : EagerModeTestBase | |||
| { | |||
| [Ignore] | |||
| [TestMethod] | |||
| public unsafe void ShapeGetLastElements() | |||
| { | |||
| // test code from function _CheckAtLeast3DImage | |||
| // 之前的 _CheckAtLeast3DImage 有bug,现在通过测试,下面的代码是正确的 | |||
| // todo: shape["-3:"] 的写法,目前有bug,需要修复,单元测试等修复后再放开,暂时先忽略测试 | |||
| var image_shape = new Shape(new[] { 32, 64, 3 }); | |||
| var image_shape_4d = new Shape(new[] { 4, 64, 32, 3 }); | |||
| var image_shape_last_three_elements = new Shape(new[] { | |||
| image_shape.dims[image_shape.dims.Length - 3], | |||
| image_shape.dims[image_shape.dims.Length - 2], | |||
| image_shape.dims[image_shape.dims.Length - 1]}); | |||
| var image_shape_last_three_elements2 = image_shape["-3:"]; | |||
| Assert.IsTrue(Equal(image_shape_last_three_elements.dims, image_shape_last_three_elements2.dims), "3dims get fail."); | |||
| var image_shape_last_three_elements_4d = new Shape(new[] { | |||
| image_shape_4d.dims[image_shape_4d.dims.Length - 3], | |||
| image_shape_4d.dims[image_shape_4d.dims.Length - 2], | |||
| image_shape_4d.dims[image_shape_4d.dims.Length - 1]}); | |||
| var image_shape_last_three_elements2_4d = image_shape_4d["-3:"]; | |||
| Assert.IsTrue(Equals(image_shape_last_three_elements_4d.dims, image_shape_last_three_elements2_4d.dims), "4dims get fail."); | |||
| } | |||
| } | |||
| } | |||
| @@ -19,13 +19,10 @@ | |||
| <PlatformTarget>AnyCPU</PlatformTarget> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.4" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Recommenders\Tensorflow.Recommenders.csproj" /> | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Text\Tensorflow.Text.csproj" /> | |||
| <ProjectReference Include="..\Tensorflow.UnitTest.RedistHolder\Tensorflow.UnitTest.RedistHolder.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -9,7 +9,6 @@ | |||
| <ItemGroup> | |||
| <PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> | |||
| <PackageReference Include="Protobuf.Text" Version="0.7.1" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -5,7 +5,7 @@ | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.4" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.16.0" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist-Lite" Version="2.6.0" /> | |||
| </ItemGroup> | |||