Add set_weights and get_weighst APIstags/v0.100.5-BERT-load
| @@ -1,5 +1,6 @@ | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.NumPy; | |||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| @@ -18,6 +19,8 @@ namespace Tensorflow.Keras | |||||
| List<IVariableV1> TrainableWeights { get; } | List<IVariableV1> TrainableWeights { get; } | ||||
| List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
| List<IVariableV1> Weights { get; set; } | List<IVariableV1> Weights { get; set; } | ||||
| void set_weights(IEnumerable<NDArray> weights); | |||||
| List<NDArray> get_weights(); | |||||
| Shape OutputShape { get; } | Shape OutputShape { get; } | ||||
| Shape BatchInputShape { get; } | Shape BatchInputShape { get; } | ||||
| TensorShapeConfig BuildInputShape { get; } | TensorShapeConfig BuildInputShape { get; } | ||||
| @@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.NumPy; | |||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| @@ -71,7 +72,10 @@ namespace Tensorflow | |||||
| public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | ||||
| public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | ||||
| public List<IVariableV1> Weights => throw new NotImplementedException(); | |||||
| public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } | |||||
| public List<NDArray> get_weights() => throw new NotImplementedException(); | |||||
| public void set_weights(IEnumerable<NDArray> weights) => throw new NotImplementedException(); | |||||
| public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | ||||
| public Shape OutputShape => throw new NotImplementedException(); | public Shape OutputShape => throw new NotImplementedException(); | ||||
| @@ -84,8 +88,6 @@ namespace Tensorflow | |||||
| protected bool built = false; | protected bool built = false; | ||||
| public bool Built => built; | public bool Built => built; | ||||
| List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } | |||||
| public RnnCell(bool trainable = true, | public RnnCell(bool trainable = true, | ||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| @@ -30,6 +30,9 @@ using Tensorflow.Training; | |||||
| using Tensorflow.Training.Saving.SavedModel; | using Tensorflow.Training.Saving.SavedModel; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Framework; | |||||
| using Tensorflow.Sessions; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -134,6 +137,62 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| } | } | ||||
| public virtual void set_weights(IEnumerable<NDArray> weights) | |||||
| { | |||||
| if (Weights.Count() != weights.Count()) throw new ValueError( | |||||
| $"You called `set_weights` on layer \"{this.name}\"" + | |||||
| $"with a weight list of length {len(weights)}, but the layer was " + | |||||
| $"expecting {len(Weights)} weights."); | |||||
| // check if the shapes are compatible | |||||
| var weight_index = 0; | |||||
| foreach(var w in weights) | |||||
| { | |||||
| if (!Weights[weight_index].AsTensor().is_compatible_with(w)) | |||||
| { | |||||
| throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}"); | |||||
| } | |||||
| weight_index++; | |||||
| } | |||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| foreach (var (this_w, v_w) in zip(Weights, weights)) | |||||
| this_w.assign(v_w, read_value: true); | |||||
| } | |||||
| else | |||||
| { | |||||
| // TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed. | |||||
| //Tensors assign_ops = new Tensors(); | |||||
| //var feed_dict = new FeedDict(); | |||||
| //Graph g = tf.Graph().as_default(); | |||||
| //foreach (var (this_w, v_w) in zip(Weights, weights)) | |||||
| //{ | |||||
| // var tf_dtype = this_w.dtype; | |||||
| // var placeholder_shape = v_w.shape; | |||||
| // var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape); | |||||
| // var assign_op = this_w.assign(assign_placeholder); | |||||
| // assign_ops.Add(assign_op); | |||||
| // feed_dict.Add(assign_placeholder, v_w); | |||||
| //} | |||||
| //var sess = tf.Session().as_default(); | |||||
| //sess.run(assign_ops, feed_dict); | |||||
| //g.Exit(); | |||||
| } | |||||
| } | |||||
| public List<NDArray> get_weights() | |||||
| { | |||||
| List<NDArray > weights = new List<NDArray>(); | |||||
| weights.AddRange(Weights.ConvertAll(x => x.numpy())); | |||||
| return weights; | |||||
| } | |||||
| protected int id; | protected int id; | ||||
| public int Id => id; | public int Id => id; | ||||
| protected string name; | protected string name; | ||||