| @@ -91,9 +91,18 @@ namespace Tensorflow | |||
| throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | |||
| } | |||
| public void add_to_collection(string name, object value) | |||
| public void add_to_collection<T>(string name, T value) | |||
| { | |||
| _collections[name] = value; | |||
| if (_collections.ContainsKey(name)) | |||
| (_collections[name] as List<T>).Add(value); | |||
| else | |||
| _collections[name] = new List<T> { value }; | |||
| } | |||
| public void add_to_collections<T>(List<string> names, T value) | |||
| { | |||
| foreach (string name in names) | |||
| add_to_collection(name, value); | |||
| } | |||
| public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | |||
| @@ -236,9 +245,9 @@ namespace Tensorflow | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| } | |||
| public Dictionary<string, object> get_collection(string name) | |||
| public object get_collection(string name) | |||
| { | |||
| return _collections; | |||
| return _collections.ContainsKey(name) ? _collections[name] : null; | |||
| } | |||
| public void Dispose() | |||
| @@ -20,7 +20,7 @@ namespace Tensorflow | |||
| name = op_type_name; | |||
| } | |||
| string scope = g.unique_name(name) + "/"; | |||
| string scope = new ops.name_scope(name); | |||
| var default_type_attr_map = new Dictionary<string, object>(); | |||
| foreach (var attr_def in op_def.Attr) | |||
| @@ -88,6 +88,9 @@ namespace Tensorflow | |||
| switch (attr_def.Type) | |||
| { | |||
| case "string": | |||
| attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||
| break; | |||
| case "type": | |||
| attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||
| break; | |||
| @@ -95,8 +98,12 @@ namespace Tensorflow | |||
| attr_value.B = (bool)value; | |||
| break; | |||
| case "shape": | |||
| attr_value.Shape = new TensorShapeProto(); | |||
| attr_value.Shape = value == null ? | |||
| attr_def.DefaultValue.Shape : | |||
| tensor_util.as_shape((long[])value); | |||
| break; | |||
| default: | |||
| throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); | |||
| } | |||
| attr_protos[key] = attr_value; | |||
| @@ -73,6 +73,22 @@ namespace Tensorflow | |||
| return nd; | |||
| } | |||
| public static TensorShapeProto as_shape(long[] dims) | |||
| { | |||
| TensorShapeProto shape = new TensorShapeProto(); | |||
| for (int i = 0; i < dims.Length; i++) | |||
| { | |||
| var dim = new TensorShapeProto.Types.Dim(); | |||
| dim.Size = dims[i]; | |||
| dim.Name = $"dim_{i}"; | |||
| shape.Dim.Add(dim); | |||
| } | |||
| return shape; | |||
| } | |||
| public static TensorShape as_shape(this IShape shape, int[] dims) | |||
| { | |||
| return new TensorShape(dims); | |||
| @@ -30,9 +30,14 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="loss"></param> | |||
| /// <returns></returns> | |||
| public Optimizer minimize(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP) | |||
| public Optimizer minimize(Tensor loss, | |||
| GateGradientType gate_gradients = GateGradientType.GATE_OP, | |||
| bool colocate_gradients_with_ops = false) | |||
| { | |||
| compute_gradients(loss, gate_gradients); | |||
| compute_gradients(loss, | |||
| gate_gradients: gate_gradients, | |||
| colocate_gradients_with_ops: colocate_gradients_with_ops); | |||
| return this; | |||
| } | |||
| @@ -41,7 +46,10 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="loss"></param> | |||
| /// <param name="gate_gradients"></param> | |||
| public List<KeyValuePair<object, object>> compute_gradients(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP) | |||
| public List<KeyValuePair<object, object>> compute_gradients(Tensor loss, | |||
| List<RefVariable> var_list = null, | |||
| GateGradientType gate_gradients = GateGradientType.GATE_OP, | |||
| bool colocate_gradients_with_ops = false) | |||
| { | |||
| int num_towers = 1; | |||
| if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN) | |||
| @@ -49,7 +57,19 @@ namespace Tensorflow | |||
| } | |||
| var var_list = variables.trainable_variables(); | |||
| var tmp = variables.trainable_variables(); | |||
| switch (tmp) | |||
| { | |||
| case List<RefVariable> values: | |||
| var_list = values; | |||
| break; | |||
| } | |||
| foreach(var v in var_list) | |||
| { | |||
| } | |||
| return null; | |||
| } | |||
| } | |||
| @@ -64,6 +64,28 @@ namespace Tensorflow | |||
| var shape = _initial_value.shape; | |||
| dtype = _initial_value.dtype; | |||
| _variable = gen_state_ops.variable_v2(shape, dtype, name); | |||
| // Manually overrides the variable's shape with the initial value's. | |||
| if (validate_shape) | |||
| { | |||
| var initial_value_shape = _initial_value.shape; | |||
| } | |||
| // If 'initial_value' makes use of other variables, make sure we don't | |||
| // have an issue if these other variables aren't initialized first by | |||
| // using their initialized_value() method. | |||
| ops.add_to_collections(collections, this); | |||
| } | |||
| public static implicit operator _VariableScopeStore(RefVariable variable) | |||
| { | |||
| return null; | |||
| } | |||
| public static implicit operator RefVariable(_VariableScopeStore store) | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -23,6 +23,8 @@ namespace Tensorflow | |||
| var keywords = new Dictionary<string, object>(); | |||
| keywords.Add("dtype", dtype); | |||
| keywords.Add("shape", shape); | |||
| keywords.Add("container", container); | |||
| keywords.Add("shared_name", shared_name); | |||
| var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords); | |||
| @@ -39,18 +39,30 @@ namespace Tensorflow | |||
| public static _VariableScopeStore get_variable_scope_store() | |||
| { | |||
| _VariableScopeStore ret = null; | |||
| var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); | |||
| if (scope_store == null) | |||
| { | |||
| scope_store = new _VariableScopeStore(); | |||
| ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); | |||
| ret = new _VariableScopeStore(); | |||
| ops.add_to_collection(_VARSCOPESTORE_KEY, ret); | |||
| } | |||
| else | |||
| { | |||
| // scope_store = scope_store[0]; | |||
| switch (scope_store) | |||
| { | |||
| case List<RefVariable> values: | |||
| ret = values[0]; | |||
| break; | |||
| case List<_VariableScopeStore> values: | |||
| ret = values[0]; | |||
| break; | |||
| default: | |||
| throw new InvalidOperationException("get_variable_scope_store"); | |||
| } | |||
| } | |||
| return scope_store; | |||
| return ret; | |||
| } | |||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) | |||
| @@ -14,7 +14,7 @@ namespace Tensorflow | |||
| public Context _ctx; | |||
| public string _name_scope; | |||
| public name_scope(string name, string default_name, List<object> values) | |||
| public name_scope(string name, string default_name = "", List<object> values = null) | |||
| { | |||
| _name = name; | |||
| _default_name = default_name; | |||
| @@ -12,15 +12,21 @@ namespace Tensorflow | |||
| { | |||
| public partial class ops | |||
| { | |||
| public static void add_to_collection(string name, object value) | |||
| public static void add_to_collection<T>(string name, T value) | |||
| { | |||
| var graph = tf.get_default_graph(); | |||
| graph.add_to_collection(name, value); | |||
| } | |||
| public static _VariableScopeStore get_collection(string key) | |||
| public static void add_to_collections<T>(List<string> names, T value) | |||
| { | |||
| return null;// get_default_graph().get_collection(key); | |||
| var graph = tf.get_default_graph(); | |||
| graph.add_to_collections(names, value); | |||
| } | |||
| public static object get_collection(string key) | |||
| { | |||
| return get_default_graph().get_collection(key); | |||
| } | |||
| public static Graph get_default_graph() | |||
| @@ -27,12 +27,12 @@ namespace TensorFlowNET.Examples | |||
| var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, | |||
| 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3); | |||
| var n_samples = train_X.shape[0]; | |||
| // tf Graph Input | |||
| var X = tf.placeholder(tf.float64); | |||
| var Y = tf.placeholder(tf.float64); | |||
| // Set model weights | |||
| // Set model weights | |||
| var W = tf.Variable(rng.randn<double>(), name: "weight"); | |||
| var b = tf.Variable(rng.randn<double>(), name: "bias"); | |||