| @@ -91,9 +91,18 @@ namespace Tensorflow | |||||
| throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | ||||
| } | } | ||||
| public void add_to_collection(string name, object value) | |||||
| public void add_to_collection<T>(string name, T value) | |||||
| { | { | ||||
| _collections[name] = value; | |||||
| if (_collections.ContainsKey(name)) | |||||
| (_collections[name] as List<T>).Add(value); | |||||
| else | |||||
| _collections[name] = new List<T> { value }; | |||||
| } | |||||
| public void add_to_collections<T>(List<string> names, T value) | |||||
| { | |||||
| foreach (string name in names) | |||||
| add_to_collection(name, value); | |||||
| } | } | ||||
| public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | ||||
| @@ -236,9 +245,9 @@ namespace Tensorflow | |||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
| } | } | ||||
| public Dictionary<string, object> get_collection(string name) | |||||
| public object get_collection(string name) | |||||
| { | { | ||||
| return _collections; | |||||
| return _collections.ContainsKey(name) ? _collections[name] : null; | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| @@ -20,7 +20,7 @@ namespace Tensorflow | |||||
| name = op_type_name; | name = op_type_name; | ||||
| } | } | ||||
| string scope = g.unique_name(name) + "/"; | |||||
| string scope = new ops.name_scope(name); | |||||
| var default_type_attr_map = new Dictionary<string, object>(); | var default_type_attr_map = new Dictionary<string, object>(); | ||||
| foreach (var attr_def in op_def.Attr) | foreach (var attr_def in op_def.Attr) | ||||
| @@ -88,6 +88,9 @@ namespace Tensorflow | |||||
| switch (attr_def.Type) | switch (attr_def.Type) | ||||
| { | { | ||||
| case "string": | |||||
| attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||||
| break; | |||||
| case "type": | case "type": | ||||
| attr_value.Type = _MakeType((TF_DataType)value, attr_def); | attr_value.Type = _MakeType((TF_DataType)value, attr_def); | ||||
| break; | break; | ||||
| @@ -95,8 +98,12 @@ namespace Tensorflow | |||||
| attr_value.B = (bool)value; | attr_value.B = (bool)value; | ||||
| break; | break; | ||||
| case "shape": | case "shape": | ||||
| attr_value.Shape = new TensorShapeProto(); | |||||
| attr_value.Shape = value == null ? | |||||
| attr_def.DefaultValue.Shape : | |||||
| tensor_util.as_shape((long[])value); | |||||
| break; | break; | ||||
| default: | |||||
| throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); | |||||
| } | } | ||||
| attr_protos[key] = attr_value; | attr_protos[key] = attr_value; | ||||
| @@ -73,6 +73,22 @@ namespace Tensorflow | |||||
| return nd; | return nd; | ||||
| } | } | ||||
| public static TensorShapeProto as_shape(long[] dims) | |||||
| { | |||||
| TensorShapeProto shape = new TensorShapeProto(); | |||||
| for (int i = 0; i < dims.Length; i++) | |||||
| { | |||||
| var dim = new TensorShapeProto.Types.Dim(); | |||||
| dim.Size = dims[i]; | |||||
| dim.Name = $"dim_{i}"; | |||||
| shape.Dim.Add(dim); | |||||
| } | |||||
| return shape; | |||||
| } | |||||
| public static TensorShape as_shape(this IShape shape, int[] dims) | public static TensorShape as_shape(this IShape shape, int[] dims) | ||||
| { | { | ||||
| return new TensorShape(dims); | return new TensorShape(dims); | ||||
| @@ -30,9 +30,14 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="loss"></param> | /// <param name="loss"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Optimizer minimize(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP) | |||||
| public Optimizer minimize(Tensor loss, | |||||
| GateGradientType gate_gradients = GateGradientType.GATE_OP, | |||||
| bool colocate_gradients_with_ops = false) | |||||
| { | { | ||||
| compute_gradients(loss, gate_gradients); | |||||
| compute_gradients(loss, | |||||
| gate_gradients: gate_gradients, | |||||
| colocate_gradients_with_ops: colocate_gradients_with_ops); | |||||
| return this; | return this; | ||||
| } | } | ||||
| @@ -41,7 +46,10 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="loss"></param> | /// <param name="loss"></param> | ||||
| /// <param name="gate_gradients"></param> | /// <param name="gate_gradients"></param> | ||||
| public List<KeyValuePair<object, object>> compute_gradients(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP) | |||||
| public List<KeyValuePair<object, object>> compute_gradients(Tensor loss, | |||||
| List<RefVariable> var_list = null, | |||||
| GateGradientType gate_gradients = GateGradientType.GATE_OP, | |||||
| bool colocate_gradients_with_ops = false) | |||||
| { | { | ||||
| int num_towers = 1; | int num_towers = 1; | ||||
| if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN) | if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN) | ||||
| @@ -49,7 +57,19 @@ namespace Tensorflow | |||||
| } | } | ||||
| var var_list = variables.trainable_variables(); | |||||
| var tmp = variables.trainable_variables(); | |||||
| switch (tmp) | |||||
| { | |||||
| case List<RefVariable> values: | |||||
| var_list = values; | |||||
| break; | |||||
| } | |||||
| foreach(var v in var_list) | |||||
| { | |||||
| } | |||||
| return null; | return null; | ||||
| } | } | ||||
| } | } | ||||
| @@ -64,6 +64,28 @@ namespace Tensorflow | |||||
| var shape = _initial_value.shape; | var shape = _initial_value.shape; | ||||
| dtype = _initial_value.dtype; | dtype = _initial_value.dtype; | ||||
| _variable = gen_state_ops.variable_v2(shape, dtype, name); | _variable = gen_state_ops.variable_v2(shape, dtype, name); | ||||
| // Manually overrides the variable's shape with the initial value's. | |||||
| if (validate_shape) | |||||
| { | |||||
| var initial_value_shape = _initial_value.shape; | |||||
| } | |||||
| // If 'initial_value' makes use of other variables, make sure we don't | |||||
| // have an issue if these other variables aren't initialized first by | |||||
| // using their initialized_value() method. | |||||
| ops.add_to_collections(collections, this); | |||||
| } | |||||
| public static implicit operator _VariableScopeStore(RefVariable variable) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| public static implicit operator RefVariable(_VariableScopeStore store) | |||||
| { | |||||
| return null; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,6 +23,8 @@ namespace Tensorflow | |||||
| var keywords = new Dictionary<string, object>(); | var keywords = new Dictionary<string, object>(); | ||||
| keywords.Add("dtype", dtype); | keywords.Add("dtype", dtype); | ||||
| keywords.Add("shape", shape); | keywords.Add("shape", shape); | ||||
| keywords.Add("container", container); | |||||
| keywords.Add("shared_name", shared_name); | |||||
| var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords); | var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords); | ||||
| @@ -39,18 +39,30 @@ namespace Tensorflow | |||||
| public static _VariableScopeStore get_variable_scope_store() | public static _VariableScopeStore get_variable_scope_store() | ||||
| { | { | ||||
| _VariableScopeStore ret = null; | |||||
| var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); | var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); | ||||
| if (scope_store == null) | if (scope_store == null) | ||||
| { | { | ||||
| scope_store = new _VariableScopeStore(); | |||||
| ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); | |||||
| ret = new _VariableScopeStore(); | |||||
| ops.add_to_collection(_VARSCOPESTORE_KEY, ret); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // scope_store = scope_store[0]; | |||||
| switch (scope_store) | |||||
| { | |||||
| case List<RefVariable> values: | |||||
| ret = values[0]; | |||||
| break; | |||||
| case List<_VariableScopeStore> values: | |||||
| ret = values[0]; | |||||
| break; | |||||
| default: | |||||
| throw new InvalidOperationException("get_variable_scope_store"); | |||||
| } | |||||
| } | } | ||||
| return scope_store; | |||||
| return ret; | |||||
| } | } | ||||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) | public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) | ||||
| @@ -14,7 +14,7 @@ namespace Tensorflow | |||||
| public Context _ctx; | public Context _ctx; | ||||
| public string _name_scope; | public string _name_scope; | ||||
| public name_scope(string name, string default_name, List<object> values) | |||||
| public name_scope(string name, string default_name = "", List<object> values = null) | |||||
| { | { | ||||
| _name = name; | _name = name; | ||||
| _default_name = default_name; | _default_name = default_name; | ||||
| @@ -12,15 +12,21 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class ops | public partial class ops | ||||
| { | { | ||||
| public static void add_to_collection(string name, object value) | |||||
| public static void add_to_collection<T>(string name, T value) | |||||
| { | { | ||||
| var graph = tf.get_default_graph(); | var graph = tf.get_default_graph(); | ||||
| graph.add_to_collection(name, value); | graph.add_to_collection(name, value); | ||||
| } | } | ||||
| public static _VariableScopeStore get_collection(string key) | |||||
| public static void add_to_collections<T>(List<string> names, T value) | |||||
| { | { | ||||
| return null;// get_default_graph().get_collection(key); | |||||
| var graph = tf.get_default_graph(); | |||||
| graph.add_to_collections(names, value); | |||||
| } | |||||
| public static object get_collection(string key) | |||||
| { | |||||
| return get_default_graph().get_collection(key); | |||||
| } | } | ||||
| public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
| @@ -27,12 +27,12 @@ namespace TensorFlowNET.Examples | |||||
| var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, | var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, | ||||
| 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3); | 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3); | ||||
| var n_samples = train_X.shape[0]; | var n_samples = train_X.shape[0]; | ||||
| // tf Graph Input | // tf Graph Input | ||||
| var X = tf.placeholder(tf.float64); | var X = tf.placeholder(tf.float64); | ||||
| var Y = tf.placeholder(tf.float64); | var Y = tf.placeholder(tf.float64); | ||||
| // Set model weights | |||||
| // Set model weights | |||||
| var W = tf.Variable(rng.randn<double>(), name: "weight"); | var W = tf.Variable(rng.randn<double>(), name: "weight"); | ||||
| var b = tf.Variable(rng.randn<double>(), name: "bias"); | var b = tf.Variable(rng.randn<double>(), name: "bias"); | ||||