| @@ -12,5 +12,14 @@ namespace Tensorflow.Keras | |||||
| [JsonProperty("config")] | [JsonProperty("config")] | ||||
| IDictionary<string, object> Config { get; } | IDictionary<string, object> Config { get; } | ||||
| Tensor Apply(RegularizerArgs args); | Tensor Apply(RegularizerArgs args); | ||||
| } | |||||
| } | |||||
| public interface IRegularizerApi | |||||
| { | |||||
| IRegularizer GetRegularizerFromName(string name); | |||||
| IRegularizer L1 { get; } | |||||
| IRegularizer L2 { get; } | |||||
| IRegularizer L1L2 { get; } | |||||
| } | |||||
| } | } | ||||
| @@ -9,7 +9,7 @@ namespace Tensorflow.Operations.Regularizers | |||||
| float _l1; | float _l1; | ||||
| private readonly Dictionary<string, object> _config; | private readonly Dictionary<string, object> _config; | ||||
| public string ClassName => "L2"; | |||||
| public string ClassName => "L1"; | |||||
| public virtual IDictionary<string, object> Config => _config; | public virtual IDictionary<string, object> Config => _config; | ||||
| public L1(float l1 = 0.01f) | public L1(float l1 = 0.01f) | ||||
| @@ -1,17 +1,51 @@ | |||||
| namespace Tensorflow.Keras | |||||
| using Tensorflow.Operations.Regularizers; | |||||
| namespace Tensorflow.Keras | |||||
| { | { | ||||
| public class Regularizers | |||||
| public class Regularizers: IRegularizerApi | |||||
| { | { | ||||
| private static Dictionary<string, IRegularizer> _nameActivationMap; | |||||
| public IRegularizer l1(float l1 = 0.01f) | public IRegularizer l1(float l1 = 0.01f) | ||||
| => new Tensorflow.Operations.Regularizers.L1(l1); | |||||
| => new L1(l1); | |||||
| public IRegularizer l2(float l2 = 0.01f) | public IRegularizer l2(float l2 = 0.01f) | ||||
| => new Tensorflow.Operations.Regularizers.L2(l2); | |||||
| => new L2(l2); | |||||
| //From TF source | //From TF source | ||||
| //# The default value for l1 and l2 are different from the value in l1_l2 | //# The default value for l1 and l2 are different from the value in l1_l2 | ||||
| //# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2 | //# for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2 | ||||
| //# and no l1 penalty. | //# and no l1 penalty. | ||||
| public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f) | public IRegularizer l1l2(float l1 = 0.00f, float l2 = 0.00f) | ||||
| => new Tensorflow.Operations.Regularizers.L1L2(l1, l2); | |||||
| => new L1L2(l1, l2); | |||||
| static Regularizers() | |||||
| { | |||||
| _nameActivationMap = new Dictionary<string, IRegularizer>(); | |||||
| _nameActivationMap["L1"] = new L1(); | |||||
| _nameActivationMap["L1"] = new L2(); | |||||
| _nameActivationMap["L1"] = new L1L2(); | |||||
| } | |||||
| public IRegularizer L1 => l1(); | |||||
| public IRegularizer L2 => l2(); | |||||
| public IRegularizer L1L2 => l1l2(); | |||||
| public IRegularizer GetRegularizerFromName(string name) | |||||
| { | |||||
| if (name == null) | |||||
| { | |||||
| throw new Exception($"Regularizer name cannot be null"); | |||||
| } | |||||
| if (!_nameActivationMap.TryGetValue(name, out var res)) | |||||
| { | |||||
| throw new Exception($"Regularizer {name} not found"); | |||||
| } | |||||
| else | |||||
| { | |||||
| return res; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestPlatform.Utilities; | using Microsoft.VisualStudio.TestPlatform.Utilities; | ||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Newtonsoft.Json.Linq; | using Newtonsoft.Json.Linq; | ||||
| using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Xml.Linq; | using System.Xml.Linq; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| @@ -129,6 +130,53 @@ public class ModelLoadTest | |||||
| } | } | ||||
| [TestMethod] | |||||
| public void BiasRegularizerSaveAndLoad() | |||||
| { | |||||
| var savemodel = keras.Sequential(new List<ILayer>() | |||||
| { | |||||
| tf.keras.layers.InputLayer((227, 227, 3)), | |||||
| tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"), | |||||
| tf.keras.layers.BatchNormalization(), | |||||
| tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)), | |||||
| tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1L2), | |||||
| tf.keras.layers.BatchNormalization(), | |||||
| tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L2), | |||||
| tf.keras.layers.BatchNormalization(), | |||||
| tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1), | |||||
| tf.keras.layers.BatchNormalization(), | |||||
| tf.keras.layers.MaxPooling2D((3, 3), (2, 2)), | |||||
| tf.keras.layers.Flatten(), | |||||
| tf.keras.layers.Dense(1000, activation: "linear"), | |||||
| tf.keras.layers.Softmax(1) | |||||
| }); | |||||
| savemodel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||||
| var num_epochs = 1; | |||||
| var batch_size = 8; | |||||
| var trainDataset = new RandomDataSet(new Shape(227, 227, 3), 16); | |||||
| savemodel.fit(trainDataset.Data, trainDataset.Labels, batch_size, num_epochs); | |||||
| savemodel.save(@"./bias_regularizer_save_and_load", save_format: "tf"); | |||||
| var loadModel = tf.keras.models.load_model(@"./bias_regularizer_save_and_load"); | |||||
| loadModel.summary(); | |||||
| loadModel.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||||
| var fitDataset = new RandomDataSet(new Shape(227, 227, 3), 16); | |||||
| loadModel.fit(fitDataset.Data, fitDataset.Labels, batch_size, num_epochs); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void CreateConcatenateModelSaveAndLoad() | public void CreateConcatenateModelSaveAndLoad() | ||||