| @@ -0,0 +1,23 @@ | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Extensions | |||||
| { | |||||
| public static class JObjectExtensions | |||||
| { | |||||
| public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | |||||
| { | |||||
| var res = obj[key]; | |||||
| if(res is null) | |||||
| { | |||||
| return default(T); | |||||
| } | |||||
| else | |||||
| { | |||||
| return res.ToObject<T>(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -7,7 +7,7 @@ namespace Tensorflow.Framework.Models | |||||
| public TensorSpec(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) : | public TensorSpec(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) : | ||||
| base(shape, dtype, name) | base(shape, dtype, name) | ||||
| { | { | ||||
| } | } | ||||
| public TensorSpec _unbatch() | public TensorSpec _unbatch() | ||||
| @@ -1,7 +1,7 @@ | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using System.Reflection; | using System.Reflection; | ||||
| using System.Runtime.Versioning; | using System.Runtime.Versioning; | ||||
| using Tensorflow.Keras.Common; | |||||
| using Tensorflow.Keras.Saving.Common; | |||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| { | { | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| @@ -18,7 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| [JsonProperty("dtype")] | [JsonProperty("dtype")] | ||||
| public override TF_DataType DType { get => base.DType; set => base.DType = value; } | public override TF_DataType DType { get => base.DType; set => base.DType = value; } | ||||
| [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | ||||
| public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
| public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
| [JsonProperty("trainable")] | [JsonProperty("trainable")] | ||||
| public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | ||||
| } | } | ||||
| @@ -1,6 +1,6 @@ | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using Newtonsoft.Json.Serialization; | using Newtonsoft.Json.Serialization; | ||||
| using Tensorflow.Keras.Common; | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| @@ -17,6 +17,6 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| [JsonProperty("dtype")] | [JsonProperty("dtype")] | ||||
| public override TF_DataType DType { get => base.DType; set => base.DType = value; } | public override TF_DataType DType { get => base.DType; set => base.DType = value; } | ||||
| [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | ||||
| public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
| public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -33,7 +33,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// <summary> | /// <summary> | ||||
| /// Only applicable to input layers. | /// Only applicable to input layers. | ||||
| /// </summary> | /// </summary> | ||||
| public virtual Shape BatchInputShape { get; set; } | |||||
| public virtual KerasShapesWrapper BatchInputShape { get; set; } | |||||
| public virtual int BatchSize { get; set; } = -1; | public virtual int BatchSize { get; set; } = -1; | ||||
| @@ -10,7 +10,7 @@ namespace Tensorflow.Keras | |||||
| string Name { get; } | string Name { get; } | ||||
| bool Trainable { get; } | bool Trainable { get; } | ||||
| bool Built { get; } | bool Built { get; } | ||||
| void build(Shape input_shape); | |||||
| void build(KerasShapesWrapper input_shape); | |||||
| List<ILayer> Layers { get; } | List<ILayer> Layers { get; } | ||||
| List<INode> InboundNodes { get; } | List<INode> InboundNodes { get; } | ||||
| List<INode> OutboundNodes { get; } | List<INode> OutboundNodes { get; } | ||||
| @@ -22,8 +22,8 @@ namespace Tensorflow.Keras | |||||
| void set_weights(IEnumerable<NDArray> weights); | void set_weights(IEnumerable<NDArray> weights); | ||||
| List<NDArray> get_weights(); | List<NDArray> get_weights(); | ||||
| Shape OutputShape { get; } | Shape OutputShape { get; } | ||||
| Shape BatchInputShape { get; } | |||||
| TensorShapeConfig BuildInputShape { get; } | |||||
| KerasShapesWrapper BatchInputShape { get; } | |||||
| KerasShapesWrapper BuildInputShape { get; } | |||||
| TF_DataType DType { get; } | TF_DataType DType { get; } | ||||
| int count_params(); | int count_params(); | ||||
| void adapt(Tensor data, int? batch_size = null, int? steps = null); | void adapt(Tensor data, int? batch_size = null, int? steps = null); | ||||
| @@ -6,7 +6,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Common | |||||
| namespace Tensorflow.Keras.Saving.Common | |||||
| { | { | ||||
| public class CustomizedActivationJsonConverter : JsonConverter | public class CustomizedActivationJsonConverter : JsonConverter | ||||
| { | { | ||||
| @@ -4,7 +4,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Common | |||||
| namespace Tensorflow.Keras.Saving.Common | |||||
| { | { | ||||
| public class CustomizedAxisJsonConverter : JsonConverter | public class CustomizedAxisJsonConverter : JsonConverter | ||||
| { | { | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Common | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
| { | { | ||||
| int[]? axis; | int[]? axis; | ||||
| if(reader.ValueType == typeof(long)) | |||||
| if (reader.ValueType == typeof(long)) | |||||
| { | { | ||||
| axis = new int[1]; | axis = new int[1]; | ||||
| axis[0] = (int)serializer.Deserialize(reader, typeof(int)); | axis[0] = (int)serializer.Deserialize(reader, typeof(int)); | ||||
| @@ -51,7 +51,7 @@ namespace Tensorflow.Keras.Common | |||||
| { | { | ||||
| throw new ValueError("Cannot deserialize 'null' to `Axis`."); | throw new ValueError("Cannot deserialize 'null' to `Axis`."); | ||||
| } | } | ||||
| return new Axis((int[])(axis!)); | |||||
| return new Axis(axis!); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,7 +1,7 @@ | |||||
| using Newtonsoft.Json.Linq; | using Newtonsoft.Json.Linq; | ||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| namespace Tensorflow.Keras.Common | |||||
| namespace Tensorflow.Keras.Saving.Common | |||||
| { | { | ||||
| public class CustomizedDTypeJsonConverter : JsonConverter | public class CustomizedDTypeJsonConverter : JsonConverter | ||||
| { | { | ||||
| @@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Common | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | ||||
| { | { | ||||
| var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value)); | |||||
| var token = JToken.FromObject(((TF_DataType)value).as_numpy_name()); | |||||
| token.WriteTo(writer); | token.WriteTo(writer); | ||||
| } | } | ||||
| @@ -4,9 +4,10 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Operations.Initializers; | using Tensorflow.Operations.Initializers; | ||||
| namespace Tensorflow.Keras.Common | |||||
| namespace Tensorflow.Keras.Saving.Common | |||||
| { | { | ||||
| class InitializerInfo | class InitializerInfo | ||||
| { | { | ||||
| @@ -27,7 +28,7 @@ namespace Tensorflow.Keras.Common | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | ||||
| { | { | ||||
| var initializer = value as IInitializer; | var initializer = value as IInitializer; | ||||
| if(initializer is null) | |||||
| if (initializer is null) | |||||
| { | { | ||||
| JToken.FromObject(null).WriteTo(writer); | JToken.FromObject(null).WriteTo(writer); | ||||
| return; | return; | ||||
| @@ -42,7 +43,7 @@ namespace Tensorflow.Keras.Common | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
| { | { | ||||
| var info = serializer.Deserialize<InitializerInfo>(reader); | var info = serializer.Deserialize<InitializerInfo>(reader); | ||||
| if(info is null) | |||||
| if (info is null) | |||||
| { | { | ||||
| return null; | return null; | ||||
| } | } | ||||
| @@ -54,8 +55,8 @@ namespace Tensorflow.Keras.Common | |||||
| "Orthogonal" => new Orthogonal(info.config["gain"].ToObject<float>(), info.config["seed"].ToObject<int?>()), | "Orthogonal" => new Orthogonal(info.config["gain"].ToObject<float>(), info.config["seed"].ToObject<int?>()), | ||||
| "RandomNormal" => new RandomNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(), | "RandomNormal" => new RandomNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(), | ||||
| info.config["seed"].ToObject<int?>()), | info.config["seed"].ToObject<int?>()), | ||||
| "RandomUniform" => new RandomUniform(minval:info.config["minval"].ToObject<float>(), | |||||
| maxval:info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()), | |||||
| "RandomUniform" => new RandomUniform(minval: info.config["minval"].ToObject<float>(), | |||||
| maxval: info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()), | |||||
| "TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(), | "TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(), | ||||
| info.config["seed"].ToObject<int?>()), | info.config["seed"].ToObject<int?>()), | ||||
| "VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject<float>(), info.config["mode"].ToObject<string>(), | "VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject<float>(), info.config["mode"].ToObject<string>(), | ||||
| @@ -0,0 +1,75 @@ | |||||
| using Newtonsoft.Json.Linq; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Saving.Json | |||||
| { | |||||
| public class CustomizedKerasShapesWrapperJsonConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(KerasShapesWrapper); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| if (value is null) | |||||
| { | |||||
| JToken.FromObject(null).WriteTo(writer); | |||||
| return; | |||||
| } | |||||
| if (value is not KerasShapesWrapper wrapper) | |||||
| { | |||||
| throw new TypeError($"Expected `KerasShapesWrapper` to be serialized, bug got {value.GetType()}"); | |||||
| } | |||||
| if (wrapper.Shapes.Length == 0) | |||||
| { | |||||
| JToken.FromObject(null).WriteTo(writer); | |||||
| } | |||||
| else if (wrapper.Shapes.Length == 1) | |||||
| { | |||||
| JToken.FromObject(wrapper.Shapes[0]).WriteTo(writer); | |||||
| } | |||||
| else | |||||
| { | |||||
| JToken.FromObject(wrapper.Shapes).WriteTo(writer); | |||||
| } | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| if (reader.TokenType == JsonToken.StartArray) | |||||
| { | |||||
| TensorShapeConfig[] shapes = serializer.Deserialize<TensorShapeConfig[]>(reader); | |||||
| if (shapes is null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| return new KerasShapesWrapper(shapes); | |||||
| } | |||||
| else if (reader.TokenType == JsonToken.StartObject) | |||||
| { | |||||
| var shape = serializer.Deserialize<TensorShapeConfig>(reader); | |||||
| if (shape is null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| return new KerasShapesWrapper(shape); | |||||
| } | |||||
| else if (reader.TokenType == JsonToken.Null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new ValueError($"Cannot deserialize the token type {reader.TokenType}"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -7,7 +7,7 @@ using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| namespace Tensorflow.Keras.Common | |||||
| namespace Tensorflow.Keras.Saving.Common | |||||
| { | { | ||||
| public class CustomizedNodeConfigJsonConverter : JsonConverter | public class CustomizedNodeConfigJsonConverter : JsonConverter | ||||
| { | { | ||||
| @@ -46,10 +46,10 @@ namespace Tensorflow.Keras.Common | |||||
| { | { | ||||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | throw new ValueError("Cannot deserialize 'null' to `Shape`."); | ||||
| } | } | ||||
| if(values.Length == 1) | |||||
| if (values.Length == 1) | |||||
| { | { | ||||
| var array = values[0] as JArray; | var array = values[0] as JArray; | ||||
| if(array is null) | |||||
| if (array is null) | |||||
| { | { | ||||
| throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | ||||
| } | } | ||||
| @@ -5,14 +5,14 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Common | |||||
| namespace Tensorflow.Keras.Saving.Common | |||||
| { | { | ||||
| class ShapeInfoFromPython | class ShapeInfoFromPython | ||||
| { | { | ||||
| public string class_name { get; set; } | public string class_name { get; set; } | ||||
| public long?[] items { get; set; } | public long?[] items { get; set; } | ||||
| } | } | ||||
| public class CustomizedShapeJsonConverter: JsonConverter | |||||
| public class CustomizedShapeJsonConverter : JsonConverter | |||||
| { | { | ||||
| public override bool CanConvert(Type objectType) | public override bool CanConvert(Type objectType) | ||||
| { | { | ||||
| @@ -25,12 +25,12 @@ namespace Tensorflow.Keras.Common | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | ||||
| { | { | ||||
| if(value is null) | |||||
| if (value is null) | |||||
| { | { | ||||
| var token = JToken.FromObject(null); | var token = JToken.FromObject(null); | ||||
| token.WriteTo(writer); | token.WriteTo(writer); | ||||
| } | } | ||||
| else if(value is not Shape) | |||||
| else if (value is not Shape) | |||||
| { | { | ||||
| throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); | throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); | ||||
| } | } | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Common | |||||
| { | { | ||||
| var shape = (value as Shape)!; | var shape = (value as Shape)!; | ||||
| long?[] dims = new long?[shape.ndim]; | long?[] dims = new long?[shape.ndim]; | ||||
| for(int i = 0; i < dims.Length; i++) | |||||
| for (int i = 0; i < dims.Length; i++) | |||||
| { | { | ||||
| if (shape.dims[i] == -1) | if (shape.dims[i] == -1) | ||||
| { | { | ||||
| @@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Common | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
| { | { | ||||
| long?[] dims; | long?[] dims; | ||||
| try | |||||
| if (reader.TokenType == JsonToken.StartObject) | |||||
| { | { | ||||
| var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader); | var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader); | ||||
| if (shape_info_from_python is null) | if (shape_info_from_python is null) | ||||
| @@ -70,14 +70,22 @@ namespace Tensorflow.Keras.Common | |||||
| } | } | ||||
| dims = shape_info_from_python.items; | dims = shape_info_from_python.items; | ||||
| } | } | ||||
| catch(JsonSerializationException) | |||||
| else if (reader.TokenType == JsonToken.StartArray) | |||||
| { | { | ||||
| dims = serializer.Deserialize<long?[]>(reader); | dims = serializer.Deserialize<long?[]>(reader); | ||||
| } | } | ||||
| else if (reader.TokenType == JsonToken.Null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new ValueError($"Cannot deserialize the token {reader} as Shape."); | |||||
| } | |||||
| long[] convertedDims = new long[dims.Length]; | long[] convertedDims = new long[dims.Length]; | ||||
| for(int i = 0; i < dims.Length; i++) | |||||
| for (int i = 0; i < dims.Length; i++) | |||||
| { | { | ||||
| convertedDims[i] = dims[i] ?? (-1); | |||||
| convertedDims[i] = dims[i] ?? -1; | |||||
| } | } | ||||
| return new Shape(convertedDims); | return new Shape(convertedDims); | ||||
| } | } | ||||
| @@ -0,0 +1,60 @@ | |||||
| using Newtonsoft.Json.Linq; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using System.Diagnostics; | |||||
| using OneOf.Types; | |||||
| using Tensorflow.Keras.Saving.Json; | |||||
| namespace Tensorflow.Keras.Saving | |||||
| { | |||||
| [JsonConverter(typeof(CustomizedKerasShapesWrapperJsonConverter))] | |||||
| public class KerasShapesWrapper | |||||
| { | |||||
| public TensorShapeConfig[] Shapes { get; set; } | |||||
| public KerasShapesWrapper(Shape shape) | |||||
| { | |||||
| Shapes = new TensorShapeConfig[] { shape }; | |||||
| } | |||||
| public KerasShapesWrapper(TensorShapeConfig shape) | |||||
| { | |||||
| Shapes = new TensorShapeConfig[] { shape }; | |||||
| } | |||||
| public KerasShapesWrapper(TensorShapeConfig[] shapes) | |||||
| { | |||||
| Shapes = shapes; | |||||
| } | |||||
| public KerasShapesWrapper(IEnumerable<Shape> shape) | |||||
| { | |||||
| Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray(); | |||||
| } | |||||
| public Shape ToSingleShape() | |||||
| { | |||||
| Debug.Assert(Shapes.Length == 1); | |||||
| var shape_config = Shapes[0]; | |||||
| Debug.Assert(shape_config is not null); | |||||
| return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray()); | |||||
| } | |||||
| public Shape[] ToShapeArray() | |||||
| { | |||||
| return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | |||||
| } | |||||
| public static implicit operator KerasShapesWrapper(Shape shape) | |||||
| { | |||||
| return new KerasShapesWrapper(shape); | |||||
| } | |||||
| public static implicit operator KerasShapesWrapper(TensorShapeConfig shape) | |||||
| { | |||||
| return new KerasShapesWrapper(shape); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -9,7 +9,7 @@ using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| public class ModelConfig : IKerasConfig | |||||
| public class FunctionalConfig : IKerasConfig | |||||
| { | { | ||||
| [JsonProperty("name")] | [JsonProperty("name")] | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| @@ -2,7 +2,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Common; | |||||
| using Tensorflow.Keras.Saving.Common; | |||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| @@ -19,7 +19,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Common; | |||||
| using Tensorflow.Keras.Saving.Common; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -19,7 +19,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Common; | |||||
| using Tensorflow.Keras.Saving.Common; | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -16,7 +16,7 @@ | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.Common; | |||||
| using Tensorflow.Keras.Saving.Common; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -80,9 +80,9 @@ namespace Tensorflow | |||||
| public Shape OutputShape => throw new NotImplementedException(); | public Shape OutputShape => throw new NotImplementedException(); | ||||
| public Shape BatchInputShape => throw new NotImplementedException(); | |||||
| public KerasShapesWrapper BatchInputShape => throw new NotImplementedException(); | |||||
| public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); | |||||
| public KerasShapesWrapper BuildInputShape => throw new NotImplementedException(); | |||||
| public TF_DataType DType => throw new NotImplementedException(); | public TF_DataType DType => throw new NotImplementedException(); | ||||
| protected bool built = false; | protected bool built = false; | ||||
| @@ -162,6 +162,11 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public void build(KerasShapesWrapper input_shape) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public Trackable GetTrackable() { throw new NotImplementedException(); } | public Trackable GetTrackable() { throw new NotImplementedException(); } | ||||
| public void adapt(Tensor data, int? batch_size = null, int? steps = null) | public void adapt(Tensor data, int? batch_size = null, int? steps = null) | ||||
| @@ -1,5 +1,5 @@ | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using Tensorflow.Keras.Common; | |||||
| using Tensorflow.Keras.Saving.Common; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| public partial class Functional | public partial class Functional | ||||
| { | { | ||||
| public static Functional from_config(ModelConfig config) | |||||
| public static Functional from_config(FunctionalConfig config) | |||||
| { | { | ||||
| var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config); | var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config); | ||||
| var model = new Functional(input_tensors, output_tensors, name: config.Name); | var model = new Functional(input_tensors, output_tensors, name: config.Name); | ||||
| @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="config"></param> | /// <param name="config"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null) | |||||
| public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(FunctionalConfig config, Dictionary<string, ILayer>? created_layers = null) | |||||
| { | { | ||||
| // Layer instances created during the graph reconstruction process. | // Layer instances created during the graph reconstruction process. | ||||
| created_layers = created_layers ?? new Dictionary<string, ILayer>(); | created_layers = created_layers ?? new Dictionary<string, ILayer>(); | ||||
| @@ -19,9 +19,9 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <summary> | /// <summary> | ||||
| /// Builds the config, which consists of the node graph and serialized layers. | /// Builds the config, which consists of the node graph and serialized layers. | ||||
| /// </summary> | /// </summary> | ||||
| ModelConfig get_network_config() | |||||
| FunctionalConfig get_network_config() | |||||
| { | { | ||||
| var config = new ModelConfig | |||||
| var config = new FunctionalConfig | |||||
| { | { | ||||
| Name = name | Name = name | ||||
| }; | }; | ||||
| @@ -211,9 +211,9 @@ namespace Tensorflow.Keras.Engine | |||||
| protected bool computePreviousMask; | protected bool computePreviousMask; | ||||
| protected List<Operation> updates; | protected List<Operation> updates; | ||||
| public Shape BatchInputShape => args.BatchInputShape; | |||||
| protected TensorShapeConfig _buildInputShape = null; | |||||
| public TensorShapeConfig BuildInputShape => _buildInputShape; | |||||
| public KerasShapesWrapper BatchInputShape => args.BatchInputShape; | |||||
| protected KerasShapesWrapper _buildInputShape = null; | |||||
| public KerasShapesWrapper BuildInputShape => _buildInputShape; | |||||
| List<INode> inboundNodes; | List<INode> inboundNodes; | ||||
| public List<INode> InboundNodes => inboundNodes; | public List<INode> InboundNodes => inboundNodes; | ||||
| @@ -284,7 +284,7 @@ namespace Tensorflow.Keras.Engine | |||||
| // Manage input shape information if passed. | // Manage input shape information if passed. | ||||
| if (args.BatchInputShape == null && args.InputShape != null) | if (args.BatchInputShape == null && args.InputShape != null) | ||||
| { | { | ||||
| args.BatchInputShape = new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); | |||||
| args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -363,7 +363,7 @@ namespace Tensorflow.Keras.Engine | |||||
| tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | ||||
| } | } | ||||
| build(inputs.shape); | |||||
| build(new KerasShapesWrapper(inputs.shape)); | |||||
| if (need_restore_mode) | if (need_restore_mode) | ||||
| tf.Context.restore_mode(); | tf.Context.restore_mode(); | ||||
| @@ -371,7 +371,7 @@ namespace Tensorflow.Keras.Engine | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| public virtual void build(Shape input_shape) | |||||
| public virtual void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| _buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
| built = true; | built = true; | ||||
| @@ -1,6 +1,8 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Keras.Utils; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| @@ -8,22 +10,40 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| public partial class Model | public partial class Model | ||||
| { | { | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| if (this is Functional || this is Sequential) | |||||
| if (_is_graph_network || this is Functional || this is Sequential) | |||||
| { | { | ||||
| base.build(input_shape); | base.build(input_shape); | ||||
| return; | return; | ||||
| } | } | ||||
| var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); | |||||
| graph.as_default(); | |||||
| var x = tf.placeholder(DType, input_shape); | |||||
| Call(x, training: false); | |||||
| graph.Exit(); | |||||
| if(input_shape is not null && this.inputs is null) | |||||
| { | |||||
| var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); | |||||
| graph.as_default(); | |||||
| var shapes = input_shape.ToShapeArray(); | |||||
| var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x))); | |||||
| try | |||||
| { | |||||
| Call(x, training: false); | |||||
| } | |||||
| catch (InvalidArgumentError) | |||||
| { | |||||
| throw new ValueError("You cannot build your model by calling `build` " + | |||||
| "if your layers do not support float type inputs. " + | |||||
| "Instead, in order to instantiate and build your " + | |||||
| "model, `call` your model on real tensor data (of the correct dtype)."); | |||||
| } | |||||
| catch (TypeError) | |||||
| { | |||||
| throw new ValueError("You cannot build your model by calling `build` " + | |||||
| "if your layers do not support float type inputs. " + | |||||
| "Instead, in order to instantiate and build your " + | |||||
| "model, `call` your model on real tensor data (of the correct dtype)."); | |||||
| } | |||||
| graph.Exit(); | |||||
| } | |||||
| base.build(input_shape); | base.build(input_shape); | ||||
| } | } | ||||
| @@ -92,7 +92,7 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| // Instantiate an input layer. | // Instantiate an input layer. | ||||
| var x = keras.Input( | var x = keras.Input( | ||||
| batch_input_shape: layer.BatchInputShape, | |||||
| batch_input_shape: layer.BatchInputShape.ToSingleShape(), | |||||
| dtype: layer.DType, | dtype: layer.DType, | ||||
| name: layer.Name + "_input"); | name: layer.Name + "_input"); | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Layers { | namespace Tensorflow.Keras.Layers { | ||||
| @@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| if (alpha < 0f) | if (alpha < 0f) | ||||
| { | { | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Layers { | namespace Tensorflow.Keras.Layers { | ||||
| @@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| { | { | ||||
| // Exponential has no args | // Exponential has no args | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| base.build(input_shape); | base.build(input_shape); | ||||
| } | } | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Layers { | namespace Tensorflow.Keras.Layers { | ||||
| @@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| public SELU ( LayerArgs args ) : base(args) { | public SELU ( LayerArgs args ) : base(args) { | ||||
| // SELU has no arguments | // SELU has no arguments | ||||
| } | } | ||||
| public override void build(Shape input_shape) { | |||||
| public override void build(KerasShapesWrapper input_shape) { | |||||
| if ( alpha < 0f ) { | if ( alpha < 0f ) { | ||||
| throw new ValueError("Alpha must be a number greater than 0."); | throw new ValueError("Alpha must be a number greater than 0."); | ||||
| } | } | ||||
| @@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Layers | |||||
| } | } | ||||
| // Creates variable when `use_scale` is True or `score_mode` is `concat`. | // Creates variable when `use_scale` is True or `score_mode` is `concat`. | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| if (this.use_scale) | if (this.use_scale) | ||||
| this.scale = this.add_weight(name: "scale", | this.scale = this.add_weight(name: "scale", | ||||
| @@ -19,6 +19,7 @@ using static Tensorflow.Binding; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -58,13 +59,14 @@ namespace Tensorflow.Keras.Layers | |||||
| return args; | return args; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| var single_shape = input_shape.ToSingleShape(); | |||||
| if (len(input_shape) != 4) | if (len(input_shape) != 4) | ||||
| throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}"); | throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}"); | ||||
| var channel_axis = _get_channel_axis(); | var channel_axis = _get_channel_axis(); | ||||
| var input_dim = input_shape[-1]; | |||||
| var input_dim = single_shape[-1]; | |||||
| var kernel_shape = new Shape(kernel_size[0], kernel_size[1], filters, input_dim); | var kernel_shape = new Shape(kernel_size[0], kernel_size[1], filters, input_dim); | ||||
| kernel = add_weight(name: "kernel", | kernel = add_weight(name: "kernel", | ||||
| @@ -19,6 +19,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -57,12 +58,13 @@ namespace Tensorflow.Keras.Layers | |||||
| _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); | _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| int channel_axis = data_format == "channels_first" ? 1 : -1; | int channel_axis = data_format == "channels_first" ? 1 : -1; | ||||
| var single_shape = input_shape.ToSingleShape(); | |||||
| var input_channel = channel_axis < 0 ? | var input_channel = channel_axis < 0 ? | ||||
| input_shape.dims[input_shape.ndim + channel_axis] : | |||||
| input_shape.dims[channel_axis]; | |||||
| single_shape.dims[single_shape.ndim + channel_axis] : | |||||
| single_shape.dims[channel_axis]; | |||||
| Shape kernel_shape = kernel_size.dims.concat(new long[] { input_channel / args.Groups, filters }); | Shape kernel_shape = kernel_size.dims.concat(new long[] { input_channel / args.Groups, filters }); | ||||
| kernel = add_weight(name: "kernel", | kernel = add_weight(name: "kernel", | ||||
| shape: kernel_shape, | shape: kernel_shape, | ||||
| @@ -16,9 +16,11 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| @@ -41,10 +43,12 @@ namespace Tensorflow.Keras.Layers | |||||
| this.inputSpec = new InputSpec(min_ndim: 2); | this.inputSpec = new InputSpec(min_ndim: 2); | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| _buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
| var last_dim = input_shape.dims.Last(); | |||||
| Debug.Assert(input_shape.Shapes.Length <= 1); | |||||
| var single_shape = input_shape.ToSingleShape(); | |||||
| var last_dim = single_shape.dims.Last(); | |||||
| var axes = new Dictionary<int, int>(); | var axes = new Dictionary<int, int>(); | ||||
| axes[-1] = (int)last_dim; | axes[-1] = (int)last_dim; | ||||
| inputSpec = new InputSpec(min_ndim: 2, axes: axes); | inputSpec = new InputSpec(min_ndim: 2, axes: axes); | ||||
| @@ -6,6 +6,7 @@ using System.Linq; | |||||
| using System.Text.RegularExpressions; | using System.Text.RegularExpressions; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.ArgsDefinition.Core; | using Tensorflow.Keras.ArgsDefinition.Core; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -119,9 +120,10 @@ namespace Tensorflow.Keras.Layers | |||||
| this.bias_constraint = args.BiasConstraint; | this.bias_constraint = args.BiasConstraint; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, input_shape, this.partial_output_shape); | |||||
| var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, | |||||
| input_shape.ToSingleShape(), this.partial_output_shape); | |||||
| var kernel_shape = shape_data.Item1; | var kernel_shape = shape_data.Item1; | ||||
| var bias_shape = shape_data.Item2; | var bias_shape = shape_data.Item2; | ||||
| this.full_output_shape = shape_data.Item3; | this.full_output_shape = shape_data.Item3; | ||||
| @@ -17,6 +17,7 @@ | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| @@ -48,13 +49,13 @@ namespace Tensorflow.Keras.Layers | |||||
| args.InputShape = args.InputLength; | args.InputShape = args.InputLength; | ||||
| if (args.BatchInputShape == null) | if (args.BatchInputShape == null) | ||||
| args.BatchInputShape = new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); | |||||
| args.BatchInputShape = new KerasShapesWrapper(new long[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray()); | |||||
| embeddings_initializer = args.EmbeddingsInitializer ?? tf.random_uniform_initializer; | embeddings_initializer = args.EmbeddingsInitializer ?? tf.random_uniform_initializer; | ||||
| SupportsMasking = mask_zero; | SupportsMasking = mask_zero; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| tf.Context.eager_mode(); | tf.Context.eager_mode(); | ||||
| embeddings = add_weight(shape: (input_dim, output_dim), | embeddings = add_weight(shape: (input_dim, output_dim), | ||||
| @@ -40,10 +40,10 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| SupportsMasking = true; | SupportsMasking = true; | ||||
| if (BatchInputShape != null) | |||||
| if (BatchInputShape is not null) | |||||
| { | { | ||||
| args.BatchSize = (int)BatchInputShape.dims[0]; | |||||
| args.InputShape = BatchInputShape.dims.Skip(1).ToArray(); | |||||
| args.BatchSize = (int)(BatchInputShape.ToSingleShape().dims[0]); | |||||
| args.InputShape = BatchInputShape.ToSingleShape().dims.Skip(1).ToArray(); | |||||
| } | } | ||||
| // moved to base class | // moved to base class | ||||
| @@ -63,9 +63,8 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| if (args.InputShape != null) | if (args.InputShape != null) | ||||
| { | { | ||||
| args.BatchInputShape = new long[] { args.BatchSize } | |||||
| .Concat(args.InputShape.dims) | |||||
| .ToArray(); | |||||
| args.BatchInputShape = new Saving.KerasShapesWrapper(new long[] { args.BatchSize } | |||||
| .Concat(args.InputShape.dims).ToArray()); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -76,7 +75,7 @@ namespace Tensorflow.Keras.Layers | |||||
| graph.as_default(); | graph.as_default(); | ||||
| args.InputTensor = keras.backend.placeholder( | args.InputTensor = keras.backend.placeholder( | ||||
| shape: BatchInputShape, | |||||
| shape: BatchInputShape.ToSingleShape(), | |||||
| dtype: DType, | dtype: DType, | ||||
| name: Name, | name: Name, | ||||
| sparse: args.Sparse, | sparse: args.Sparse, | ||||
| @@ -4,6 +4,7 @@ using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| @@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Layers | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| /*var shape_set = new HashSet<Shape>(); | /*var shape_set = new HashSet<Shape>(); | ||||
| var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray(); | var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray(); | ||||
| @@ -4,6 +4,7 @@ using System.Text; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers | |||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| // output_shape = input_shape.dims[1^]; | // output_shape = input_shape.dims[1^]; | ||||
| _buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
| @@ -19,6 +19,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -53,9 +54,10 @@ namespace Tensorflow.Keras.Layers | |||||
| axis = args.Axis.dims.Select(x => (int)x).ToArray(); | axis = args.Axis.dims.Select(x => (int)x).ToArray(); | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| var ndims = input_shape.ndim; | |||||
| var single_shape = input_shape.ToSingleShape(); | |||||
| var ndims = single_shape.ndim; | |||||
| foreach (var (idx, x) in enumerate(axis)) | foreach (var (idx, x) in enumerate(axis)) | ||||
| if (x < 0) | if (x < 0) | ||||
| args.Axis.dims[idx] = axis[idx] = ndims + x; | args.Axis.dims[idx] = axis[idx] = ndims + x; | ||||
| @@ -74,7 +76,7 @@ namespace Tensorflow.Keras.Layers | |||||
| var axis_to_dim = new Dictionary<int, int>(); | var axis_to_dim = new Dictionary<int, int>(); | ||||
| foreach (var x in axis) | foreach (var x in axis) | ||||
| axis_to_dim[x] = (int)input_shape[x]; | |||||
| axis_to_dim[x] = (int)single_shape[x]; | |||||
| inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); | inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); | ||||
| var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | ||||
| @@ -19,6 +19,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -49,16 +50,17 @@ namespace Tensorflow.Keras.Layers | |||||
| axis = args.Axis.axis; | axis = args.Axis.axis; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| var ndims = input_shape.ndim; | |||||
| var single_shape = input_shape.ToSingleShape(); | |||||
| var ndims = single_shape.ndim; | |||||
| foreach (var (idx, x) in enumerate(axis)) | foreach (var (idx, x) in enumerate(axis)) | ||||
| if (x < 0) | if (x < 0) | ||||
| axis[idx] = ndims + x; | axis[idx] = ndims + x; | ||||
| var axis_to_dim = new Dictionary<int, int>(); | var axis_to_dim = new Dictionary<int, int>(); | ||||
| foreach (var x in axis) | foreach (var x in axis) | ||||
| axis_to_dim[x] = (int)input_shape[x]; | |||||
| axis_to_dim[x] = (int)single_shape[x]; | |||||
| inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); | inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); | ||||
| var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | ||||
| @@ -15,6 +15,7 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -45,10 +46,11 @@ namespace Tensorflow.Keras.Layers | |||||
| input_variance = args.Variance; | input_variance = args.Variance; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| base.build(input_shape); | base.build(input_shape); | ||||
| var ndim = input_shape.ndim; | |||||
| var single_shape = input_shape.ToSingleShape(); | |||||
| var ndim = single_shape.ndim; | |||||
| foreach (var (idx, x) in enumerate(axis)) | foreach (var (idx, x) in enumerate(axis)) | ||||
| if (x < 0) | if (x < 0) | ||||
| axis[idx] = ndim + x; | axis[idx] = ndim + x; | ||||
| @@ -57,8 +59,8 @@ namespace Tensorflow.Keras.Layers | |||||
| _reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray(); | _reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray(); | ||||
| var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray(); | var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray(); | ||||
| // Broadcast any reduced axes. | // Broadcast any reduced axes. | ||||
| _broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray()); | |||||
| var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray(); | |||||
| _broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? single_shape.dims[d] : 1).ToArray()); | |||||
| var mean_and_var_shape = _keep_axis.Select(d => single_shape.dims[d]).ToArray(); | |||||
| var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | ||||
| var param_shape = input_shape; | var param_shape = input_shape; | ||||
| @@ -77,8 +77,8 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| var data_shape = data.shape; | var data_shape = data.shape; | ||||
| var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray(); | var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray(); | ||||
| _args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones); | |||||
| build(data_shape); | |||||
| _args.BatchInputShape = BatchInputShape ?? new Saving.KerasShapesWrapper(new Shape(data_shape_nones)); | |||||
| build(new Saving.KerasShapesWrapper(data_shape)); | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| } | } | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| @@ -35,12 +36,12 @@ namespace Tensorflow.Keras.Layers | |||||
| var shape = data.output_shapes[0]; | var shape = data.output_shapes[0]; | ||||
| if (shape.ndim == 1) | if (shape.ndim == 1) | ||||
| data = data.map(tensor => array_ops.expand_dims(tensor, -1)); | data = data.map(tensor => array_ops.expand_dims(tensor, -1)); | ||||
| build(data.variant_tensor.shape); | |||||
| build(new KerasShapesWrapper(data.variant_tensor.shape)); | |||||
| var preprocessed_inputs = data.map(_preprocess); | var preprocessed_inputs = data.map(_preprocess); | ||||
| _index_lookup_layer.adapt(preprocessed_inputs); | _index_lookup_layer.adapt(preprocessed_inputs); | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| base.build(input_shape); | base.build(input_shape); | ||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using Tensorflow.Keras.ArgsDefinition.Reshaping; | using Tensorflow.Keras.ArgsDefinition.Reshaping; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers.Reshaping | namespace Tensorflow.Keras.Layers.Reshaping | ||||
| { | { | ||||
| @@ -11,7 +12,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| if (args.cropping.rank != 1) | if (args.cropping.rank != 1) | ||||
| { | { | ||||
| @@ -1,5 +1,6 @@ | |||||
| using Tensorflow.Keras.ArgsDefinition.Reshaping; | using Tensorflow.Keras.ArgsDefinition.Reshaping; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers.Reshaping | namespace Tensorflow.Keras.Layers.Reshaping | ||||
| { | { | ||||
| @@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||||
| { | { | ||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
| @@ -1,5 +1,6 @@ | |||||
| using Tensorflow.Keras.ArgsDefinition.Reshaping; | using Tensorflow.Keras.ArgsDefinition.Reshaping; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers.Reshaping | namespace Tensorflow.Keras.Layers.Reshaping | ||||
| { | { | ||||
| @@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
| @@ -5,6 +5,7 @@ using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers { | namespace Tensorflow.Keras.Layers { | ||||
| public class Permute : Layer | public class Permute : Layer | ||||
| @@ -14,14 +15,15 @@ namespace Tensorflow.Keras.Layers { | |||||
| { | { | ||||
| this.dims = args.dims; | this.dims = args.dims; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| var rank = input_shape.rank; | |||||
| var single_shape = input_shape.ToSingleShape(); | |||||
| var rank = single_shape.rank; | |||||
| if (dims.Length != rank - 1) | if (dims.Length != rank - 1) | ||||
| { | { | ||||
| throw new ValueError("Dimensions must match."); | throw new ValueError("Dimensions must match."); | ||||
| } | } | ||||
| permute = new int[input_shape.rank]; | |||||
| permute = new int[single_shape.rank]; | |||||
| dims.CopyTo(permute, 1); | dims.CopyTo(permute, 1); | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| // from tensorflow.python.distribute import distribution_strategy_context as ds_context; | // from tensorflow.python.distribute import distribution_strategy_context as ds_context; | ||||
| namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
| @@ -36,7 +37,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| //} | //} | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| if (!cell.Built) | if (!cell.Built) | ||||
| { | { | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System.Data; | using System.Data; | ||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
| using static HDF.PInvoke.H5Z; | using static HDF.PInvoke.H5Z; | ||||
| using static Tensorflow.ApiDef.Types; | using static Tensorflow.ApiDef.Types; | ||||
| @@ -14,12 +15,13 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| var input_dim = input_shape[-1]; | |||||
| var single_shape = input_shape.ToSingleShape(); | |||||
| var input_dim = single_shape[-1]; | |||||
| _buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
| kernel = add_weight("kernel", (input_shape[-1], args.Units), | |||||
| kernel = add_weight("kernel", (single_shape[-1], args.Units), | |||||
| initializer: args.KernelInitializer | initializer: args.KernelInitializer | ||||
| //regularizer = self.kernel_regularizer, | //regularizer = self.kernel_regularizer, | ||||
| //constraint = self.kernel_constraint, | //constraint = self.kernel_constraint, | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
| { | { | ||||
| @@ -18,11 +19,12 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| public override void build(Shape input_shape) | |||||
| public override void build(KerasShapesWrapper input_shape) | |||||
| { | { | ||||
| var input_dim = input_shape[-1]; | |||||
| var single_shape = input_shape.ToSingleShape(); | |||||
| var input_dim = single_shape[-1]; | |||||
| kernel = add_weight("kernel", (input_shape[-1], args.Units), | |||||
| kernel = add_weight("kernel", (single_shape[-1], args.Units), | |||||
| initializer: args.KernelInitializer | initializer: args.KernelInitializer | ||||
| ); | ); | ||||
| @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Models | |||||
| { | { | ||||
| public class ModelsApi: IModelsApi | public class ModelsApi: IModelsApi | ||||
| { | { | ||||
| public Functional from_config(ModelConfig config) | |||||
| public Functional from_config(FunctionalConfig config) | |||||
| => Functional.from_config(config); | => Functional.from_config(config); | ||||
| public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null) | public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null) | ||||
| @@ -22,16 +22,19 @@ namespace Tensorflow.Keras.Saving | |||||
| public int SharedObjectId { get; set; } | public int SharedObjectId { get; set; } | ||||
| [JsonProperty("must_restore_from_config")] | [JsonProperty("must_restore_from_config")] | ||||
| public bool MustRestoreFromConfig { get; set; } | public bool MustRestoreFromConfig { get; set; } | ||||
| [JsonProperty("config")] | |||||
| public JObject Config { get; set; } | public JObject Config { get; set; } | ||||
| [JsonProperty("build_input_shape")] | [JsonProperty("build_input_shape")] | ||||
| public TensorShapeConfig BuildInputShape { get; set; } | |||||
| public KerasShapesWrapper BuildInputShape { get; set; } | |||||
| [JsonProperty("batch_input_shape")] | [JsonProperty("batch_input_shape")] | ||||
| public TensorShapeConfig BatchInputShape { get; set; } | |||||
| public KerasShapesWrapper BatchInputShape { get; set; } | |||||
| [JsonProperty("activity_regularizer")] | [JsonProperty("activity_regularizer")] | ||||
| public IRegularizer ActivityRegularizer { get; set; } | public IRegularizer ActivityRegularizer { get; set; } | ||||
| [JsonProperty("input_spec")] | [JsonProperty("input_spec")] | ||||
| public JToken InputSpec { get; set; } | public JToken InputSpec { get; set; } | ||||
| [JsonProperty("stateful")] | [JsonProperty("stateful")] | ||||
| public bool? Stateful { get; set; } | public bool? Stateful { get; set; } | ||||
| [JsonProperty("model_config")] | |||||
| public KerasModelConfig? ModelConfig { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,16 @@ | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Saving | |||||
| { | |||||
| public class KerasModelConfig | |||||
| { | |||||
| [JsonProperty("class_name")] | |||||
| public string ClassName { get; set; } | |||||
| [JsonProperty("config")] | |||||
| public JObject Config { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -8,6 +8,7 @@ using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Reflection; | using System.Reflection; | ||||
| using System.Text.RegularExpressions; | using System.Text.RegularExpressions; | ||||
| using Tensorflow.Extensions; | |||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| @@ -356,7 +357,7 @@ namespace Tensorflow.Keras.Saving | |||||
| var (obj, setter) = _revive_from_config(identifier, metadata, node_id); | var (obj, setter) = _revive_from_config(identifier, metadata, node_id); | ||||
| if (obj is null) | if (obj is null) | ||||
| { | { | ||||
| (obj, setter) = _revive_custom_object(identifier, metadata); | |||||
| (obj, setter) = revive_custom_object(identifier, metadata); | |||||
| } | } | ||||
| if(obj is null) | if(obj is null) | ||||
| { | { | ||||
| @@ -398,7 +399,7 @@ namespace Tensorflow.Keras.Saving | |||||
| return (obj, setter); | return (obj, setter); | ||||
| } | } | ||||
| private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | |||||
| private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata) | |||||
| { | { | ||||
| if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | ||||
| { | { | ||||
| @@ -437,7 +438,7 @@ namespace Tensorflow.Keras.Saving | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>()); | |||||
| model = new Functional(new Tensors(), new Tensors(), config.TryGetOrReturnNull<string>("name")); | |||||
| } | } | ||||
| // Record this model and its layers. This will later be used to reconstruct | // Record this model and its layers. This will later be used to reconstruct | ||||
| @@ -619,7 +620,7 @@ namespace Tensorflow.Keras.Saving | |||||
| } | } | ||||
| } | } | ||||
| private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||||
| private bool _try_build_layer(Layer obj, int node_id, KerasShapesWrapper build_input_shape) | |||||
| { | { | ||||
| if (obj.Built) | if (obj.Built) | ||||
| return true; | return true; | ||||
| @@ -679,10 +680,10 @@ namespace Tensorflow.Keras.Saving | |||||
| return inputs; | return inputs; | ||||
| } | } | ||||
| private Shape _infer_input_shapes(int layer_node_id) | |||||
| private KerasShapesWrapper _infer_input_shapes(int layer_node_id) | |||||
| { | { | ||||
| var inputs = _infer_inputs(layer_node_id); | var inputs = _infer_inputs(layer_node_id); | ||||
| return nest.map_structure(x => x.shape, inputs); | |||||
| return new KerasShapesWrapper(nest.map_structure(x => x.shape, inputs)); | |||||
| } | } | ||||
| private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | ||||
| @@ -173,6 +173,11 @@ namespace Tensorflow.Keras.Utils | |||||
| obj is not Type; | obj is not Type; | ||||
| } | } | ||||
| public static Tensor generate_placeholders_from_shape(Shape shape) | |||||
| { | |||||
| return array_ops.placeholder(keras.backend.floatx(), shape); | |||||
| } | |||||
| // recusive | // recusive | ||||
| static bool uses_keras_history(Tensor op_input) | static bool uses_keras_history(Tensor op_input) | ||||
| { | { | ||||
| @@ -102,9 +102,9 @@ namespace Tensorflow.Keras.Utils | |||||
| return args as LayerArgs; | return args as LayerArgs; | ||||
| } | } | ||||
| public static ModelConfig deserialize_model_config(JToken json) | |||||
| public static FunctionalConfig deserialize_model_config(JToken json) | |||||
| { | { | ||||
| ModelConfig config = new ModelConfig(); | |||||
| FunctionalConfig config = new FunctionalConfig(); | |||||
| config.Name = json["name"].ToObject<string>(); | config.Name = json["name"].ToObject<string>(); | ||||
| config.Layers = new List<LayerConfig>(); | config.Layers = new List<LayerConfig>(); | ||||
| var layersToken = json["layers"]; | var layersToken = json["layers"]; | ||||