| @@ -45,5 +45,7 @@ | |||||
| public IRegularizer ActivityRegularizer { get; set; } | public IRegularizer ActivityRegularizer { get; set; } | ||||
| public bool Autocast { get; set; } | public bool Autocast { get; set; } | ||||
| public bool IsFromConfig { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,6 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class ResizingArgs : LayerArgs | |||||
| public class ResizingArgs : PreprocessingLayerArgs | |||||
| { | { | ||||
| public int Height { get; set; } | public int Height { get; set; } | ||||
| public int Width { get; set; } | public int Width { get; set; } | ||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public class PreprocessingLayer : Layer | |||||
| { | |||||
| public PreprocessingLayer(PreprocessingLayerArgs args) : base(args) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,7 +1,9 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -9,7 +11,7 @@ namespace Tensorflow.Keras.Layers | |||||
| /// Resize the batched image input to target height and width. | /// Resize the batched image input to target height and width. | ||||
| /// The input should be a 4-D tensor in the format of NHWC. | /// The input should be a 4-D tensor in the format of NHWC. | ||||
| /// </summary> | /// </summary> | ||||
| public class Resizing : Layer | |||||
| public class Resizing : PreprocessingLayer | |||||
| { | { | ||||
| ResizingArgs args; | ResizingArgs args; | ||||
| public Resizing(ResizingArgs args) : base(args) | public Resizing(ResizingArgs args) : base(args) | ||||
| @@ -26,5 +28,12 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| return new TensorShape(input_shape.dims[0], args.Height, args.Width, input_shape.dims[3]); | return new TensorShape(input_shape.dims[0], args.Height, args.Width, input_shape.dims[3]); | ||||
| } | } | ||||
| public static Resizing from_config(JObject config) | |||||
| { | |||||
| var args = JsonConvert.DeserializeObject<ResizingArgs>(config.ToString()); | |||||
| args.IsFromConfig = true; | |||||
| return new Resizing(args); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -16,7 +17,7 @@ namespace Tensorflow.Keras.Saving | |||||
| public int SharedObjectId { get; set; } | public int SharedObjectId { get; set; } | ||||
| [JsonProperty("must_restore_from_config")] | [JsonProperty("must_restore_from_config")] | ||||
| public bool MustRestoreFromConfig { get; set; } | public bool MustRestoreFromConfig { get; set; } | ||||
| public ModelConfig Config { get; set; } | |||||
| public JObject Config { get; set; } | |||||
| [JsonProperty("build_input_shape")] | [JsonProperty("build_input_shape")] | ||||
| public TensorShapeConfig BuildInputShape { get; set; } | public TensorShapeConfig BuildInputShape { get; set; } | ||||
| } | } | ||||
| @@ -5,8 +5,10 @@ using System.Linq; | |||||
| using System.Text.RegularExpressions; | using System.Text.RegularExpressions; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Layers; | |||||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | using ThirdParty.Tensorflow.Python.Keras.Protobuf; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| @@ -73,7 +75,7 @@ namespace Tensorflow.Keras.Saving | |||||
| { | { | ||||
| model = new Sequential(new SequentialArgs | model = new Sequential(new SequentialArgs | ||||
| { | { | ||||
| Name = config.Name | |||||
| Name = config.GetValue("name").ToString() | |||||
| }); | }); | ||||
| } | } | ||||
| else if (class_name == "Functional") | else if (class_name == "Functional") | ||||
| @@ -97,7 +99,12 @@ namespace Tensorflow.Keras.Saving | |||||
| var class_name = metadata.ClassName; | var class_name = metadata.ClassName; | ||||
| var shared_object_id = metadata.SharedObjectId; | var shared_object_id = metadata.SharedObjectId; | ||||
| var must_restore_from_config = metadata.MustRestoreFromConfig; | var must_restore_from_config = metadata.MustRestoreFromConfig; | ||||
| var obj = class_name switch | |||||
| { | |||||
| "Resizing" => Resizing.from_config(config), | |||||
| _ => throw new NotImplementedException("") | |||||
| }; | |||||
| var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | |||||
| return null; | return null; | ||||
| } | } | ||||
| @@ -157,5 +164,13 @@ namespace Tensorflow.Keras.Saving | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool _try_build_layer(Layer obj, int node_id, TensorShape build_input_shape) | |||||
| { | |||||
| if (obj.Built) | |||||
| return true; | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -10,6 +10,6 @@ namespace Tensorflow.Keras.Saving | |||||
| public int?[] Items { get; set; } | public int?[] Items { get; set; } | ||||
| public static implicit operator TensorShape(TensorShapeConfig shape) | public static implicit operator TensorShape(TensorShapeConfig shape) | ||||
| => new TensorShape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); | |||||
| => shape == null ? null : new TensorShape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); | |||||
| } | } | ||||
| } | } | ||||