| @@ -72,10 +72,7 @@ namespace Tensorflow | |||||
| public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) | public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) | ||||
| { | { | ||||
| if (x.NDims == 0 && y.NDims == 0 && grad.NDims == 0) return true; | |||||
| return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) && | |||||
| string.Join(",", x.shape).Equals(string.Join(",", grad.shape)); | |||||
| return x.NDims == y.NDims && y.NDims == grad.NDims && x.NDims > -1; | |||||
| } | } | ||||
| public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) | public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) | ||||
| @@ -110,14 +107,15 @@ namespace Tensorflow | |||||
| x = math_ops.conj(x); | x = math_ops.conj(x); | ||||
| y = math_ops.conj(y); | y = math_ops.conj(y); | ||||
| var realdiv1 = gen_math_ops.real_div(grad, y); | |||||
| var reduce_sum1 = math_ops.reduce_sum(realdiv1, rx); | |||||
| var realdiv2 = gen_math_ops.real_div(-x, y); | |||||
| var realdiv3 = gen_math_ops.real_div(realdiv2, y); | |||||
| var mul = grad * realdiv3; | |||||
| var reduce_sum2 = math_ops.reduce_sum(mul, ry); | |||||
| var realdiv1 = gen_math_ops.real_div(-x, y); | |||||
| var realdiv2 = gen_math_ops.real_div(realdiv1, y); | |||||
| var reduce_sum1 = math_ops.reduce_sum(grad * realdiv2, ry); | |||||
| var reshape1 = gen_array_ops.reshape(reduce_sum1, sy); | |||||
| var realdiv3 = gen_math_ops.real_div(grad, y); | |||||
| var reduce_sum2 = math_ops.reduce_sum(realdiv3, rx); | |||||
| var reshape2 = gen_array_ops.reshape(reduce_sum2, sx); | |||||
| return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy)); | |||||
| return (reshape2, reshape1); | |||||
| } | } | ||||
| public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad) | public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad) | ||||
| @@ -135,17 +133,16 @@ namespace Tensorflow | |||||
| var gx = gen_array_ops.reshape(math_ops.reduce_sum(grad * y * gen_math_ops.pow(x, y - 1.0), rx), sx); | var gx = gen_array_ops.reshape(math_ops.reduce_sum(grad * y * gen_math_ops.pow(x, y - 1.0), rx), sx); | ||||
| Tensor log_x = null; | Tensor log_x = null; | ||||
| // Avoid false singularity at x = 0 | // Avoid false singularity at x = 0 | ||||
| Tensor mask = null; | |||||
| if (x.dtype.is_complex()) | if (x.dtype.is_complex()) | ||||
| { | |||||
| throw new NotImplementedException("x.dtype.is_complex()"); | throw new NotImplementedException("x.dtype.is_complex()"); | ||||
| } | |||||
| else | else | ||||
| { | |||||
| var x1 = gen_array_ops.log(x); | |||||
| var y1 = array_ops.zeros_like(x); | |||||
| log_x = array_ops.where(x > 0.0, x1, y1); | |||||
| } | |||||
| mask = x > 0.0f; | |||||
| var ones = array_ops.ones_like(x); | |||||
| var safe_x = array_ops.where(mask, x, ones); | |||||
| var x1 = gen_array_ops.log(safe_x); | |||||
| var y1 = array_ops.zeros_like(x); | |||||
| log_x = array_ops.where(mask, x1, y1); | |||||
| var gy = gen_array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy); | var gy = gen_array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy); | ||||
| return (gx, gy); | return (gx, gy); | ||||
| @@ -357,6 +357,13 @@ namespace Tensorflow | |||||
| return _collections.ContainsKey(name) ? _collections[name] : null; | return _collections.ContainsKey(name) ? _collections[name] : null; | ||||
| } | } | ||||
| public object get_collection_ref(string name) | |||||
| { | |||||
| if (!_collections.ContainsKey(name)) | |||||
| _collections[name] = new List<object>(); | |||||
| return _collections[name]; | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| c_api.TF_DeleteGraph(_handle); | c_api.TF_DeleteGraph(_handle); | ||||
| @@ -55,6 +55,43 @@ namespace Tensorflow | |||||
| return math_ops.rank_internal(input, name, optimize: true); | return math_ops.rank_internal(input, name, optimize: true); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Creates a tensor with all elements set to 1. | |||||
| /// </summary> | |||||
| /// <param name="tensor"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="optimize"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool optimize = true) | |||||
| => ones_like_impl(tensor, dtype, name, optimize); | |||||
| private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | |||||
| { | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "ones_like", new { tensor }), scope => | |||||
| { | |||||
| name = scope; | |||||
| var tensor1 = ops.convert_to_tensor(tensor, name: "tensor"); | |||||
| var ones_shape = shape_internal(tensor1, optimize: optimize); | |||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| dtype = tensor1.dtype; | |||||
| var ret = ones(ones_shape, dtype: dtype, name: name); | |||||
| ret.shape = tensor1.shape; | |||||
| return ret; | |||||
| }); | |||||
| } | |||||
| public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | |||||
| { | |||||
| dtype = dtype.as_base_dtype(); | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "ones", new { shape }), scope => | |||||
| { | |||||
| name = scope; | |||||
| var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); | |||||
| return output; | |||||
| }); | |||||
| } | |||||
| public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = "") | public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = "") | ||||
| { | { | ||||
| if( x == null && y == null) | if( x == null && y == null) | ||||
| @@ -111,7 +111,7 @@ namespace Tensorflow | |||||
| if (delta == null) | if (delta == null) | ||||
| delta = 1; | delta = 1; | ||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope => | |||||
| return with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope => | |||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| var start1 = ops.convert_to_tensor(start, name: "start"); | var start1 = ops.convert_to_tensor(start, name: "start"); | ||||
| @@ -124,15 +124,15 @@ namespace Tensorflow | |||||
| public static Tensor floordiv(Tensor x, Tensor y, string name = "") | public static Tensor floordiv(Tensor x, Tensor y, string name = "") | ||||
| { | { | ||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "floordiv", new object[] { }), scope => | |||||
| return with<ops.name_scope, Tensor>(new ops.name_scope("", "floordiv", new { x, y }), scope => | |||||
| { | { | ||||
| return gen_math_ops.floor_div(x, y, name); | |||||
| return gen_math_ops.floor_div(x, y, scope); | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true) | public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true) | ||||
| { | { | ||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope => | |||||
| return with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope => | |||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| var input_tensor = ops.convert_to_tensor(input); | var input_tensor = ops.convert_to_tensor(input); | ||||
| @@ -63,31 +63,6 @@ namespace Tensorflow | |||||
| break; | break; | ||||
| case "Single": | case "Single": | ||||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | ||||
| /*if (nd.size > 1) | |||||
| { | |||||
| var bb = nd.Data<byte>(); | |||||
| var bytes = Marshal.AllocHGlobal(bb.Length); | |||||
| Marshal.Copy(bb, 0, bytes, bb.Length); | |||||
| ulong bytes_len = c_api.TF_StringEncodedSize((ulong)bb.Length); | |||||
| var dataTypeByte = ToTFDataType(nd.dtype); | |||||
| // shape | |||||
| var dims2 = nd.shape.Select(x => (long)x).ToArray(); | |||||
| var tfHandle2 = c_api.TF_AllocateTensor(dataTypeByte, | |||||
| dims2, | |||||
| nd.ndim, | |||||
| bytes_len + sizeof(Int64)); | |||||
| dotHandle = c_api.TF_TensorData(tfHandle2); | |||||
| Marshal.WriteInt64(dotHandle, 0); | |||||
| c_api.TF_StringEncode(bytes, (ulong)bb.Length, dotHandle + sizeof(Int64), bytes_len, status); | |||||
| return tfHandle2; | |||||
| } | |||||
| else | |||||
| { | |||||
| Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||||
| }*/ | |||||
| break; | break; | ||||
| case "Double": | case "Double": | ||||
| Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); | Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); | ||||
| @@ -27,8 +27,10 @@ namespace Tensorflow | |||||
| public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); | public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); | ||||
| public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); | public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); | ||||
| public static Tensor operator >(Tensor x, float y) => gen_array_ops.greater(x, y); | |||||
| public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); | public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); | ||||
| public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); | public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); | ||||
| public static Tensor operator <(Tensor x, float y) => gen_array_ops.less(x, y); | |||||
| public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); | public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); | ||||
| private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | ||||
| @@ -68,7 +68,12 @@ namespace Tensorflow | |||||
| c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), value, value.Length, status); | c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), value, value.Length, status); | ||||
| } | } | ||||
| } | } | ||||
| public int[] _shape_tuple() | |||||
| { | |||||
| return null; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// number of dimensions | /// number of dimensions | ||||
| /// 0 Scalar (magnitude only) | /// 0 Scalar (magnitude only) | ||||
| @@ -6,7 +6,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class GradientDescentOptimizer : Optimizer | public class GradientDescentOptimizer : Optimizer | ||||
| { | { | ||||
| public GradientDescentOptimizer(double learning_rate, bool use_locking = false, string name = "GradientDescent") | |||||
| public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent") | |||||
| : base(learning_rate, use_locking, name) | : base(learning_rate, use_locking, name) | ||||
| { | { | ||||
| LearningRate = learning_rate; | LearningRate = learning_rate; | ||||
| @@ -20,14 +20,14 @@ namespace Tensorflow | |||||
| public static int GATE_GRAPH = 2; | public static int GATE_GRAPH = 2; | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public double LearningRate { get; set; } | |||||
| public float LearningRate { get; set; } | |||||
| public Tensor LearningRateTensor { get; set; } | public Tensor LearningRateTensor { get; set; } | ||||
| public bool _use_locking; | public bool _use_locking; | ||||
| public Dictionary<string, object> _slots; | public Dictionary<string, object> _slots; | ||||
| public Dictionary<string, object> _non_slot_dict; | public Dictionary<string, object> _non_slot_dict; | ||||
| public Dictionary<string, object> _deferred_slot_restorations; | public Dictionary<string, object> _deferred_slot_restorations; | ||||
| public Optimizer(double learning_rate, bool use_locking, string name = "") | |||||
| public Optimizer(float learning_rate, bool use_locking, string name = "") | |||||
| { | { | ||||
| if (String.IsNullOrEmpty(name)) | if (String.IsNullOrEmpty(name)) | ||||
| throw new NotImplementedException("Must specify the optimizer name"); | throw new NotImplementedException("Must specify the optimizer name"); | ||||
| @@ -114,6 +114,13 @@ namespace Tensorflow | |||||
| } | } | ||||
| if (!tf.context.executing_eagerly()) | |||||
| { | |||||
| var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<object>; | |||||
| if (!train_op.Contains(apply_updates)) | |||||
| train_op.Add(apply_updates); | |||||
| } | |||||
| return apply_updates; | return apply_updates; | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -9,7 +9,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static class train | public static class train | ||||
| { | { | ||||
| public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); | |||||
| public static Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate); | |||||
| public static Saver Saver() => new Saver(); | public static Saver Saver() => new Saver(); | ||||
| @@ -33,6 +33,8 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public static string GLOBAL_VARIABLES = "variables"; | public static string GLOBAL_VARIABLES = "variables"; | ||||
| public static string TRAIN_OP = "train_op"; | |||||
| public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" }; | public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" }; | ||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | ||||
| @@ -45,6 +45,11 @@ namespace Tensorflow | |||||
| return get_default_graph().get_collection(key, scope); | return get_default_graph().get_collection(key, scope); | ||||
| } | } | ||||
| public static object get_collection_ref(string key) | |||||
| { | |||||
| return get_default_graph().get_collection_ref(key); | |||||
| } | |||||
| private static Graph default_graph; | private static Graph default_graph; | ||||
| public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
| { | { | ||||
| @@ -21,7 +21,7 @@ namespace TensorFlowNET.Examples | |||||
| // Parameters | // Parameters | ||||
| float learning_rate = 0.01f; | float learning_rate = 0.01f; | ||||
| int training_epochs = 1000; | int training_epochs = 1000; | ||||
| int display_step = 1; | |||||
| int display_step = 10; | |||||
| // Training Data | // Training Data | ||||
| var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, | var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, | ||||
| @@ -29,9 +29,9 @@ namespace TensorFlowNET.Examples | |||||
| var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, | var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, | ||||
| 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); | 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); | ||||
| var n_samples = train_X.shape[0]; | var n_samples = train_X.shape[0]; | ||||
| // tf Graph Input | // tf Graph Input | ||||
| var X = tf.placeholder(tf.float32); | |||||
| /*var X = tf.placeholder(tf.float32); | |||||
| var Y = tf.placeholder(tf.float32); | var Y = tf.placeholder(tf.float32); | ||||
| // Set model weights | // Set model weights | ||||
| @@ -55,7 +55,14 @@ namespace TensorFlowNET.Examples | |||||
| // radient descent | // radient descent | ||||
| // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | ||||
| var grad = tf.train.GradientDescentOptimizer(learning_rate); | var grad = tf.train.GradientDescentOptimizer(learning_rate); | ||||
| var optimizer = grad.minimize(cost); | |||||
| var optimizer = grad.minimize(cost);*/ | |||||
| var new_saver = tf.train.import_meta_graph("save_model.meta", import_scope: "import"); | |||||
| var X = graph.OperationByName("Placeholder"); | |||||
| var Y = graph.OperationByName("Placeholder_1"); | |||||
| var W = graph.OperationByName("weight"); | |||||
| var optimizer = graph.OperationByName("GradientDescent"); | |||||
| // Initialize the variables (i.e. assign their default value) | // Initialize the variables (i.e. assign their default value) | ||||
| var init = tf.global_variables_initializer(); | var init = tf.global_variables_initializer(); | ||||
| @@ -71,14 +78,15 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | foreach (var (x, y) in zip<float>(train_X, train_Y)) | ||||
| { | { | ||||
| var w = sess.run(W); | |||||
| sess.run(optimizer, | sess.run(optimizer, | ||||
| new FeedItem(X, x), | new FeedItem(X, x), | ||||
| new FeedItem(Y, y)); | new FeedItem(Y, y)); | ||||
| var w = sess.run(W); | |||||
| w = sess.run(W); | |||||
| } | } | ||||
| // Display logs per epoch step | // Display logs per epoch step | ||||
| if ((epoch + 1) % display_step == 0) | |||||
| /*if ((epoch + 1) % display_step == 0) | |||||
| { | { | ||||
| var c = sess.run(cost, | var c = sess.run(cost, | ||||
| new FeedItem(X, train_X), | new FeedItem(X, train_X), | ||||
| @@ -86,7 +94,7 @@ namespace TensorFlowNET.Examples | |||||
| var rW = sess.run(W); | var rW = sess.run(W); | ||||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + | Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + | ||||
| $"W={rW} b={sess.run(b)}"); | $"W={rW} b={sess.run(b)}"); | ||||
| } | |||||
| }*/ | |||||
| } | } | ||||
| Console.WriteLine("Optimization Finished!"); | Console.WriteLine("Optimization Finished!"); | ||||