diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f16d54d1..1e473d75 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras List TrainableWeights { get; } List NonTrainableWeights { get; } List Weights { get; set; } - void set_weights(List weights); + void set_weights(IEnumerable weights); List get_weights(); Shape OutputShape { get; } Shape BatchInputShape { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 93e0edf0..5847e31a 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -75,7 +75,7 @@ namespace Tensorflow public List Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } public List get_weights() => throw new NotImplementedException(); - public void set_weights(List weights) => throw new NotImplementedException(); + public void set_weights(IEnumerable weights) => throw new NotImplementedException(); public List NonTrainableWeights => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 31ac74dc..11a0584c 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -30,6 +30,9 @@ using Tensorflow.Training; using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Util; using static Tensorflow.Binding; +using Tensorflow.Framework; +using Tensorflow.Sessions; + namespace Tensorflow.Keras.Engine { @@ -134,21 +137,53 @@ namespace Tensorflow.Keras.Engine } } - public virtual void set_weights(List weights) + public virtual void set_weights(IEnumerable weights) { if (Weights.Count() != weights.Count()) throw new ValueError( $"You called `set_weights` on layer \"{this.name}\"" + $"with a weight list of length {len(weights)}, but the layer was " + $"expecting {len(Weights)} weights."); - for (int i = 0; i < weights.Count(); i++) + + + + // check if the shapes are compatible + var weight_index = 0; + foreach(var w in weights) { - if (weights[i].shape != Weights[i].shape) + if (!Weights[weight_index].AsTensor().is_compatible_with(w)) { - throw new ValueError($"Layer weight shape {weights[i].shape} not compatible with provided weight shape {Weights[i].shape}"); + throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}"); } + weight_index++; + } + + if (tf.executing_eagerly()) + { + foreach (var (this_w, v_w) in zip(Weights, weights)) + this_w.assign(v_w, read_value: true); + } + else + { + // TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed. + + //Tensors assign_ops = new Tensors(); + //var feed_dict = new FeedDict(); + + //Graph g = tf.Graph().as_default(); + //foreach (var (this_w, v_w) in zip(Weights, weights)) + //{ + // var tf_dtype = this_w.dtype; + // var placeholder_shape = v_w.shape; + // var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape); + // var assign_op = this_w.assign(assign_placeholder); + // assign_ops.Add(assign_op); + // feed_dict.Add(assign_placeholder, v_w); + //} + //var sess = tf.Session().as_default(); + //sess.run(assign_ops, feed_dict); + + //g.Exit(); } - foreach (var (this_w, v_w) in zip(Weights, weights)) - this_w.assign(v_w, read_value: true); } public List get_weights()