| @@ -24,6 +24,8 @@ namespace Tensorflow | |||||
| private int _next_id_counter; | private int _next_id_counter; | ||||
| private List<String> _unfetchable_ops = new List<string>(); | private List<String> _unfetchable_ops = new List<string>(); | ||||
| private string _name_stack; | |||||
| public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
| { | { | ||||
| this._c_graph = graph; | this._c_graph = graph; | ||||
| @@ -126,8 +128,31 @@ namespace Tensorflow | |||||
| return false; | return false; | ||||
| } | } | ||||
| public string name_scope(string name) | |||||
| { | |||||
| string new_stack = ""; | |||||
| if (name.EndsWith("/")) | |||||
| { | |||||
| new_stack = ops._name_from_scope_name(name); | |||||
| } | |||||
| else | |||||
| { | |||||
| new_stack = unique_name(name); | |||||
| } | |||||
| _name_stack = new_stack; | |||||
| return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/"; | |||||
| } | |||||
| public string unique_name(string name) | public string unique_name(string name) | ||||
| { | { | ||||
| if (!String.IsNullOrEmpty(_name_stack)) | |||||
| { | |||||
| name = _name_stack + "/" + name; | |||||
| } | |||||
| var name_key = name.ToLower(); | var name_key = name.ToLower(); | ||||
| if (_names_in_use.ContainsKey(name_key)) | if (_names_in_use.ContainsKey(name_key)) | ||||
| { | { | ||||
| @@ -138,7 +163,6 @@ namespace Tensorflow | |||||
| _names_in_use[name_key] = 1; | _names_in_use[name_key] = 1; | ||||
| return name; | return name; | ||||
| } | } | ||||
| return $"{name}_{_names_in_use[name_key]}"; | return $"{name}_{_names_in_use[name_key]}"; | ||||
| } | } | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow | |||||
| private static OpDefLibrary _InitOpDefLibrary() | private static OpDefLibrary _InitOpDefLibrary() | ||||
| { | { | ||||
| // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | ||||
| var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin"); | |||||
| var bytes = File.ReadAllBytes("Operations/op_list_proto_array.bin"); | |||||
| var op_list = OpList.Parser.ParseFrom(bytes); | var op_list = OpList.Parser.ParseFrom(bytes); | ||||
| var op_def_lib = new OpDefLibrary(); | var op_def_lib = new OpDefLibrary(); | ||||
| op_def_lib.add_op_list(op_list); | op_def_lib.add_op_list(op_list); | ||||
| @@ -71,6 +71,20 @@ namespace Tensorflow | |||||
| return node_def; | return node_def; | ||||
| } | } | ||||
| public static string name_scope(string name, string default_name = "", object values = null) | |||||
| { | |||||
| string _name = ""; | |||||
| if (String.IsNullOrEmpty(name)) | |||||
| { | |||||
| _name = default_name; | |||||
| } | |||||
| var g = get_default_graph(); | |||||
| var _name_scope = g.name_scope(_name); | |||||
| return _name_scope; | |||||
| } | |||||
| public static string _name_from_scope_name(string name) | public static string _name_from_scope_name(string name) | ||||
| { | { | ||||
| if (name.EndsWith("/")) | if (name.EndsWith("/")) | ||||
| @@ -14,12 +14,13 @@ namespace Tensorflow | |||||
| bool validate_shape = true) : | bool validate_shape = true) : | ||||
| base(initial_value, trainable, validate_shape) | base(initial_value, trainable, validate_shape) | ||||
| { | { | ||||
| _init_from_args(initial_value, trainable); | |||||
| } | } | ||||
| private void _init_from_args(object initial_value, | private void _init_from_args(object initial_value, | ||||
| TF_DataType trainable) | TF_DataType trainable) | ||||
| { | { | ||||
| var name = ops.name_scope("", "Variable", initial_value); | |||||
| _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | ||||
| } | } | ||||
| } | } | ||||
| @@ -33,7 +33,13 @@ namespace Tensorflow | |||||
| var attrs = new Dictionary<string, AttrValue>(); | var attrs = new Dictionary<string, AttrValue>(); | ||||
| attrs["dtype"] = dtype_value; | attrs["dtype"] = dtype_value; | ||||
| attrs["value"] = tensor_value; | attrs["value"] = tensor_value; | ||||
| var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0]; | |||||
| var const_tensor = g.create_op("Const", | |||||
| null, | |||||
| new TF_DataType[] { (TF_DataType)dtype_value.Type }, | |||||
| attrs: attrs, | |||||
| name: name).outputs[0]; | |||||
| const_tensor.value = nd.Data(); | const_tensor.value = nd.Data(); | ||||
| return const_tensor; | return const_tensor; | ||||
| @@ -17,9 +17,9 @@ namespace Tensorflow | |||||
| public static Graph g = new Graph(c_api.TF_NewGraph()); | public static Graph g = new Graph(c_api.TF_NewGraph()); | ||||
| public static object Variable<T>(T data, TF_DataType dtype) | |||||
| public static object Variable<T>(T data, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | { | ||||
| return new Variable(null, TF_DataType.DtInvalid); | |||||
| return new RefVariable(data, dtype); | |||||
| } | } | ||||
| public static unsafe Tensor add(Tensor a, Tensor b) | public static unsafe Tensor add(Tensor a, Tensor b) | ||||
| @@ -10,9 +10,17 @@ namespace TensorFlowNET.UnitTest | |||||
| public class VariableTest | public class VariableTest | ||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void Creating() | |||||
| public void StringVar() | |||||
| { | { | ||||
| var mammal = tf.Variable("Elephant", tf.chars); | |||||
| var mammal1 = tf.Variable("Elephant", tf.chars); | |||||
| var mammal2 = tf.Variable("Tiger"); | |||||
| } | |||||
| [TestMethod] | |||||
| public void ScalarVar() | |||||
| { | |||||
| var x = tf.Variable(3); | |||||
| var y = tf.Variable(6f); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||