| @@ -2,7 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Eager | |||
| namespace Tensorflow | |||
| { | |||
| public class Context | |||
| { | |||
| @@ -152,6 +152,11 @@ namespace Tensorflow | |||
| return false; | |||
| } | |||
| public string get_name_scope() | |||
| { | |||
| return _name_stack; | |||
| } | |||
| public string name_scope(string name) | |||
| { | |||
| string new_stack = ""; | |||
| @@ -10,7 +10,7 @@ namespace Tensorflow | |||
| { | |||
| public class OpDefLibrary | |||
| { | |||
| public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||
| public Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||
| { | |||
| var g = ops.get_default_graph(); | |||
| var op_def = g.GetOpDef(op_type_name); | |||
| @@ -8,22 +8,62 @@ namespace Tensorflow | |||
| { | |||
| public bool _in_graph_mode = true; | |||
| public Tensor _initial_value; | |||
| public string _graph_key; | |||
| public bool _trainable; | |||
| public Tensor _variable; | |||
| public RefVariable(object initial_value, | |||
| public RefVariable(object initial_value, | |||
| bool trainable = true, | |||
| List<string> collections = null, | |||
| bool validate_shape = true, | |||
| string caching_device = "", | |||
| string name = "", | |||
| TF_DataType trainable = TF_DataType.DtInvalid, | |||
| bool validate_shape = true) : | |||
| base(initial_value, name, trainable, validate_shape) | |||
| TF_DataType dtype = TF_DataType.DtInvalid) : | |||
| base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) | |||
| { | |||
| _init_from_args(initial_value, name, trainable); | |||
| _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); | |||
| } | |||
| private void _init_from_args(object initial_value, | |||
| bool trainable = true, | |||
| List<string> collections = null, | |||
| bool validate_shape = true, | |||
| string caching_device = "", | |||
| string name = "", | |||
| TF_DataType trainable = TF_DataType.DtInvalid) | |||
| TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| name = ops.name_scope("", "Variable", initial_value); | |||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||
| if (initial_value is null) | |||
| throw new ValueError("initial_value must be specified."); | |||
| var init_from_fn = false; | |||
| if(collections == null) | |||
| { | |||
| collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES }; | |||
| } | |||
| // Store the graph key so optimizers know how to only retrieve variables from | |||
| // this graph. | |||
| _graph_key = ops.get_default_graph()._graph_key; | |||
| _trainable = trainable; | |||
| if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) | |||
| collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
| ops.init_scope(); | |||
| name = new ops.name_scope(name, "Variable", init_from_fn ? new List<object>() : new List<object> { initial_value }); | |||
| if (init_from_fn) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||
| } | |||
| var shape = _initial_value.shape; | |||
| dtype = _initial_value.dtype; | |||
| _variable = gen_state_ops.variable_v2(shape, dtype, name); | |||
| } | |||
| } | |||
| } | |||
| @@ -16,7 +16,13 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public class VariableV1 | |||
| { | |||
| public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) | |||
| public VariableV1(object initial_value, | |||
| bool trainable = true, | |||
| List<string> collections = null, | |||
| bool validate_shape = true, | |||
| string caching_device = "", | |||
| string name = "", | |||
| TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| } | |||
| @@ -0,0 +1,35 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class gen_state_ops | |||
| { | |||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| /// <summary> | |||
| /// Holds state in the form of a tensor that persists across steps. | |||
| /// Outputs a ref to the tensor state so it may be read or modified. | |||
| /// </summary> | |||
| /// <param name="shape">The shape of the variable tensor.</param> | |||
| /// <param name="dtype">The type of elements in the variable tensor.</param> | |||
| /// <param name="name"></param> | |||
| /// <param name="container"></param> | |||
| /// <param name="shared_name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name = "", string container = "", string shared_name = "") | |||
| { | |||
| var keywords = new Dictionary<string, object>(); | |||
| keywords.Add("dtype", dtype); | |||
| keywords.Add("shape", shape); | |||
| var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords); | |||
| var _result = _op.outputs; | |||
| var _inputs_flat = _op.inputs; | |||
| return new Tensor(_op, 0, dtype); | |||
| } | |||
| } | |||
| } | |||
| @@ -26,7 +26,9 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| return new RefVariable(initial_value); | |||
| return new RefVariable(initial_value, | |||
| name: name, | |||
| dtype: dtype); | |||
| } | |||
| } | |||
| @@ -12,7 +12,7 @@ namespace Tensorflow | |||
| /// <returns></returns> | |||
| public static object trainable_variables() | |||
| { | |||
| return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES); | |||
| return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||
| } | |||
| } | |||
| } | |||
| @@ -15,12 +15,18 @@ namespace Tensorflow | |||
| /// specified, but it is also possible to pass an explicit list of | |||
| /// variables. | |||
| /// </summary> | |||
| public static class GraphKey | |||
| public static class GraphKeys | |||
| { | |||
| /// <summary> | |||
| /// the subset of `Variable` objects that will be trained by an optimizer. | |||
| /// </summary> | |||
| public static string TRAINABLE_VARIABLES = "trainable_variables"; | |||
| /// <summary> | |||
| /// Key to collect Variable objects that are global (shared across machines). | |||
| /// Default collection for all variables, except local ones. | |||
| /// </summary> | |||
| public static string GLOBAL_VARIABLES = "variables"; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,44 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class ops | |||
| { | |||
| public class name_scope | |||
| { | |||
| public string _name; | |||
| public string _default_name; | |||
| public object _values; | |||
| public Context _ctx; | |||
| public string _name_scope; | |||
| public name_scope(string name, string default_name, List<object> values) | |||
| { | |||
| _name = name; | |||
| _default_name = default_name; | |||
| _values = values; | |||
| _ctx = new Context(); | |||
| _name_scope = __enter__(); | |||
| } | |||
| public string __enter__() | |||
| { | |||
| if (String.IsNullOrEmpty(_name)) | |||
| { | |||
| _name = _default_name; | |||
| } | |||
| var g = get_default_graph(); | |||
| return g.name_scope(_name); | |||
| } | |||
| public static implicit operator string(name_scope ns) | |||
| { | |||
| return ns._name_scope; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -97,20 +97,6 @@ namespace Tensorflow | |||
| return node_def; | |||
| } | |||
| public static string name_scope(string name, string default_name = "", object values = null) | |||
| { | |||
| string _name = ""; | |||
| if (String.IsNullOrEmpty(name)) | |||
| { | |||
| _name = default_name; | |||
| } | |||
| var g = get_default_graph(); | |||
| var _name_scope = g.name_scope(_name); | |||
| return _name_scope; | |||
| } | |||
| public static string _name_from_scope_name(string name) | |||
| { | |||
| if (name.EndsWith("/")) | |||
| @@ -123,6 +109,27 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// A context manager that lifts ops out of control-flow scopes and function-building graphs. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static void init_scope() | |||
| { | |||
| // Retrieve the active name scope: entering an `init_scope` preserves | |||
| // the name scope of the current context. | |||
| var default_graph = get_default_graph(); | |||
| var scope = default_graph.get_name_scope(); | |||
| if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) | |||
| // Names that end with trailing slashes are treated by `name_scope` as | |||
| // absolute. | |||
| scope += "/"; | |||
| // inner_device_stack = default_graph._device_function_stack | |||
| // var outer_context = default_graph.as_default; | |||
| var outer_graph = get_default_graph(); | |||
| // outer_device_stack = None | |||
| } | |||
| public static int uid() | |||
| { | |||
| return 1; | |||
| @@ -1,10 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using TF_DataType = Tensorflow.DataType; | |||
| using attr_value_pb2 = Tensorflow; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow | |||
| { | |||