| @@ -2,7 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Eager | |||||
| namespace Tensorflow | |||||
| { | { | ||||
| public class Context | public class Context | ||||
| { | { | ||||
| @@ -152,6 +152,11 @@ namespace Tensorflow | |||||
| return false; | return false; | ||||
| } | } | ||||
| public string get_name_scope() | |||||
| { | |||||
| return _name_stack; | |||||
| } | |||||
| public string name_scope(string name) | public string name_scope(string name) | ||||
| { | { | ||||
| string new_stack = ""; | string new_stack = ""; | ||||
| @@ -10,7 +10,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class OpDefLibrary | public class OpDefLibrary | ||||
| { | { | ||||
| public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||||
| public Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||||
| { | { | ||||
| var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
| var op_def = g.GetOpDef(op_type_name); | var op_def = g.GetOpDef(op_type_name); | ||||
| @@ -8,22 +8,62 @@ namespace Tensorflow | |||||
| { | { | ||||
| public bool _in_graph_mode = true; | public bool _in_graph_mode = true; | ||||
| public Tensor _initial_value; | public Tensor _initial_value; | ||||
| public string _graph_key; | |||||
| public bool _trainable; | |||||
| public Tensor _variable; | |||||
| public RefVariable(object initial_value, | |||||
| public RefVariable(object initial_value, | |||||
| bool trainable = true, | |||||
| List<string> collections = null, | |||||
| bool validate_shape = true, | |||||
| string caching_device = "", | |||||
| string name = "", | string name = "", | ||||
| TF_DataType trainable = TF_DataType.DtInvalid, | |||||
| bool validate_shape = true) : | |||||
| base(initial_value, name, trainable, validate_shape) | |||||
| TF_DataType dtype = TF_DataType.DtInvalid) : | |||||
| base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype) | |||||
| { | { | ||||
| _init_from_args(initial_value, name, trainable); | |||||
| _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); | |||||
| } | } | ||||
| private void _init_from_args(object initial_value, | private void _init_from_args(object initial_value, | ||||
| bool trainable = true, | |||||
| List<string> collections = null, | |||||
| bool validate_shape = true, | |||||
| string caching_device = "", | |||||
| string name = "", | string name = "", | ||||
| TF_DataType trainable = TF_DataType.DtInvalid) | |||||
| TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | { | ||||
| name = ops.name_scope("", "Variable", initial_value); | |||||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||||
| if (initial_value is null) | |||||
| throw new ValueError("initial_value must be specified."); | |||||
| var init_from_fn = false; | |||||
| if(collections == null) | |||||
| { | |||||
| collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES }; | |||||
| } | |||||
| // Store the graph key so optimizers know how to only retrieve variables from | |||||
| // this graph. | |||||
| _graph_key = ops.get_default_graph()._graph_key; | |||||
| _trainable = trainable; | |||||
| if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES)) | |||||
| collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES); | |||||
| ops.init_scope(); | |||||
| name = new ops.name_scope(name, "Variable", init_from_fn ? new List<object>() : new List<object> { initial_value }); | |||||
| if (init_from_fn) | |||||
| { | |||||
| } | |||||
| else | |||||
| { | |||||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | |||||
| } | |||||
| var shape = _initial_value.shape; | |||||
| dtype = _initial_value.dtype; | |||||
| _variable = gen_state_ops.variable_v2(shape, dtype, name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -16,7 +16,13 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class VariableV1 | public class VariableV1 | ||||
| { | { | ||||
| public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) | |||||
| public VariableV1(object initial_value, | |||||
| bool trainable = true, | |||||
| List<string> collections = null, | |||||
| bool validate_shape = true, | |||||
| string caching_device = "", | |||||
| string name = "", | |||||
| TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | { | ||||
| } | } | ||||
| @@ -0,0 +1,35 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class gen_state_ops | |||||
| { | |||||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||||
| /// <summary> | |||||
| /// Holds state in the form of a tensor that persists across steps. | |||||
| /// Outputs a ref to the tensor state so it may be read or modified. | |||||
| /// </summary> | |||||
| /// <param name="shape">The shape of the variable tensor.</param> | |||||
| /// <param name="dtype">The type of elements in the variable tensor.</param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="container"></param> | |||||
| /// <param name="shared_name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor variable_v2(long[] shape, TF_DataType dtype, string name = "", string container = "", string shared_name = "") | |||||
| { | |||||
| var keywords = new Dictionary<string, object>(); | |||||
| keywords.Add("dtype", dtype); | |||||
| keywords.Add("shape", shape); | |||||
| var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords); | |||||
| var _result = _op.outputs; | |||||
| var _inputs_flat = _op.inputs; | |||||
| return new Tensor(_op, 0, dtype); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -26,7 +26,9 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return new RefVariable(initial_value); | |||||
| return new RefVariable(initial_value, | |||||
| name: name, | |||||
| dtype: dtype); | |||||
| } | } | ||||
| } | } | ||||
| @@ -12,7 +12,7 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static object trainable_variables() | public static object trainable_variables() | ||||
| { | { | ||||
| return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES); | |||||
| return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -15,12 +15,18 @@ namespace Tensorflow | |||||
| /// specified, but it is also possible to pass an explicit list of | /// specified, but it is also possible to pass an explicit list of | ||||
| /// variables. | /// variables. | ||||
| /// </summary> | /// </summary> | ||||
| public static class GraphKey | |||||
| public static class GraphKeys | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | /// the subset of `Variable` objects that will be trained by an optimizer. | ||||
| /// </summary> | /// </summary> | ||||
| public static string TRAINABLE_VARIABLES = "trainable_variables"; | public static string TRAINABLE_VARIABLES = "trainable_variables"; | ||||
| /// <summary> | |||||
| /// Key to collect Variable objects that are global (shared across machines). | |||||
| /// Default collection for all variables, except local ones. | |||||
| /// </summary> | |||||
| public static string GLOBAL_VARIABLES = "variables"; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,44 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class ops | |||||
| { | |||||
| public class name_scope | |||||
| { | |||||
| public string _name; | |||||
| public string _default_name; | |||||
| public object _values; | |||||
| public Context _ctx; | |||||
| public string _name_scope; | |||||
| public name_scope(string name, string default_name, List<object> values) | |||||
| { | |||||
| _name = name; | |||||
| _default_name = default_name; | |||||
| _values = values; | |||||
| _ctx = new Context(); | |||||
| _name_scope = __enter__(); | |||||
| } | |||||
| public string __enter__() | |||||
| { | |||||
| if (String.IsNullOrEmpty(_name)) | |||||
| { | |||||
| _name = _default_name; | |||||
| } | |||||
| var g = get_default_graph(); | |||||
| return g.name_scope(_name); | |||||
| } | |||||
| public static implicit operator string(name_scope ns) | |||||
| { | |||||
| return ns._name_scope; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -97,20 +97,6 @@ namespace Tensorflow | |||||
| return node_def; | return node_def; | ||||
| } | } | ||||
| public static string name_scope(string name, string default_name = "", object values = null) | |||||
| { | |||||
| string _name = ""; | |||||
| if (String.IsNullOrEmpty(name)) | |||||
| { | |||||
| _name = default_name; | |||||
| } | |||||
| var g = get_default_graph(); | |||||
| var _name_scope = g.name_scope(_name); | |||||
| return _name_scope; | |||||
| } | |||||
| public static string _name_from_scope_name(string name) | public static string _name_from_scope_name(string name) | ||||
| { | { | ||||
| if (name.EndsWith("/")) | if (name.EndsWith("/")) | ||||
| @@ -123,6 +109,27 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// A context manager that lifts ops out of control-flow scopes and function-building graphs. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public static void init_scope() | |||||
| { | |||||
| // Retrieve the active name scope: entering an `init_scope` preserves | |||||
| // the name scope of the current context. | |||||
| var default_graph = get_default_graph(); | |||||
| var scope = default_graph.get_name_scope(); | |||||
| if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) | |||||
| // Names that end with trailing slashes are treated by `name_scope` as | |||||
| // absolute. | |||||
| scope += "/"; | |||||
| // inner_device_stack = default_graph._device_function_stack | |||||
| // var outer_context = default_graph.as_default; | |||||
| var outer_graph = get_default_graph(); | |||||
| // outer_device_stack = None | |||||
| } | |||||
| public static int uid() | public static int uid() | ||||
| { | { | ||||
| return 1; | return 1; | ||||
| @@ -1,10 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using TF_DataType = Tensorflow.DataType; | |||||
| using attr_value_pb2 = Tensorflow; | |||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||