| @@ -337,3 +337,5 @@ test/TensorFlowNET.Examples/mnist | |||||
| # training model resources | # training model resources | ||||
| .resources | .resources | ||||
| /redist | /redist | ||||
| *.xml | |||||
| *.xsd | |||||
| @@ -18,7 +18,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| public Tensor reshape<T>(T tensor, | |||||
| public Tensor reshape(Tensor tensor, | |||||
| TensorShape shape, | TensorShape shape, | ||||
| string name = null) => gen_array_ops.reshape(tensor, shape, name); | string name = null) => gen_array_ops.reshape(tensor, shape, name); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -42,8 +43,8 @@ namespace Tensorflow.Framework | |||||
| var values_shape = values.TensorShape.with_rank(1); | var values_shape = values.TensorShape.with_rank(1); | ||||
| var dense_shape_shape = dense_shape.TensorShape.with_rank(1); | var dense_shape_shape = dense_shape.TensorShape.with_rank(1); | ||||
| indices_shape[0].merge_with(values_shape.dims[0]); | |||||
| indices_shape[1].merge_with(dense_shape_shape.dims[0]); | |||||
| indices_shape["0"].merge_with(values_shape[0]); | |||||
| indices_shape["1"].merge_with(dense_shape_shape[0]); | |||||
| _shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray()); | _shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray()); | ||||
| } | } | ||||
| @@ -6,5 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class ModelArgs : LayerArgs | public class ModelArgs : LayerArgs | ||||
| { | { | ||||
| public Tensor[] Inputs { get; set; } | |||||
| public Tensor[] Outputs { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -12,6 +12,6 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| public int[] NodeIndices { get; set; } | public int[] NodeIndices { get; set; } | ||||
| public int[] TensorIndices { get; set; } | public int[] TensorIndices { get; set; } | ||||
| public Tensor InputTensors { get; set; } | public Tensor InputTensors { get; set; } | ||||
| public Tensor Outputs { get; set; } | |||||
| public Tensors Outputs { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||||
| _channels_first = args.DataFormat == "channels_first"; | _channels_first = args.DataFormat == "channels_first"; | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| if (_channels_first) | if (_channels_first) | ||||
| { | { | ||||
| @@ -0,0 +1,29 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Engine | |||||
| { | |||||
| /// <summary> | |||||
| /// Tracks the Layer call that created a Tensor, for Keras Graph Networks. | |||||
| /// </summary> | |||||
| public class KerasHistory | |||||
| { | |||||
| Layer layer; | |||||
| int node_index; | |||||
| int tensor_index; | |||||
| public KerasHistory(Layer layer, int node_index, int tensor_index) | |||||
| { | |||||
| this.layer = layer; | |||||
| this.node_index = node_index; | |||||
| this.tensor_index = tensor_index; | |||||
| } | |||||
| public static implicit operator Layer(KerasHistory history) | |||||
| => history.layer; | |||||
| public static implicit operator (Layer, int, int)(KerasHistory history) | |||||
| => (history.layer, history.node_index, history.tensor_index); | |||||
| } | |||||
| } | |||||
| @@ -119,11 +119,12 @@ namespace Tensorflow.Keras.Engine | |||||
| /// Wraps `call`, applying pre- and post-processing steps. | /// Wraps `call`, applying pre- and post-processing steps. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="input"></param> | /// <param name="input"></param> | ||||
| /// <param name="state"></param> | |||||
| /// <param name="is_training"></param> | /// <param name="is_training"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor Apply(Tensor inputs, bool is_training = false) | |||||
| public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| Tensor outputs = null; | |||||
| Tensors outputs = null; | |||||
| callContext = callContext ?? new ThreadLocal<CallContext>() | callContext = callContext ?? new ThreadLocal<CallContext>() | ||||
| { | { | ||||
| @@ -148,7 +149,7 @@ namespace Tensorflow.Keras.Engine | |||||
| if (!built) | if (!built) | ||||
| MaybeBuild(inputs); | MaybeBuild(inputs); | ||||
| outputs = call(inputs, is_training: is_training); | |||||
| outputs = call(inputs, state: state, is_training: is_training); | |||||
| outputs = _set_connectivity_metadata_(inputs, outputs); | outputs = _set_connectivity_metadata_(inputs, outputs); | ||||
| _handle_activity_regularization(inputs, outputs); | _handle_activity_regularization(inputs, outputs); | ||||
| @@ -161,36 +162,7 @@ namespace Tensorflow.Keras.Engine | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| public Tensor[] Apply(Tensor[] inputs, Tensor state, bool is_training = false) | |||||
| { | |||||
| Tensor[] outputs = null; | |||||
| callContext = callContext ?? new ThreadLocal<CallContext>() | |||||
| { | |||||
| Value = new CallContext() | |||||
| }; | |||||
| var eager = tf.executing_eagerly(); | |||||
| using var ctxManager = CallContext.enter(); | |||||
| string nameScope = ""; | |||||
| if (eager) | |||||
| nameScope = name; | |||||
| else | |||||
| nameScope = _name_scope(); | |||||
| tf_with(ops.name_scope(nameScope), scope => | |||||
| { | |||||
| if (!built) | |||||
| MaybeBuild(inputs[0]); | |||||
| outputs = call(inputs, is_training: is_training, state: state); | |||||
| }); | |||||
| return outputs; | |||||
| } | |||||
| private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) | |||||
| private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) | |||||
| { | { | ||||
| /*var returnOutputs = new List<Tensor>(); | /*var returnOutputs = new List<Tensor>(); | ||||
| foreach(var x in outputs) | foreach(var x in outputs) | ||||
| @@ -211,7 +183,7 @@ namespace Tensorflow.Keras.Engine | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | |||||
| private void _handle_activity_regularization(Tensors inputs, Tensors outputs) | |||||
| { | { | ||||
| //if(_activity_regularizer != null) | //if(_activity_regularizer != null) | ||||
| { | { | ||||
| @@ -219,7 +191,7 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| } | } | ||||
| private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) | |||||
| private void _set_mask_metadata(Tensors inputs, Tensors outputs, Tensors previous_mask) | |||||
| { | { | ||||
| } | } | ||||
| @@ -229,12 +201,7 @@ namespace Tensorflow.Keras.Engine | |||||
| return null; | return null; | ||||
| } | } | ||||
| protected virtual Tensor call(Tensor inputs, bool is_training = false) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| protected virtual Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) | |||||
| protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| @@ -244,7 +211,7 @@ namespace Tensorflow.Keras.Engine | |||||
| return Name; | return Name; | ||||
| } | } | ||||
| protected void MaybeBuild(Tensor inputs) | |||||
| protected void MaybeBuild(Tensors inputs) | |||||
| { | { | ||||
| // Check input assumptions set before layer building, e.g. input rank. | // Check input assumptions set before layer building, e.g. input rank. | ||||
| if (built) | if (built) | ||||
| @@ -252,7 +219,7 @@ namespace Tensorflow.Keras.Engine | |||||
| if (DType == TF_DataType.DtInvalid) | if (DType == TF_DataType.DtInvalid) | ||||
| args.DType = inputs.dtype; | args.DType = inputs.dtype; | ||||
| var input_shapes = inputs.TensorShape; | |||||
| var input_shapes = inputs.shape; | |||||
| build(input_shapes); | build(input_shapes); | ||||
| built = true; | built = true; | ||||
| } | } | ||||
| @@ -27,7 +27,11 @@ namespace Tensorflow.Keras.Engine | |||||
| public Model(ModelArgs args) | public Model(ModelArgs args) | ||||
| : base(args) | : base(args) | ||||
| { | { | ||||
| // Build _output_layers | |||||
| /*foreach(var x in args.Outputs) | |||||
| { | |||||
| var layer = x.KerasHistory; | |||||
| }*/ | |||||
| } | } | ||||
| public void compile(string optimizerName, string lossName) | public void compile(string optimizerName, string lossName) | ||||
| @@ -35,8 +35,8 @@ namespace Tensorflow.Keras.Engine | |||||
| public int[] node_indices; | public int[] node_indices; | ||||
| public int[] tensor_indices; | public int[] tensor_indices; | ||||
| public Tensor input_tensors; | |||||
| public Tensor Outputs => args.Outputs; | |||||
| public Tensors input_tensors; | |||||
| public Tensors Outputs => args.Outputs; | |||||
| public TensorShape[] input_shapes; | public TensorShape[] input_shapes; | ||||
| public TensorShape[] output_shapes; | public TensorShape[] output_shapes; | ||||
| List<Layer> kerasInputs; | List<Layer> kerasInputs; | ||||
| @@ -57,7 +57,8 @@ namespace Tensorflow.Keras.Engine | |||||
| // Set metadata on outputs. | // Set metadata on outputs. | ||||
| var node_index = layer.InboundNodes.Count - 1; | var node_index = layer.InboundNodes.Count - 1; | ||||
| args.Outputs.KerasHistory.Add(layer); | |||||
| foreach (var (i, tensor) in enumerate(Outputs)) | |||||
| tensor.KerasHistory = new KerasHistory(layer, node_index, i); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -60,7 +60,7 @@ namespace Tensorflow.Keras.Engine | |||||
| public void add(Tensor tensor) | public void add(Tensor tensor) | ||||
| { | { | ||||
| var layer = tensor.KerasHistory[0]; | |||||
| Layer layer = tensor.KerasHistory; | |||||
| add(layer); | add(layer); | ||||
| } | } | ||||
| @@ -129,7 +129,7 @@ namespace Tensorflow.Keras.Engine | |||||
| void _map_graph_network(Tensor inputs, Tensor outputs) | void _map_graph_network(Tensor inputs, Tensor outputs) | ||||
| { | { | ||||
| layers.add(outputs.KerasHistory[0]); | |||||
| layers.add(outputs.KerasHistory); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -30,6 +30,19 @@ namespace Tensorflow | |||||
| Name = name | Name = name | ||||
| }); | }); | ||||
| /// <summary> | |||||
| /// `Model` groups layers into an object with training and inference features. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="output"></param> | |||||
| /// <returns></returns> | |||||
| public Model Model(Tensor input, Tensor output) | |||||
| => new Model(new ModelArgs | |||||
| { | |||||
| Inputs = new[] { input }, | |||||
| Outputs = new[] { output } | |||||
| }); | |||||
| /// <summary> | /// <summary> | ||||
| /// Instantiate a Keras tensor. | /// Instantiate a Keras tensor. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| Tensor outputs = null; | Tensor outputs = null; | ||||
| @@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs, bool training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) | |||||
| { | { | ||||
| var outputs = _convolution_op.__call__(inputs, kernel); | var outputs = _convolution_op.__call__(inputs, kernel); | ||||
| if (use_bias) | if (use_bias) | ||||
| @@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs, bool training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) | |||||
| { | { | ||||
| Tensor outputs = null; | Tensor outputs = null; | ||||
| var rank = inputs.rank; | var rank = inputs.rank; | ||||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| var output = tf_utils.smart_cond(is_training, | var output = tf_utils.smart_cond(is_training, | ||||
| () => tf.nn.dropout(inputs, | () => tf.nn.dropout(inputs, | ||||
| @@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| var dtype = inputs.dtype; | var dtype = inputs.dtype; | ||||
| if (dtype != tf.int32 && dtype != tf.int64) | if (dtype != tf.int32 && dtype != tf.int64) | ||||
| @@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers | |||||
| .ToArray(); | .ToArray(); | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| return base.call(inputs, is_training); | |||||
| return base.call(inputs, state: state, is_training: is_training); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||||
| input_spec = new InputSpec(ndim: 4); | input_spec = new InputSpec(ndim: 4); | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| int[] pool_shape; | int[] pool_shape; | ||||
| int[] strides; | int[] strides; | ||||
| @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| scale = math_ops.cast(args.Scale, args.DType); | scale = math_ops.cast(args.Scale, args.DType); | ||||
| offset = math_ops.cast(args.Offset, args.DType); | offset = math_ops.cast(args.Offset, args.DType); | ||||
| @@ -61,44 +61,7 @@ namespace Tensorflow.Layers | |||||
| return (results[0], results[1]); | return (results[0], results[1]); | ||||
| } | } | ||||
| public Tensor __call__(Tensor inputs, | |||||
| Tensor training = null, | |||||
| VariableScope scope = null) | |||||
| { | |||||
| _set_scope(scope); | |||||
| _graph = ops._get_graph_from_inputs(new Tensor[] { inputs }, graph: _graph); | |||||
| variable_scope scope_context_manager = null; | |||||
| if (built) | |||||
| { | |||||
| scope_context_manager = tf.variable_scope(_scope, | |||||
| reuse: true, | |||||
| auxiliary_name_scope: false); | |||||
| } | |||||
| else | |||||
| { | |||||
| scope_context_manager = tf.variable_scope(_scope, | |||||
| reuse: _reuse, | |||||
| auxiliary_name_scope: false); | |||||
| } | |||||
| Tensor outputs = null; | |||||
| tf_with(scope_context_manager, scope2 => | |||||
| { | |||||
| _current_scope = scope2; | |||||
| // Actually call layer | |||||
| outputs = base.Apply(inputs[0], | |||||
| is_training: training == null ? false : false); | |||||
| }); | |||||
| // Update global default collections. | |||||
| _add_elements_to_collection(updates.ToArray(), new string[] { tf.GraphKeys.UPDATE_OPS }); | |||||
| return outputs; | |||||
| } | |||||
| public Tensor[] __call__(Tensor[] inputs, | |||||
| public Tensors __call__(Tensors inputs, | |||||
| Tensor state = null, | Tensor state = null, | ||||
| Tensor training = null, | Tensor training = null, | ||||
| VariableScope scope = null) | VariableScope scope = null) | ||||
| @@ -120,13 +83,13 @@ namespace Tensorflow.Layers | |||||
| auxiliary_name_scope: false); | auxiliary_name_scope: false); | ||||
| } | } | ||||
| Tensor[] outputs = null; | |||||
| Tensors outputs = null; | |||||
| tf_with(scope_context_manager, scope2 => | tf_with(scope_context_manager, scope2 => | ||||
| { | { | ||||
| _current_scope = scope2; | _current_scope = scope2; | ||||
| // Actually call layer | // Actually call layer | ||||
| outputs = base.Apply(inputs, | outputs = base.Apply(inputs, | ||||
| state, | |||||
| state: state, | |||||
| is_training: training == null ? false : false); | is_training: training == null ? false : false); | ||||
| }); | }); | ||||
| @@ -74,7 +74,7 @@ namespace Tensorflow | |||||
| /// <param name="training"></param> | /// <param name="training"></param> | ||||
| /// <param name="state"></param> | /// <param name="state"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected override Tensor call(Tensor inputs, bool is_training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| var one = constant_op.constant(1, dtype: dtypes.int32); | var one = constant_op.constant(1, dtype: dtypes.int32); | ||||
| // Parameters of gates are concatenated into one multiply for efficiency. | // Parameters of gates are concatenated into one multiply for efficiency. | ||||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||||
| // array_ops.split(value: state, num_or_size_splits: 2, axis: one); | // array_ops.split(value: state, num_or_size_splits: 2, axis: one); | ||||
| throw new NotImplementedException("BasicLstmCell call"); | throw new NotImplementedException("BasicLstmCell call"); | ||||
| } | } | ||||
| var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel.AsTensor()); | |||||
| var gate_inputs = math_ops.matmul(array_ops.concat(new[] { (Tensor)inputs, h }, 1), _kernel.AsTensor()); | |||||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); | gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); | ||||
| // i = input_gate, j = new_input, f = forget_gate, o = output_gate | // i = input_gate, j = new_input, f = forget_gate, o = output_gate | ||||
| @@ -67,14 +67,14 @@ namespace Tensorflow | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensor[] call(Tensor[] inputs, Tensor state, bool is_training = false) | |||||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| // Most basic RNN: output = new_state = act(W * input + U * state + B). | // Most basic RNN: output = new_state = act(W * input + U * state + B). | ||||
| var concat = array_ops.concat(new[] { inputs[0], state }, 1); | |||||
| var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); | |||||
| var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); | var gate_inputs = math_ops.matmul(concat, _kernel.AsTensor()); | ||||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); | gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); | ||||
| var output = _activation(gate_inputs, null); | var output = _activation(gate_inputs, null); | ||||
| return new[] { output, output }; | |||||
| return new Tensors(output, output); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -127,7 +127,7 @@ namespace Tensorflow.Operations | |||||
| { | { | ||||
| input_shape = flat_input.TensorShape.with_rank_at_least(2); | input_shape = flat_input.TensorShape.with_rank_at_least(2); | ||||
| batch_size = tensor_shape.dimension_at_index(input_shape, 0); | batch_size = tensor_shape.dimension_at_index(input_shape, 0); | ||||
| var input_size = input_shape[1]; | |||||
| var input_size = input_shape[new Slice(1)]; | |||||
| fixed_batch_size.merge_with(batch_size); | fixed_batch_size.merge_with(batch_size); | ||||
| foreach (var (i, size) in enumerate(input_size.dims)) | foreach (var (i, size) in enumerate(input_size.dims)) | ||||
| { | { | ||||
| @@ -364,7 +364,7 @@ namespace Tensorflow.Operations | |||||
| if (sequence_length != null) | if (sequence_length != null) | ||||
| throw new NotImplementedException("sequence_length != null"); | throw new NotImplementedException("sequence_length != null"); | ||||
| else | else | ||||
| outputs = cell.__call__(new[] { input_t_t }, state: state1); | |||||
| outputs = cell.__call__(input_t_t, state: state1); | |||||
| var (output, new_state) = (outputs[0], outputs[1]); | var (output, new_state) = (outputs[0], outputs[1]); | ||||
| // Keras cells always wrap state as list, even if it's a single tensor. | // Keras cells always wrap state as list, even if it's a single tensor. | ||||
| @@ -157,7 +157,7 @@ namespace Tensorflow | |||||
| leading_size, | leading_size, | ||||
| shape(tensor_tensor)[$"{axis + ndims_mask}:"] | shape(tensor_tensor)[$"{axis + ndims_mask}:"] | ||||
| }, 0); | }, 0); | ||||
| tensor_tensor = reshape(tensor, shape1); | |||||
| tensor_tensor = reshape(tensor_tensor, shape1); | |||||
| var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); | var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); | ||||
| var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray()); | var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray()); | ||||
| var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); | var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); | ||||
| @@ -353,7 +353,7 @@ namespace Tensorflow | |||||
| public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | ||||
| => ones_like_impl(tensor, dtype, name, optimize); | => ones_like_impl(tensor, dtype, name, optimize); | ||||
| public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null) | |||||
| public static Tensor reshape<T2>(Tensor tensor, T2 shape, string name = null) | |||||
| => gen_array_ops.reshape(tensor, shape, null); | => gen_array_ops.reshape(tensor, shape, null); | ||||
| private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | ||||
| @@ -292,7 +292,7 @@ namespace Tensorflow | |||||
| return _op.output; | return _op.output; | ||||
| } | } | ||||
| public static Tensor reshape<T1, T2>(T1 tensor, T2 shape, string name = null) | |||||
| public static Tensor reshape<T>(Tensor tensor, T shape, string name = null) | |||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
| { | { | ||||
| @@ -144,7 +144,7 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Keras History: (Layer, (node_index, tensor_index)) | /// Keras History: (Layer, (node_index, tensor_index)) | ||||
| /// </summary> | /// </summary> | ||||
| public List<Layer> KerasHistory = new List<Layer>(); | |||||
| public KerasHistory KerasHistory { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Updates the shape of this tensor. | /// Updates the shape of this tensor. | ||||
| @@ -132,6 +132,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public int this[int index] => dims[index]; | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns True iff `self` is fully defined in every dimension. | /// Returns True iff `self` is fully defined in every dimension. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,70 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Gradients; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Tensors is used to represent a Tensor or a array of Tensor. | |||||
| /// It will simplify the API interface, it converts Tensor | |||||
| /// and Tensor[] to Tensors implicitily. And parse back to Tensor | |||||
| /// and Tensor[] from Tensors implicitily. | |||||
| /// It works for tuple and scalar as well. | |||||
| /// </summary> | |||||
| public class Tensors : IEnumerable<Tensor> | |||||
| { | |||||
| Tensor[] items; | |||||
| public TF_DataType dtype => items.First().dtype; | |||||
| public TensorShape shape => items.First().TensorShape; | |||||
| public int rank => items.First().rank; | |||||
| public bool IsEagerTensor => items.First().IsEagerTensor; | |||||
| public Tensor this[int index] => items[index]; | |||||
| public Tensors(params Tensor[] tensors) | |||||
| { | |||||
| items = tensors; | |||||
| } | |||||
| public Tensors(NDArray nd) | |||||
| { | |||||
| items = new[] { ops.convert_to_tensor(nd) }; | |||||
| } | |||||
| public IEnumerator<Tensor> GetEnumerator() | |||||
| { | |||||
| foreach (var tensor in items) | |||||
| yield return tensor; | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public static implicit operator Tensors(Tensor tensor) | |||||
| => new Tensors(tensor); | |||||
| public static implicit operator Tensors(NDArray nd) | |||||
| => new Tensors(nd); | |||||
| public static implicit operator Tensors(Tensor[] tensors) | |||||
| => new Tensors(tensors); | |||||
| public static implicit operator Tensor(Tensors tensors) | |||||
| => tensors.FirstOrDefault(); | |||||
| public static implicit operator Tensor[](Tensors tensors) | |||||
| => tensors.items; | |||||
| public override string ToString() | |||||
| => items.Length == 1 | |||||
| ? items.First().ToString() | |||||
| : items.Length + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||||
| } | |||||
| } | |||||
| @@ -155,6 +155,8 @@ namespace Tensorflow | |||||
| return val; | return val; | ||||
| case NDArray val: | case NDArray val: | ||||
| return new EagerTensor(val, ctx.DeviceName); | return new EagerTensor(val, ctx.DeviceName); | ||||
| //case TensorShape val: | |||||
| //return new EagerTensor(val.dims, ctx.DeviceName); | |||||
| case string val: | case string val: | ||||
| return new EagerTensor(val, ctx.DeviceName); | return new EagerTensor(val, ctx.DeviceName); | ||||
| case string[] val: | case string[] val: | ||||
| @@ -16,6 +16,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -280,6 +281,7 @@ namespace Tensorflow | |||||
| return scope._scope; | return scope._scope; | ||||
| } | } | ||||
| [DebuggerHidden] | |||||
| public void __exit__() | public void __exit__() | ||||
| { | { | ||||
| _cached_pure_variable_scope.__exit__(); | _cached_pure_variable_scope.__exit__(); | ||||
| @@ -287,6 +289,7 @@ namespace Tensorflow | |||||
| _current_name_scope.__exit__(); | _current_name_scope.__exit__(); | ||||
| } | } | ||||
| [DebuggerHidden] | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| if (_current_name_scope != null) | if (_current_name_scope != null) | ||||
| @@ -76,10 +76,10 @@ namespace Tensorflow | |||||
| return get_default_graph().get_collection_ref<T>(key); | return get_default_graph().get_collection_ref<T>(key); | ||||
| } | } | ||||
| public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | |||||
| public static Graph _get_graph_from_inputs(Tensors op_input_list) | |||||
| => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | ||||
| public static Graph _get_graph_from_inputs(Tensor[] op_input_list, Graph graph = null) | |||||
| public static Graph _get_graph_from_inputs(Tensors op_input_list, Graph graph = null) | |||||
| { | { | ||||
| foreach(var op_input in op_input_list) | foreach(var op_input in op_input_list) | ||||
| { | { | ||||
| @@ -0,0 +1,37 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Layers; | |||||
| using NumSharp; | |||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.Keras | |||||
| { | |||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/guide/keras/save_and_serialize | |||||
| /// </summary> | |||||
| [TestClass] | |||||
| public class ModelSaveTest : EagerModeTestBase | |||||
| { | |||||
| [TestMethod] | |||||
| public void SaveAndLoadTest() | |||||
| { | |||||
| var model = GetModel(); | |||||
| } | |||||
| Model GetModel() | |||||
| { | |||||
| var keras = tf.keras; | |||||
| // Create a simple model. | |||||
| var inputs = keras.Input(shape: 32); | |||||
| var outputs = keras.layers.Dense(1).Apply(inputs); | |||||
| var model = keras.Model(inputs, outputs); | |||||
| model.compile("adam", "mean_squared_error"); | |||||
| return model; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -12,6 +12,40 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| [TestClass] | [TestClass] | ||||
| public class FunctionApiTest : TFNetApiTest | public class FunctionApiTest : TFNetApiTest | ||||
| { | { | ||||
| Tensor Min(Tensor a, Tensor b) | |||||
| { | |||||
| return tf.cond(a < b, () => a, () => b); | |||||
| } | |||||
| [TestMethod] | |||||
| public void MulInAutoGraph() | |||||
| { | |||||
| var a = tf.constant(1); | |||||
| var b = tf.constant(2); | |||||
| // For first time running, tf.net will record the operations in graph mode. | |||||
| // And register to tensorflow op library. | |||||
| var output = Mul(a, b); | |||||
| Assert.AreEqual(2, (int)output); | |||||
| var c = tf.constant(3); | |||||
| // for the following invoke, Mul will be intercepted and run it in eager mode. | |||||
| output = Mul(b, c); | |||||
| Assert.AreEqual(6, (int)output); | |||||
| } | |||||
| /// <summary> | |||||
| /// Method with AutoGraph attribute will be converted to FuncGraph | |||||
| /// when it's invoked for the first time. | |||||
| /// </summary> | |||||
| /// <param name="a"></param> | |||||
| /// <param name="b"></param> | |||||
| /// <returns></returns> | |||||
| [AutoGraph] | |||||
| Tensor Mul(Tensor a, Tensor b) | |||||
| { | |||||
| return a * b; | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void TwoInputs_OneOutput() | public void TwoInputs_OneOutput() | ||||
| { | { | ||||
| @@ -0,0 +1,35 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using System.Linq; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| { | |||||
| [TestClass] | |||||
| public class GraphBuildTest : CApiTest | |||||
| { | |||||
| [TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")] | |||||
| public void UpdateEdge() | |||||
| { | |||||
| using var graph = new Graph().as_default(); | |||||
| var one = tf.constant(1, name: "one"); | |||||
| var two = tf.constant(2, name: "two"); | |||||
| var add = tf.add(one, two, name: "add"); | |||||
| var neg = tf.negative(add, name: "neg"); | |||||
| Assert.AreEqual(1, one.consumers().Length); | |||||
| Assert.AreEqual("add", neg.op.node_def.Input[0]); | |||||
| // update edge | |||||
| neg.op._update_input(0, one); | |||||
| // c_api.TF_UpdateEdge(graph, new TF_Output(c1.op, 0), new TF_Input(neg.op, 0), tf.Status.Handle); | |||||
| Assert.AreEqual(2, one.consumers().Length); | |||||
| Assert.AreEqual("one:0", neg.op.node_def.Input[0]); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,59 +0,0 @@ | |||||
| using System; | |||||
| using FluentAssertions; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.layers_test | |||||
| { | |||||
| [TestClass] | |||||
| public class flatten : GraphModeTestBase | |||||
| { | |||||
| [TestMethod] | |||||
| public void Case1() | |||||
| { | |||||
| var sess = tf.Session().as_default(); | |||||
| var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2)); | |||||
| sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Case2() | |||||
| { | |||||
| var sess = tf.Session().as_default(); | |||||
| var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); | |||||
| sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Case3() | |||||
| { | |||||
| var sess = tf.Session().as_default(); | |||||
| var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape()); | |||||
| new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>(); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Case4() | |||||
| { | |||||
| var sess = tf.Session().as_default(); | |||||
| var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, Unknown, 1, 2)); | |||||
| sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Case5() | |||||
| { | |||||
| var sess = tf.Session().as_default(); | |||||
| var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(Unknown, 4, 3, 1, 2)); | |||||
| sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24); | |||||
| } | |||||
| } | |||||
| } | |||||