| @@ -17,9 +17,11 @@ namespace Tensorflow | |||||
| { | { | ||||
| var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}"); | var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}"); | ||||
| func.Enter(); | func.Enter(); | ||||
| var input = tf.placeholder(input_dataset.element_spec[0].dtype); | |||||
| var output = map_func(input); | |||||
| func.ToGraph(input, output); | |||||
| var inputs = new Tensors(); | |||||
| foreach (var input in input_dataset.element_spec) | |||||
| inputs.Add(tf.placeholder(input.dtype, shape: input.shape)); | |||||
| var outputs = map_func(inputs); | |||||
| func.ToGraph(inputs, outputs); | |||||
| func.Exit(); | func.Exit(); | ||||
| structure = func.OutputStructure; | structure = func.OutputStructure; | ||||
| @@ -86,7 +86,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.</PackageReleaseNotes | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | ||||
| <PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" /> | <PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" /> | ||||
| <PackageReference Include="NumSharp.Lite" Version="0.1.12" /> | |||||
| <PackageReference Include="NumSharp" Version="0.30.0" /> | |||||
| <PackageReference Include="Protobuf.Text" Version="0.5.0" /> | <PackageReference Include="Protobuf.Text" Version="0.5.0" /> | ||||
| <PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" /> | <PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -13,8 +13,40 @@ namespace Tensorflow.Keras.Layers | |||||
| public TextVectorization(TextVectorizationArgs args) | public TextVectorization(TextVectorizationArgs args) | ||||
| : base(args) | : base(args) | ||||
| { | { | ||||
| this.args = args; | |||||
| args.DType = TF_DataType.TF_STRING; | args.DType = TF_DataType.TF_STRING; | ||||
| // string standardize = "lower_and_strip_punctuation", | // string standardize = "lower_and_strip_punctuation", | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Fits the state of the preprocessing layer to the dataset. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="reset_state"></param> | |||||
| public void adapt(IDatasetV2 data, bool reset_state = true) | |||||
| { | |||||
| var shape = data.output_shapes[0]; | |||||
| if (shape.rank == 1) | |||||
| data = data.map(tensor => array_ops.expand_dims(tensor, -1)); | |||||
| build(data.variant_tensor); | |||||
| var preprocessed_inputs = data.map(_preprocess); | |||||
| } | |||||
| protected override void build(Tensors inputs) | |||||
| { | |||||
| base.build(inputs); | |||||
| } | |||||
| Tensors _preprocess(Tensors inputs) | |||||
| { | |||||
| if (args.Standardize != null) | |||||
| inputs = args.Standardize(inputs); | |||||
| if (!string.IsNullOrEmpty(args.Split)) | |||||
| { | |||||
| if (inputs.shape.ndim > 1) | |||||
| inputs = array_ops.squeeze(inputs, axis: new[] { -1 }); | |||||
| } | |||||
| return inputs; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -11,7 +11,7 @@ namespace Tensorflow.Keras | |||||
| public DatasetUtils dataset_utils => new DatasetUtils(); | public DatasetUtils dataset_utils => new DatasetUtils(); | ||||
| public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null, | public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null, | ||||
| string split = "standardize", | |||||
| string split = "whitespace", | |||||
| int max_tokens = -1, | int max_tokens = -1, | ||||
| string output_mode = "int", | string output_mode = "int", | ||||
| int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs | int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs | ||||