diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index e50bb267..7470442b 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -11,6 +11,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "src\TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{0254BFF9-453C-4FE0-9609-3644559A79CE}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{3EEAFB06-BEF0-4261-BAAB-630EABD25290}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -33,6 +35,10 @@ Global {0254BFF9-453C-4FE0-9609-3644559A79CE}.Debug|Any CPU.Build.0 = Debug|Any CPU {0254BFF9-453C-4FE0-9609-3644559A79CE}.Release|Any CPU.ActiveCfg = Release|Any CPU {0254BFF9-453C-4FE0-9609-3644559A79CE}.Release|Any CPU.Build.0 = Release|Any CPU + {3EEAFB06-BEF0-4261-BAAB-630EABD25290}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3EEAFB06-BEF0-4261-BAAB-630EABD25290}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3EEAFB06-BEF0-4261-BAAB-630EABD25290}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3EEAFB06-BEF0-4261-BAAB-630EABD25290}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index cc86014a..8a5d929f 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -21,6 +21,7 @@ namespace Tensorflow public int _version; private int _next_id_counter; private List _unfetchable_ops = new List(); + private List _unfeedable_tensors = new List(); public string _name_stack = ""; public string _graph_key; @@ -366,6 +367,11 @@ namespace Tensorflow return _collections[name]; } + public void prevent_feeding(Tensor tensor) + { + _unfeedable_tensors.Add(tensor); + } + public void Dispose() { c_api.TF_DeleteGraph(_handle); diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs index 13f74c87..bfd96d6a 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -10,11 +10,16 @@ namespace Tensorflow.Keras.Engine public class InputSpec { public int ndim; + Dictionary axes; public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, - int? ndim = null) + int? ndim = null, + Dictionary axes = null) { this.ndim = ndim.Value; + if (axes == null) + axes = new Dictionary(); + this.axes = axes; } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index e35343d1..b6603c28 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -66,7 +66,7 @@ namespace Tensorflow.Keras.Engine } - protected virtual void add_weight(string name, + protected virtual RefVariable add_weight(string name, int[] shape, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, @@ -82,6 +82,8 @@ namespace Tensorflow.Keras.Engine trainable: trainable.Value); backend.track_variable(variable); _trainable_weights.Add(variable); + + return variable; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index fdd4329a..e3c824e6 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -2,6 +2,8 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Operations; using Tensorflow.Operations.Activation; namespace Tensorflow.Keras.Layers @@ -19,6 +21,9 @@ namespace Tensorflow.Keras.Layers protected bool use_bias; protected IInitializer kernel_initializer; protected IInitializer bias_initializer; + protected RefVariable kernel; + protected RefVariable bias; + protected Convolution _convolution_op; public Conv(int rank, int filters, @@ -53,11 +58,37 @@ namespace Tensorflow.Keras.Layers int channel_axis = data_format == "channels_first" ? 1 : -1; int input_dim = input_shape.Dimensions[input_shape.NDim - 1]; var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; - add_weight(name: "kernel", + kernel = add_weight(name: "kernel", shape: kernel_shape, initializer: kernel_initializer, trainable: true, dtype: _dtype); + if (use_bias) + bias = add_weight(name: "bias", + shape: new int[] { filters }, + initializer: bias_initializer, + trainable: true, + dtype: _dtype); + + var axes = new Dictionary(); + axes.Add(-1, input_dim); + input_spec = new InputSpec(ndim: rank + 2, axes: axes); + + string op_padding; + if (padding == "causal") + op_padding = "valid"; + else + op_padding = padding; + + var df = conv_utils.convert_data_format(data_format, rank + 2); + _convolution_op = nn_ops.Convolution(input_shape, + kernel.shape, + op_padding.ToUpper(), + strides, + dilation_rate, + data_format: df); + + built = true; } } } diff --git a/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs new file mode 100644 index 00000000..ef348d1b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Utils +{ + public class conv_utils + { + public static string convert_data_format(string data_format, int ndim) + { + if (data_format == "channels_last") + if (ndim == 3) + return "NWC"; + else if (ndim == 4) + return "NHWC"; + else if (ndim == 5) + return "NDHWC"; + else + throw new ValueError($"Input rank not supported: {ndim}"); + else if (data_format == "channels_first") + if (ndim == 3) + return "NCW"; + else if (ndim == 4) + return "NCHW"; + else if (ndim == 5) + return "NCDHW"; + else + throw new ValueError($"Input rank not supported: {ndim}"); + else + throw new ValueError($"Invalid data_format: {data_format}"); + } + } +} diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index afe17dbb..510b505b 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -68,7 +68,7 @@ namespace Tensorflow.Layers throw new NotImplementedException(""); } - protected virtual void add_weight(string name, + protected virtual RefVariable add_weight(string name, int[] shape, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, @@ -93,14 +93,14 @@ namespace Tensorflow.Layers _set_scope(); var reuse = built || (_reuse != null && _reuse.Value); - Python.with(tf.variable_scope(_scope, + return Python.with(tf.variable_scope(_scope, reuse: reuse, auxiliary_name_scope: false), scope => { _current_scope = scope; - Python.with(ops.name_scope(_name_scope()), delegate + return Python.with(ops.name_scope(_name_scope()), delegate { - base.add_weight(name, + var variable = base.add_weight(name, shape, dtype: dtype, initializer: initializer, @@ -113,6 +113,12 @@ namespace Tensorflow.Layers initializer: initializer1, trainable: trainable1); }); + + if(init_graph != null) + { + var trainable_variables = variables.trainable_variables(); + } + return variable; }); }); } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs b/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs new file mode 100644 index 00000000..a0d7887b --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs @@ -0,0 +1,66 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Operations +{ + public class Convolution + { + public TensorShape input_shape; + public TensorShape filter_shape; + public string data_format; + public int[] strides; + public string name; + public _WithSpaceToBatch conv_op; + + public Convolution(TensorShape input_shape, + TensorShape filter_shape, + string padding, + int[] strides, + int[] dilation_rate, + string name = null, + string data_format = null) + { + var num_total_dims = filter_shape.NDim; + var num_spatial_dims = num_total_dims - 2; + int input_channels_dim; + int[] spatial_dims; + if (string.IsNullOrEmpty(data_format) || !data_format.StartsWith("NC")) + { + input_channels_dim = input_shape.Dimensions[num_spatial_dims + 1]; + spatial_dims = Enumerable.Range(1, num_spatial_dims).ToArray(); + } + else + { + input_channels_dim = input_shape.Dimensions[1]; + spatial_dims = Enumerable.Range(2, num_spatial_dims).ToArray(); + } + + this.input_shape = input_shape; + this.filter_shape = filter_shape; + this.data_format = data_format; + this.strides = strides; + this.name = name; + + conv_op = new _WithSpaceToBatch( + input_shape, + dilation_rate: dilation_rate, + padding: padding, + build_op: _build_op, + filter_shape: filter_shape, + spatial_dims: spatial_dims, + data_format: data_format); + } + + public _NonAtrousConvolution _build_op(int _, string padding) + { + return new _NonAtrousConvolution(input_shape, + filter_shape: filter_shape, + padding: padding, + data_format: data_format, + strides: strides, + name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs b/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs new file mode 100644 index 00000000..38e77a4f --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs @@ -0,0 +1,56 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Operations +{ + public class _NonAtrousConvolution + { + public string padding; + public string name; + public int[] strides; + public string data_format; + private Func conv_op; + + public _NonAtrousConvolution(TensorShape input_shape, + TensorShape filter_shape, + string padding, + string data_format, + int[] strides, + string name) + { + this.padding = padding; + this.name = name; + var conv_dims = input_shape.NDim - 2; + if (conv_dims == 1) + { + throw new NotImplementedException("_NonAtrousConvolution conv_dims 1"); + } + else if (conv_dims == 2) + { + var list = strides.ToList(); + + if (string.IsNullOrEmpty(data_format) || data_format == "NHWC") + { + data_format = "NHWC"; + list.Insert(0, 1); + list.Add(1); + } + else if (data_format == "NCHW") + list.InsertRange(0, new int[] { 1, 1 }); + else + throw new ValueError("data_format must be \"NHWC\" or \"NCHW\"."); + + strides = list.ToArray(); + this.strides = strides; + this.data_format = data_format; + conv_op = gen_nn_ops.conv2d; + } + else if (conv_dims == 3) + { + throw new NotImplementedException("_NonAtrousConvolution conv_dims 3"); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs new file mode 100644 index 00000000..b144e95c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Operations +{ + public class _WithSpaceToBatch + { + private _NonAtrousConvolution call; + + public _WithSpaceToBatch(TensorShape input_shape, + int[] dilation_rate, + string padding, + Func build_op, + TensorShape filter_shape = null, + int[] spatial_dims = null, + string data_format = null) + { + var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate"); + var rate_shape = dilation_rate_tensor.getShape(); + var num_spatial_dims = rate_shape.Dimensions[0]; + int starting_spatial_dim = -1; + if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) + starting_spatial_dim = 2; + else + starting_spatial_dim = 1; + + if (spatial_dims == null) + throw new NotImplementedException("_WithSpaceToBatch spatial_dims"); + + var orig_spatial_dims = spatial_dims; + spatial_dims = spatial_dims.OrderBy(x => x).ToArray(); + if (!Enumerable.SequenceEqual(spatial_dims, orig_spatial_dims) || spatial_dims.Any(x => x < 1)) + throw new ValueError("spatial_dims must be a montonically increasing sequence of positive integers"); + + int expected_input_rank = -1; + if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) + expected_input_rank = spatial_dims.Last(); + else + expected_input_rank = spatial_dims.Last() + 1; + + var const_rate = tensor_util.constant_value(dilation_rate_tensor); + var rate_or_const_rate = dilation_rate; + if(!(const_rate is null)) + { + if (const_rate.Data().Count(x => x == 1) == const_rate.size) + { + call = build_op(num_spatial_dims, padding); + return; + } + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs new file mode 100644 index 00000000..54d9242c --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public class gen_nn_ops + { + public static Tensor conv2d(object parameters) + { + throw new NotImplementedException("gen_nn_op.conv2d"); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index a18b109e..a3535838 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -12,7 +12,7 @@ namespace Tensorflow { public class OpDefLibrary : Python { - public Operation _apply_op_helper(string op_type_name, string name = null, dynamic args = null) + public Operation _apply_op_helper(string op_type_name, string name = null, object args = null) { Dictionary keywords = ConvertToDict(args); var g = ops.get_default_graph(); @@ -358,25 +358,5 @@ namespace Tensorflow return false; } } - - private Dictionary ConvertToDict(dynamic dyn) - { - var dictionary = new Dictionary(); - foreach (PropertyDescriptor propertyDescriptor in TypeDescriptor.GetProperties(dyn)) - { - object obj = propertyDescriptor.GetValue(dyn); - string name = propertyDescriptor.Name; - // avoid .net keyword - switch (name) - { - case "_ref_": - name = "ref"; - break; - } - - dictionary.Add(name, obj); - } - return dictionary; - } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 5f5d9b1c..d6fb63c1 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -164,7 +164,7 @@ namespace Tensorflow return grouped_inputs.ToArray(); } - public object get_attr(string name) + public object get_attr(string name) { AttrValue x = null; @@ -175,24 +175,17 @@ namespace Tensorflow x = AttrValue.Parser.ParseFrom(buf); } - switch (name) - { - case "T": - case "dtype": - return x.Type; - case "shape": - return x.Shape; - default: - switch (typeof(T).Name) - { - case "Boolean": - return x.B; - case "String": - return x.S; - default: - throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); - } - } + string oneof_value = x.ValueCase.ToString(); + if (string.IsNullOrEmpty(oneof_value)) + return null; + + if(oneof_value == "list") + throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); + + if (oneof_value == "type") + return x.Type; + + return x.GetType().GetProperty(oneof_value).GetValue(x); } public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 6d41e13f..0ae4c3a2 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -47,8 +47,8 @@ namespace Tensorflow var _inputs_flat = _op.inputs; var _attrs = new Dictionary(); - _attrs["dtype"] = _op.get_attr("dtype"); - _attrs["shape"] = _op.get_attr("shape"); + _attrs["dtype"] = _op.get_attr("dtype"); + _attrs["shape"] = _op.get_attr("shape"); _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs new file mode 100644 index 00000000..1c36819a --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; + +namespace Tensorflow +{ + public class nn_ops + { + public static Convolution Convolution(TensorShape input_shape, + TensorShape filter_shape, + string padding, + int[] strides, + int[] dilation_rate, + string name = null, + string data_format = null) => new Convolution(input_shape, + filter_shape, + padding, + strides, + dilation_rate, + name: name, + data_format: data_format); + } +} diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index b11f2889..1a7beb72 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -1,6 +1,7 @@ using NumSharp.Core; using System; using System.Collections.Generic; +using System.ComponentModel; using System.Text; namespace Tensorflow @@ -109,6 +110,26 @@ namespace Tensorflow for (int i = 0; i < values.Count; i++) yield return (i, values[i]); } + + public static Dictionary ConvertToDict(object dyn) + { + var dictionary = new Dictionary(); + foreach (PropertyDescriptor propertyDescriptor in TypeDescriptor.GetProperties(dyn)) + { + object obj = propertyDescriptor.GetValue(dyn); + string name = propertyDescriptor.Name; + // avoid .net keyword + switch (name) + { + case "_ref_": + name = "ref"; + break; + } + + dictionary.Add(name, obj); + } + return dictionary; + } } public interface IPython : IDisposable diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 9dfe8c31..23d5b0a6 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -51,4 +51,8 @@ Fixed import name scope issue. + + + + diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 5f830070..005b5df9 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -79,6 +79,11 @@ namespace Tensorflow return (int)type; } + public static Type as_numpy_dtype(this DataType type) + { + return type.as_tf_dtype().as_numpy_datatype(); + } + public static DataType as_base_dtype(this DataType type) { return (int)type > 100 ? diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index d6aaea07..ede6d495 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -17,6 +17,42 @@ namespace Tensorflow TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 }; + /// + /// Returns the constant value of the given tensor, if efficiently calculable. + /// + /// + /// + /// + public static NDArray constant_value(Tensor tensor, bool partial = false) + { + NDArray ret = _ConstantValue(tensor, partial); + if (!(ret is null)) + tensor.graph.prevent_feeding(tensor); + + return ret; + } + + private static NDArray _ConstantValue(Tensor tensor, bool partial) + { + if (tensor.op.type == "Const") + { + return MakeNdarray(tensor.op.get_attr("value") as TensorProto); + } + throw new NotImplementedException("_ConstantValue"); + } + + public static NDArray MakeNdarray(TensorProto tensor) + { + var shape = tensor.TensorShape.Dim.Select(x => (int)x.Size).ToArray(); + long num_elements = np.prod(shape); + var tensor_dtype = tensor.Dtype.as_numpy_dtype(); + + if (tensor.TensorContent.Length > 0) + return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype) + .reshape(shape); + throw new NotImplementedException("MakeNdarray"); + } + /// /// Create a TensorProto. /// diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 05982d95..e7ee3043 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -113,7 +113,7 @@ namespace Tensorflow _graph_key = ops.get_default_graph()._graph_key; _trainable = trainable; - if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) + if (trainable && !collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); ops.init_scope(); diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 6317733d..737d95b1 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -29,10 +29,10 @@ namespace Tensorflow var _inputs_flat = _op.inputs; var _attrs = new Dictionary(); - _attrs["dtype"] = _op.get_attr("dtype"); - _attrs["shape"] = _op.get_attr("shape"); - _attrs["container"] = _op.get_attr("container"); - _attrs["shared_name"] = _op.get_attr("shared_name"); + _attrs["dtype"] = _op.get_attr("dtype"); + _attrs["shape"] = _op.get_attr("shape"); + _attrs["container"] = _op.get_attr("container"); + _attrs["shared_name"] = _op.get_attr("shared_name"); _execute.record_gradient("VariableV2", _inputs_flat, _attrs, _result, name); @@ -58,9 +58,9 @@ namespace Tensorflow var _inputs_flat = _op.inputs; var _attrs = new Dictionary(); - _attrs["T"] = _op.get_attr("T"); - _attrs["validate_shape"] = _op.get_attr("validate_shape"); - _attrs["use_locking"] = _op.get_attr("use_locking"); + _attrs["T"] = _op.get_attr("T"); + _attrs["validate_shape"] = _op.get_attr("validate_shape"); + _attrs["use_locking"] = _op.get_attr("use_locking"); _execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name); diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 3f92edd0..1a15a6ac 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -155,6 +155,7 @@ namespace Tensorflow else { return new RefVariable(initial_value, + trainable: trainable.Value, name: name, dtype: dtype); } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index b0dd5a65..9e83fdae 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -12,6 +12,7 @@ + diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index bc298eb7..c9b02eab 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -24,6 +24,7 @@ +