diff --git a/docs/RELEASE.md b/docs/RELEASE.md index 98925ddf..62a1be23 100644 --- a/docs/RELEASE.md +++ b/docs/RELEASE.md @@ -4,6 +4,25 @@ This release contains contributions from many people at SciSharp as well as the external contributors. +**Release Date 02/06/2021** + +### TensorFlow.Binding v0.33.0 + +* Improve memory usage +* Fix minor bugs + +### TensorFlow.Keras v0.4.0 + +* Add Subtract layer + +* Add model.load_weights and model.save_weights + +* Fix memory leak issue + +* Support to build YOLOv3 object detection model + + + **Release Date 01/09/2021** ### TensorFlow.Binding v0.32.0 diff --git a/src/TensorFlowNET.Console/MemoryBasicTest.cs b/src/TensorFlowNET.Console/MemoryBasicTest.cs index 199f870c..d61cca69 100644 --- a/src/TensorFlowNET.Console/MemoryBasicTest.cs +++ b/src/TensorFlowNET.Console/MemoryBasicTest.cs @@ -56,15 +56,31 @@ namespace Tensorflow { var nd = np.zeros(1 * 256 * 256 * 3).astype(np.float32).reshape(1, 256, 256, 3); ResourceVariable variable = tf.Variable(nd); - var nd2 = np.arange(1 * 256 * 256 * 3).astype(np.float32).reshape(1, 256, 256, 3); - variable.assign(nd2); - for (int i = 0; i< 100; i++) + for (int i = 0; i< 10; i++) { var v = variable.numpy(); } }; + public Action VariableAssign + => (epoch, iterate) => + { + ResourceVariable variable = tf.Variable(3112f); + AssignVariable(variable); + for (int i = 0; i < 100; i++) + { + var v = variable.numpy(); + if ((float)v != 1984f) + throw new ValueError(""); + } + }; + + void AssignVariable(IVariableV1 v) + { + using var tensor = tf.constant(1984f); + v.assign(tensor); + } public Action MathAdd => (epoch, iterate) => diff --git a/src/TensorFlowNET.Console/Program.cs b/src/TensorFlowNET.Console/Program.cs index d65e7e6b..38b878af 100644 --- a/src/TensorFlowNET.Console/Program.cs +++ b/src/TensorFlowNET.Console/Program.cs @@ -52,6 +52,10 @@ namespace Tensorflow // 100K float variable. mm.Execute(10, batchSize, basic.Variable); + mm.Execute(10, batchSize, basic.VariableRead); + + mm.Execute(10, batchSize, basic.VariableAssign); + // 1 million math. mm.Execute(10, 100 * batchSize, basic.MathAdd); diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 8452b81a..390942d2 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -215,6 +215,9 @@ namespace Tensorflow public Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => array_ops.ones_like(tensor, dtype: dtype, name: name, optimize: optimize); + public Tensor ones_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.ones_like(nd, dtype: dtype, name: name, optimize: optimize); + public Tensor one_hot(Tensor indices, int depth, Tensor on_value = null, Tensor off_value = null, @@ -290,6 +293,9 @@ namespace Tensorflow public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); + public Tensor zeros_like(NDArray nd, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => array_ops.zeros_like(nd, dtype: dtype, name: name, optimize: optimize); + /// /// Stops gradient computation. /// diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 2d91be12..e27a5e3c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -118,6 +118,9 @@ namespace Tensorflow public Tensor cos(Tensor x, string name = null) => gen_math_ops.cos(x, name); + public Tensor cos(float x, string name = null) + => gen_math_ops.cos(x, name); + /// /// Computes hyperbolic cosine of x element-wise. /// diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 62ba0bbd..535bbca4 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -137,6 +137,8 @@ namespace Tensorflow { switch (a) { + case Tensors arr: + return arr.Length; case Array arr: return arr.Length; case IList arr: diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs index 7db178b3..b076c90f 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs @@ -28,6 +28,7 @@ namespace Tensorflow.Contexts /// public sealed partial class Context { + // [DebuggerStepThrough] public T RunInAutoMode(Func graphAction, Func eagerAction, params object[] args) { if (tf.Context.has_graph_arg(args)) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index b77c0f70..625d76a1 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -400,9 +400,22 @@ namespace Tensorflow public static Tensor reshape(Tensor tensor, object[] shape, string name = null) => gen_array_ops.reshape(tensor, shape, name: name); + private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) + { + return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope => + { + name = scope; + var tensor1 = ops.convert_to_tensor(tensor, name: "tensor"); + var ones_shape = shape_internal(tensor1, optimize: optimize); + if (dtype == TF_DataType.DtInvalid) + dtype = tensor1.dtype; + var ret = ones(ones_shape, dtype: dtype, name: name); + return ret; + }); + } + public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) { - dtype = dtype.as_base_dtype(); return tf_with(ops.name_scope(name, "ones", new { shape }), scope => { name = scope; @@ -585,11 +598,10 @@ namespace Tensorflow if (!tf.Context.executing_eagerly()) { - var input_tensor = ops.convert_to_tensor(input); - var input_shape = input_tensor.TensorShape; - if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) + var input_shape = input.TensorShape; + if (optimize && input.NDims > -1 && input_shape.is_fully_defined()) { - var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); + var nd = np.array(input.shape).astype(out_type.as_numpy_dtype()); return constant_op.constant(nd, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 3d64e8b9..5d585e77 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -124,6 +124,9 @@ namespace Tensorflow x, y).FirstOrDefault(), x, y); + public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null) + => mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); + /// /// Computes the mean of elements across dimensions of a tensor. /// Reduces `input` along the dimensions given in `axis`. Unless @@ -137,23 +140,30 @@ namespace Tensorflow /// An optional `bool`. Defaults to `False`. If true, retain reduced dimensions with length 1. /// A name for the operation (optional). /// A `Tensor`. Has the same type as `input`. - public static Tensor mean(T1 input, T2 axis, bool keep_dims = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null) + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Mean", name, new + { + input, + reduction_indices = axis, + keep_dims = keep_dims + }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Mean", name, null, input, axis, - "keep_dims", keep_dims); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); - - return _op.output; - } + "keep_dims", keep_dims).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T"), + "Tidx", op.get_attr("Tidx"), + "keep_dims", op.get_attr("keep_dims") + }; + tf.Runner.RecordGradient("Mean", op.inputs, attrs, op.outputs); + }, + new Tensors(input, axis)); public static Tensor mean(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null) { @@ -376,8 +386,18 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor cos(Tensor x, string name = null) + public static Tensor cos(T x, string name = null) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Cos", name, + null, + x); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("Cos", name, args: new { x }); return _op.outputs[0]; @@ -776,20 +796,21 @@ namespace Tensorflow } public static Tensor sub(Tensor x, Tensor y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Sub", name, new { x, y }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Sub", name, null, - x, y); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Sub", name, args: new { x, y }); - - return _op.output; - } + x, y).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T") + }; + tf.Runner.RecordGradient("Sub", op.inputs, attrs, op.outputs); + }, + new Tensors(x, y)); public static Tensor sub(Tx x, Ty y, string name = null) { diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 2c051992..391ad9d5 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -327,31 +327,17 @@ namespace Tensorflow public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) { var r = _ReductionDims(input_tensor, axis); - if (axis == null) - { - var m = gen_math_ops.mean(input_tensor, r, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } - else - { - var m = gen_math_ops.mean(input_tensor, axis, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } + var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis); + var m = gen_math_ops.mean(input_tensor, axis_tensor, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis_tensor, m); } public static Tensor reduce_mean(Tensor[] input_tensors, int? axis = null, bool keepdims = false, string name = null) { - if (axis == null) - { - var r = _ReductionDims(input_tensors, axis); - var m = gen_math_ops.mean(input_tensors, r, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } - else - { - var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); - return _may_reduce_to_scalar(keepdims, axis, m); - } + var r = _ReductionDims(input_tensors, axis); + var axis_tensor = axis == null ? r : ops.convert_to_tensor(axis.Value); + var m = gen_math_ops.mean(input_tensors, axis_tensor, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, m); } /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs index a2ad7530..e331dc1a 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs @@ -91,14 +91,16 @@ namespace Tensorflow var buffer = new byte[size][]; var data_start = c_api.TF_TensorData(_handle); - var string_start = data_start + (int)(size * sizeof(ulong)); + data_start += (int)(size * sizeof(ulong)); for (int i = 0; i < buffer.Length; i++) { - var len = *(byte*)string_start; - buffer[i] = new byte[len]; - string_start += 1; - Marshal.Copy(string_start, buffer[i], 0, len); - string_start += len; + IntPtr dst = IntPtr.Zero; + ulong dstLen = 0; + var read = c_api.TF_StringDecode((byte*)data_start, bytesize, (byte**)&dst, ref dstLen, tf.Status.Handle); + tf.Status.Check(true); + buffer[i] = new byte[(int)dstLen]; + Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); + data_start += (int)read; } return buffer; diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 1c8d939a..3c334ea5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -69,13 +69,14 @@ namespace Tensorflow => items.Insert(index, tensor); IEnumerator IEnumerable.GetEnumerator() - { - throw new NotImplementedException(); - } + => GetEnumerator(); public static implicit operator Tensors(Tensor tensor) => new Tensors(tensor); + public static implicit operator Tensors((Tensor, Tensor) tuple) + => new Tensors(tuple.Item1, tuple.Item2); + public static implicit operator Tensors(NDArray nd) => new Tensors(nd); diff --git a/src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs deleted file mode 100644 index a1d3ecbf..00000000 --- a/src/TensorFlowNET.Keras/Engine/Interfaces/ITensorFlowOpLayer.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow.Keras.ArgsDefinition; - -namespace Tensorflow.Keras.Engine -{ - public interface ITensorFlowOpLayer - { - Layer GetOpLayer(TensorFlowOpLayerArgs args); - } -} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 03125e03..a2a29770 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -142,6 +142,7 @@ namespace Tensorflow.Keras.Layers public Dense Dense(int units, Activation activation = null, IInitializer kernel_initializer = null, + bool use_bias = true, IInitializer bias_initializer = null, TensorShape input_shape = null) => new Dense(new DenseArgs @@ -149,7 +150,7 @@ namespace Tensorflow.Keras.Layers Units = units, Activation = activation ?? keras.activations.Linear, KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, - BiasInitializer = bias_initializer ?? tf.zeros_initializer, + BiasInitializer = bias_initializer ?? (use_bias ? tf.zeros_initializer : null), InputShape = input_shape }); @@ -375,6 +376,9 @@ namespace Tensorflow.Keras.Layers public Add Add() => new Add(new MergeArgs { }); + public Subtract Subtract() + => new Subtract(new MergeArgs { }); + public GlobalAveragePooling2D GlobalAveragePooling2D() => new GlobalAveragePooling2D(new Pooling2DArgs { }); diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs b/src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs new file mode 100644 index 00000000..b6a1039e --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Merging/Subtract.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class Subtract : Merge + { + public Subtract(MergeArgs args) : base(args) + { + + } + + protected override Tensors _merge_function(Tensors inputs) + { + if (len(inputs) != 2) + throw new ValueError($"A `Subtract` layer should be called on exactly 2 inputs"); + return inputs[0] - inputs[1]; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs new file mode 100644 index 00000000..1c0470fe --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -0,0 +1,73 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using Tensorflow.Graphs; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class TensorFlowOpLayer : Layer + { + TensorFlowOpLayerArgs args; + Dictionary constants => args.Constants; + NodeDef node_def => args.NodeDef; + static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"; + public string OpType => node_def.Op; + + public TensorFlowOpLayer(TensorFlowOpLayerArgs args) + : base(new LayerArgs + { + Name = TF_OP_LAYER_NAME_PREFIX + args.Name, + Trainable = args.Trainable, + DType = args.DType, + Autocast = false + }) + { + this.args = args; + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) + { + if (tf.Context.executing_eagerly()) + return _defun_call(inputs); + return MakOp(inputs); + } + + [AutoGraph] + Tensors _defun_call(Tensors inputs) + => MakOp(inputs); + + Tensors MakOp(Tensors inputs) + { + var graph = inputs.graph; + graph.as_default(); + foreach (var (index, constant) in enumerate(constants)) + { + var value = constant_op.constant(constant, name: node_def.Input[index]); + inputs.Insert(index, value); + } + + var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); + var op = graph._create_op_from_tf_operation(c_op); + op._control_flow_post_processing(); + + // Record the gradient because custom-made ops don't go through the + // code-gen'd eager call path + var op_type = op.node_def.Op; + + tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); + + graph.Exit(); + return op.outputs; + } + + public Layer GetOpLayer(TensorFlowOpLayerArgs args) + => new TensorFlowOpLayer(args); + } +} diff --git a/src/TensorFlowNET.Keras/Losses/Huber.cs b/src/TensorFlowNET.Keras/Losses/Huber.cs index 6098dee3..a256786f 100644 --- a/src/TensorFlowNET.Keras/Losses/Huber.cs +++ b/src/TensorFlowNET.Keras/Losses/Huber.cs @@ -27,10 +27,10 @@ namespace Tensorflow.Keras.Losses Tensor error = math_ops.subtract(y_pred_cast, y_true_cast); Tensor abs_error = math_ops.abs(error); Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype); - return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, - half * math_ops.pow(error, 2), + return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, + half * math_ops.pow(error, 2), half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), - axis : -1); + axis: -1); } } } diff --git a/src/TensorFlowNET.Keras/Losses/LogCosh.cs b/src/TensorFlowNET.Keras/Losses/LogCosh.cs index 1c894904..8acbbe9d 100644 --- a/src/TensorFlowNET.Keras/Losses/LogCosh.cs +++ b/src/TensorFlowNET.Keras/Losses/LogCosh.cs @@ -19,10 +19,8 @@ namespace Tensorflow.Keras.Losses Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); Tensor x = y_pred_dispatch - y_true_cast; - - return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),axis: -1); - + return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype), axis: -1); } } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs index 74c95b4a..3295b12b 100644 --- a/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs +++ b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Losses Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype)); - return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) *gen_math_ops.mean(diff, axis: -1); + return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, axis: -1); } } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs index 24ef1043..6ae7d86d 100644 --- a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Losses { Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); - return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); + return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); } } } diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs index 7ad370ae..2383c5d1 100644 --- a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs @@ -26,6 +26,9 @@ namespace Tensorflow.Keras.Optimizers protected float _initial_decay = 0.0f; protected bool _use_locking = true; + public IVariableV1 lr + => _hyper_variables["learning_rate"]; + Dictionary> _slots; List _slot_names; diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index e705b3d1..3f5ca2b9 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -21,7 +21,9 @@ * Support BatchNormalization layer. * Building keras model in subclass, functional and sequential api * Implemented backward_function. -* Support model.load_weights. +* Support model.load_weights. +* Add Subtract layer +* Support YOLOv3 model. Keras for .NET Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages. @@ -64,4 +66,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac + + + + diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index 32a1737a..39c14fa8 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -21,6 +21,7 @@ using System.Linq; using System.Reflection; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -150,12 +151,13 @@ namespace Tensorflow.Keras.Utils // recursively CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); - var op_layer = GetLayer(new TensorFlowOpLayerArgs + var opLayerArgs = new TensorFlowOpLayerArgs { NodeDef = op.node_def, Constants = constants, Name = op.name - }); + }; + var op_layer = new TensorFlowOpLayer(opLayerArgs); created_layers.Add(op_layer); op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); processed_ops.Add(op); @@ -163,20 +165,6 @@ namespace Tensorflow.Keras.Utils } } - static Layer GetLayer(LayerArgs args) - { - Layer layer = default; - var assemble = Assembly.Load("TensorFlow.Keras.Layers"); - foreach (var type in assemble.GetTypes().Where(x => x.GetInterface(typeof(T).Name) != null)) - { - layer = (Layer)Activator.CreateInstance(type, new object[] { args }); - } - - if (layer == null) - throw new NotImplementedException($"Can't find implementation for type {args.GetType().Name}"); - return layer; - } - // recusive static bool uses_keras_history(Tensor op_input) { diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md index 20d30f6f..a08959a7 100644 --- a/tensorflowlib/README.md +++ b/tensorflowlib/README.md @@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ 1. Build static library -`bazel build --config=opt //tensorflow:tensorflow` +`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow` 2. Build pip package diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index f7e6155c..62d9fa5c 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; using Tensorflow; +using static Tensorflow.Binding; using static Tensorflow.KerasApi; namespace TensorFlowNET.Keras.UnitTest @@ -39,8 +40,8 @@ namespace TensorFlowNET.Keras.UnitTest /// /// Custom layer test, used in Dueling DQN /// - [TestMethod, Ignore] - public void FunctionalTest() + [TestMethod] + public void TensorFlowOpLayer() { var layers = keras.layers; var inputs = layers.Input(shape: 24); @@ -48,58 +49,15 @@ namespace TensorFlowNET.Keras.UnitTest var value = layers.Dense(24).Apply(x); var adv = layers.Dense(1).Apply(x); - var adv_out = adv - Binding.tf.reduce_mean(adv, axis: 1, keepdims: true); // Here's problem. - var outputs = layers.Add().Apply(new Tensors(adv_out, value)); + var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); + adv = layers.Subtract().Apply((adv, mean)); + var outputs = layers.Add().Apply((value, adv)); var model = keras.Model(inputs, outputs); - model.summary(); model.compile(optimizer: keras.optimizers.RMSprop(0.001f), loss: keras.losses.MeanSquaredError(), metrics: new[] { "acc" }); - // Here we consider the adv_out is one layer, which is a little different from py's version - Assert.AreEqual(model.Layers.Count, 6); - - // py code: - //from tensorflow.keras.layers import Input, Dense, Add, Subtract, Lambda - //from tensorflow.keras.models import Model - //from tensorflow.keras.optimizers import RMSprop - //import tensorflow.keras.backend as K - - //inputs = Input(24) - //x = Dense(128, activation = "relu")(inputs) - //value = Dense(24)(x) - //adv = Dense(1)(x) - //meam = Lambda(lambda x: K.mean(x, axis = 1, keepdims = True))(adv) - //adv = Subtract()([adv, meam]) - //outputs = Add()([value, adv]) - //model = Model(inputs, outputs) - //model.compile(loss = "mse", optimizer = RMSprop(1e-3)) - //model.summary() - - //py output: - //Model: "functional_3" - //__________________________________________________________________________________________________ - //Layer(type) Output Shape Param # Connected to - //================================================================================================== - //input_2 (InputLayer) [(None, 24)] 0 - //__________________________________________________________________________________________________ - //dense_3 (Dense) (None, 128) 3200 input_2[0][0] - //__________________________________________________________________________________________________ - //dense_5 (Dense) (None, 1) 129 dense_3[0][0] - //__________________________________________________________________________________________________ - //lambda_1 (Lambda) (None, 1) 0 dense_5[0][0] - //__________________________________________________________________________________________________ - //dense_4 (Dense) (None, 24) 3096 dense_3[0][0] - //__________________________________________________________________________________________________ - //subtract_1 (Subtract) (None, 1) 0 dense_5[0][0] - // lambda_1[0][0] - //__________________________________________________________________________________________________ - //add_1 (Add) (None, 24) 0 dense_4[0][0] - // subtract_1[0][0] - //================================================================================================== - //Total params: 6,425 - //Trainable params: 6,425 - //Non-trainable params: 0 - //__________________________________________________________________________________________________ + model.summary(); + Assert.AreEqual(model.Layers.Count, 8); } /// diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs index 1a7bc633..c57c98df 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs @@ -136,23 +136,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI public void TestOnesLike() { #region 2-dimension - var testCase2D = tf.constant(new int[,] + var ones2D = tf.ones_like(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var ones2D = tf.ones_like(testCase2D); Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[0].numpy()); Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[1].numpy()); #endregion #region 1-dimension - var testCase1D = tf.constant(new int[,] + var ones1D = tf.ones_like(new int[,] { { 1, 2, 3 } }); - var ones1D = tf.ones_like(testCase1D); Assert.AreEqual(new[] { 1, 1, 1 }, ones1D[0].numpy()); #endregion @@ -162,23 +160,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI public void TestZerosLike() { #region 2-dimension - var testCase2D = tf.constant(new int[,] + var zeros2D = tf.zeros_like(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var zeros2D = tf.zeros_like(testCase2D); Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[0].numpy()); Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[1].numpy()); #endregion #region 1-dimension - var testCase1D = tf.constant(new int[,] + var zeros1D = tf.zeros_like(new int[,] { { 1, 2, 3 } }); - var zeros1D = tf.zeros_like(testCase1D); Assert.AreEqual(new[] { 0, 0, 0 }, zeros1D[0].numpy()); #endregion diff --git a/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs b/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs deleted file mode 100644 index 6647ca59..00000000 --- a/test/Tensorflow.Keras.UnitTest/OptimizerTest.cs +++ /dev/null @@ -1,11 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System.Collections.Generic; - -namespace Tensorflow.Keras.UnitTest -{ - [TestClass] - public class OptimizerTest - { - - } -} diff --git a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj deleted file mode 100644 index 5f5ab347..00000000 --- a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ /dev/null @@ -1,25 +0,0 @@ - - - - netcoreapp3.1 - - false - - AnyCPU;x64 - - - - - - - - all - runtime; build; native; contentfiles; analyzers; buildtransitive - - - - - - - -