| @@ -72,10 +72,7 @@ namespace Tensorflow | |||
| public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) | |||
| { | |||
| if (x.NDims == 0 && y.NDims == 0 && grad.NDims == 0) return true; | |||
| return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) && | |||
| string.Join(",", x.shape).Equals(string.Join(",", grad.shape)); | |||
| return x.NDims == y.NDims && y.NDims == grad.NDims && x.NDims > -1; | |||
| } | |||
| public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) | |||
| @@ -110,14 +107,15 @@ namespace Tensorflow | |||
| x = math_ops.conj(x); | |||
| y = math_ops.conj(y); | |||
| var realdiv1 = gen_math_ops.real_div(grad, y); | |||
| var reduce_sum1 = math_ops.reduce_sum(realdiv1, rx); | |||
| var realdiv2 = gen_math_ops.real_div(-x, y); | |||
| var realdiv3 = gen_math_ops.real_div(realdiv2, y); | |||
| var mul = grad * realdiv3; | |||
| var reduce_sum2 = math_ops.reduce_sum(mul, ry); | |||
| var realdiv1 = gen_math_ops.real_div(-x, y); | |||
| var realdiv2 = gen_math_ops.real_div(realdiv1, y); | |||
| var reduce_sum1 = math_ops.reduce_sum(grad * realdiv2, ry); | |||
| var reshape1 = gen_array_ops.reshape(reduce_sum1, sy); | |||
| var realdiv3 = gen_math_ops.real_div(grad, y); | |||
| var reduce_sum2 = math_ops.reduce_sum(realdiv3, rx); | |||
| var reshape2 = gen_array_ops.reshape(reduce_sum2, sx); | |||
| return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy)); | |||
| return (reshape2, reshape1); | |||
| } | |||
| public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad) | |||
| @@ -135,17 +133,16 @@ namespace Tensorflow | |||
| var gx = gen_array_ops.reshape(math_ops.reduce_sum(grad * y * gen_math_ops.pow(x, y - 1.0), rx), sx); | |||
| Tensor log_x = null; | |||
| // Avoid false singularity at x = 0 | |||
| Tensor mask = null; | |||
| if (x.dtype.is_complex()) | |||
| { | |||
| throw new NotImplementedException("x.dtype.is_complex()"); | |||
| } | |||
| else | |||
| { | |||
| var x1 = gen_array_ops.log(x); | |||
| var y1 = array_ops.zeros_like(x); | |||
| log_x = array_ops.where(x > 0.0, x1, y1); | |||
| } | |||
| mask = x > 0.0f; | |||
| var ones = array_ops.ones_like(x); | |||
| var safe_x = array_ops.where(mask, x, ones); | |||
| var x1 = gen_array_ops.log(safe_x); | |||
| var y1 = array_ops.zeros_like(x); | |||
| log_x = array_ops.where(mask, x1, y1); | |||
| var gy = gen_array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy); | |||
| return (gx, gy); | |||
| @@ -357,6 +357,13 @@ namespace Tensorflow | |||
| return _collections.ContainsKey(name) ? _collections[name] : null; | |||
| } | |||
| public object get_collection_ref(string name) | |||
| { | |||
| if (!_collections.ContainsKey(name)) | |||
| _collections[name] = new List<object>(); | |||
| return _collections[name]; | |||
| } | |||
| public void Dispose() | |||
| { | |||
| c_api.TF_DeleteGraph(_handle); | |||
| @@ -55,6 +55,43 @@ namespace Tensorflow | |||
| return math_ops.rank_internal(input, name, optimize: true); | |||
| } | |||
| /// <summary> | |||
| /// Creates a tensor with all elements set to 1. | |||
| /// </summary> | |||
| /// <param name="tensor"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="optimize"></param> | |||
| /// <returns></returns> | |||
| public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool optimize = true) | |||
| => ones_like_impl(tensor, dtype, name, optimize); | |||
| private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "ones_like", new { tensor }), scope => | |||
| { | |||
| name = scope; | |||
| var tensor1 = ops.convert_to_tensor(tensor, name: "tensor"); | |||
| var ones_shape = shape_internal(tensor1, optimize: optimize); | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = tensor1.dtype; | |||
| var ret = ones(ones_shape, dtype: dtype, name: name); | |||
| ret.shape = tensor1.shape; | |||
| return ret; | |||
| }); | |||
| } | |||
| public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | |||
| { | |||
| dtype = dtype.as_base_dtype(); | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "ones", new { shape }), scope => | |||
| { | |||
| name = scope; | |||
| var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); | |||
| return output; | |||
| }); | |||
| } | |||
| public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = "") | |||
| { | |||
| if( x == null && y == null) | |||
| @@ -111,7 +111,7 @@ namespace Tensorflow | |||
| if (delta == null) | |||
| delta = 1; | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope => | |||
| return with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope => | |||
| { | |||
| name = scope; | |||
| var start1 = ops.convert_to_tensor(start, name: "start"); | |||
| @@ -124,15 +124,15 @@ namespace Tensorflow | |||
| public static Tensor floordiv(Tensor x, Tensor y, string name = "") | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "floordiv", new object[] { }), scope => | |||
| return with<ops.name_scope, Tensor>(new ops.name_scope("", "floordiv", new { x, y }), scope => | |||
| { | |||
| return gen_math_ops.floor_div(x, y, name); | |||
| return gen_math_ops.floor_div(x, y, scope); | |||
| }); | |||
| } | |||
| public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true) | |||
| { | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope => | |||
| return with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope => | |||
| { | |||
| name = scope; | |||
| var input_tensor = ops.convert_to_tensor(input); | |||
| @@ -63,31 +63,6 @@ namespace Tensorflow | |||
| break; | |||
| case "Single": | |||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||
| /*if (nd.size > 1) | |||
| { | |||
| var bb = nd.Data<byte>(); | |||
| var bytes = Marshal.AllocHGlobal(bb.Length); | |||
| Marshal.Copy(bb, 0, bytes, bb.Length); | |||
| ulong bytes_len = c_api.TF_StringEncodedSize((ulong)bb.Length); | |||
| var dataTypeByte = ToTFDataType(nd.dtype); | |||
| // shape | |||
| var dims2 = nd.shape.Select(x => (long)x).ToArray(); | |||
| var tfHandle2 = c_api.TF_AllocateTensor(dataTypeByte, | |||
| dims2, | |||
| nd.ndim, | |||
| bytes_len + sizeof(Int64)); | |||
| dotHandle = c_api.TF_TensorData(tfHandle2); | |||
| Marshal.WriteInt64(dotHandle, 0); | |||
| c_api.TF_StringEncode(bytes, (ulong)bb.Length, dotHandle + sizeof(Int64), bytes_len, status); | |||
| return tfHandle2; | |||
| } | |||
| else | |||
| { | |||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||
| }*/ | |||
| break; | |||
| case "Double": | |||
| Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); | |||
| @@ -27,8 +27,10 @@ namespace Tensorflow | |||
| public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); | |||
| public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); | |||
| public static Tensor operator >(Tensor x, float y) => gen_array_ops.greater(x, y); | |||
| public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); | |||
| public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); | |||
| public static Tensor operator <(Tensor x, float y) => gen_array_ops.less(x, y); | |||
| public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); | |||
| private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | |||
| @@ -68,7 +68,12 @@ namespace Tensorflow | |||
| c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), value, value.Length, status); | |||
| } | |||
| } | |||
| public int[] _shape_tuple() | |||
| { | |||
| return null; | |||
| } | |||
| /// <summary> | |||
| /// number of dimensions | |||
| /// 0 Scalar (magnitude only) | |||
| @@ -6,7 +6,7 @@ namespace Tensorflow | |||
| { | |||
| public class GradientDescentOptimizer : Optimizer | |||
| { | |||
| public GradientDescentOptimizer(double learning_rate, bool use_locking = false, string name = "GradientDescent") | |||
| public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") | |||
| : base(learning_rate, use_locking, name) | |||
| { | |||
| LearningRate = learning_rate; | |||
| @@ -20,14 +20,14 @@ namespace Tensorflow | |||
| public static int GATE_GRAPH = 2; | |||
| public string Name { get; set; } | |||
| public double LearningRate { get; set; } | |||
| public float LearningRate { get; set; } | |||
| public Tensor LearningRateTensor { get; set; } | |||
| public bool _use_locking; | |||
| public Dictionary<string, object> _slots; | |||
| public Dictionary<string, object> _non_slot_dict; | |||
| public Dictionary<string, object> _deferred_slot_restorations; | |||
| public Optimizer(double learning_rate, bool use_locking, string name = "") | |||
| public Optimizer(float learning_rate, bool use_locking, string name = "") | |||
| { | |||
| if (String.IsNullOrEmpty(name)) | |||
| throw new NotImplementedException("Must specify the optimizer name"); | |||
| @@ -114,6 +114,13 @@ namespace Tensorflow | |||
| } | |||
| if (!tf.context.executing_eagerly()) | |||
| { | |||
| var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<object>; | |||
| if (!train_op.Contains(apply_updates)) | |||
| train_op.Add(apply_updates); | |||
| } | |||
| return apply_updates; | |||
| }); | |||
| } | |||
| @@ -9,7 +9,7 @@ namespace Tensorflow | |||
| { | |||
| public static class train | |||
| { | |||
| public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); | |||
| public static Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate); | |||
| public static Saver Saver() => new Saver(); | |||
| @@ -33,6 +33,8 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public static string GLOBAL_VARIABLES = "variables"; | |||
| public static string TRAIN_OP = "train_op"; | |||
| public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" }; | |||
| /// <summary> | |||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
| @@ -45,6 +45,11 @@ namespace Tensorflow | |||
| return get_default_graph().get_collection(key, scope); | |||
| } | |||
| public static object get_collection_ref(string key) | |||
| { | |||
| return get_default_graph().get_collection_ref(key); | |||
| } | |||
| private static Graph default_graph; | |||
| public static Graph get_default_graph() | |||
| { | |||
| @@ -21,7 +21,7 @@ namespace TensorFlowNET.Examples | |||
| // Parameters | |||
| float learning_rate = 0.01f; | |||
| int training_epochs = 1000; | |||
| int display_step = 1; | |||
| int display_step = 10; | |||
| // Training Data | |||
| var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, | |||
| @@ -29,9 +29,9 @@ namespace TensorFlowNET.Examples | |||
| var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, | |||
| 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); | |||
| var n_samples = train_X.shape[0]; | |||
| // tf Graph Input | |||
| var X = tf.placeholder(tf.float32); | |||
| /*var X = tf.placeholder(tf.float32); | |||
| var Y = tf.placeholder(tf.float32); | |||
| // Set model weights | |||
| @@ -55,7 +55,14 @@ namespace TensorFlowNET.Examples | |||
| // radient descent | |||
| // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | |||
| var grad = tf.train.GradientDescentOptimizer(learning_rate); | |||
| var optimizer = grad.minimize(cost); | |||
| var optimizer = grad.minimize(cost);*/ | |||
| var new_saver = tf.train.import_meta_graph("save_model.meta", import_scope: "import"); | |||
| var X = graph.OperationByName("Placeholder"); | |||
| var Y = graph.OperationByName("Placeholder_1"); | |||
| var W = graph.OperationByName("weight"); | |||
| var optimizer = graph.OperationByName("GradientDescent"); | |||
| // Initialize the variables (i.e. assign their default value) | |||
| var init = tf.global_variables_initializer(); | |||
| @@ -71,14 +78,15 @@ namespace TensorFlowNET.Examples | |||
| { | |||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | |||
| { | |||
| var w = sess.run(W); | |||
| sess.run(optimizer, | |||
| new FeedItem(X, x), | |||
| new FeedItem(Y, y)); | |||
| var w = sess.run(W); | |||
| w = sess.run(W); | |||
| } | |||
| // Display logs per epoch step | |||
| if ((epoch + 1) % display_step == 0) | |||
| /*if ((epoch + 1) % display_step == 0) | |||
| { | |||
| var c = sess.run(cost, | |||
| new FeedItem(X, train_X), | |||
| @@ -86,7 +94,7 @@ namespace TensorFlowNET.Examples | |||
| var rW = sess.run(W); | |||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + | |||
| $"W={rW} b={sess.run(b)}"); | |||
| } | |||
| }*/ | |||
| } | |||
| Console.WriteLine("Optimization Finished!"); | |||