| @@ -107,19 +107,19 @@ namespace Tensorflow.Gradients | |||
| public Tensor gradient(Tensor target, ResourceVariable source) | |||
| { | |||
| var results = gradient(target, new[] { source }); | |||
| var results = gradient(target, new List<IVariableV1> { source }); | |||
| return results[0]; | |||
| } | |||
| public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||
| { | |||
| var results = gradient(target, new[] { sources.Item1, sources.Item2 }); | |||
| var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }); | |||
| return (results[0], results[1]); | |||
| } | |||
| public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | |||
| public Tensor[] gradient(Tensor target, List<IVariableV1> sources) | |||
| { | |||
| if (_recording) | |||
| { | |||
| @@ -54,7 +54,16 @@ namespace Tensorflow.Gradients | |||
| var id = trace.output_tensor_info[i].GetID(); | |||
| if (!gradients.find(id, out var grad_it)) | |||
| { | |||
| throw new NotImplementedException("FunctionsAcceptingNoneForIndicesMap"); | |||
| if (FunctionsAcceptingNoneForIndicesMap().find(trace.op_type, out var func_name_it) && | |||
| func_name_it.find(i)) | |||
| { | |||
| out_gradients.Add(null); | |||
| } | |||
| else | |||
| { | |||
| out_gradients.Add(null); | |||
| zero_indices.Add(i); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| @@ -184,6 +193,15 @@ namespace Tensorflow.Gradients | |||
| return result.ToArray(); | |||
| } | |||
| UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() | |||
| { | |||
| var m = new UnorderedMap<string, UnorderedSet<int>>(); | |||
| m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||
| return m; | |||
| } | |||
| UnorderedMapEnumerable<long, List<Tensor>> InitialGradients(long[] target_tensor_ids, | |||
| UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
| Tensor[] output_gradients, | |||
| @@ -5,13 +5,11 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| public delegate Tensor Activation(Tensor x); | |||
| public class Activations | |||
| public partial class Activations | |||
| { | |||
| /// <summary> | |||
| /// Linear activation function (pass-through). | |||
| /// </summary> | |||
| public Activation Linear = x => x; | |||
| public Activation Linear = (features, name) => features; | |||
| } | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| public partial class Activations | |||
| { | |||
| public Activation Relu = (features, name) => | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Relu", name, | |||
| null, | |||
| features); | |||
| return results[0]; | |||
| } | |||
| throw new NotImplementedException(""); | |||
| }; | |||
| } | |||
| } | |||
| @@ -0,0 +1,9 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| public delegate Tensor Activation(Tensor features, string name = null); | |||
| } | |||
| @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
| public Layer[] InboundLayers { get; set; } | |||
| public int[] NodeIndices { get; set; } | |||
| public int[] TensorIndices { get; set; } | |||
| public Tensor[] InputTensors { get; set; } | |||
| public Tensor[] Outputs { get; set; } | |||
| public Tensor InputTensors { get; set; } | |||
| public Tensor Outputs { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,29 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Layers; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public partial class Layer | |||
| { | |||
| protected List<Layer> _layers = new List<Layer>(); | |||
| protected Layer Dense(int units, | |||
| Activation activation = null, | |||
| TensorShape input_shape = null) | |||
| { | |||
| var layer = new Dense(new DenseArgs | |||
| { | |||
| Units = units, | |||
| Activation = activation ?? tf.keras.activations.Linear, | |||
| InputShape = input_shape | |||
| }); | |||
| _layers.Add(layer); | |||
| return layer; | |||
| } | |||
| } | |||
| } | |||
| @@ -18,11 +18,9 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Threading; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Operations.Activation; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| @@ -34,7 +32,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// as convolution, batch norm, etc. These operations require managing weights, | |||
| /// losses, updates, and inter-layer connectivity. | |||
| /// </summary> | |||
| public abstract class Layer : AutoTrackable | |||
| public abstract partial class Layer : AutoTrackable | |||
| { | |||
| /// <summary> | |||
| /// Arguments initialize layer. | |||
| @@ -60,8 +58,19 @@ namespace Tensorflow.Keras.Engine | |||
| protected InputSpec inputSpec; | |||
| public bool SupportsMasking { get; set; } | |||
| protected List<IVariableV1> trainableWeights; | |||
| public List<IVariableV1> TrainableVariables => trainableWeights; | |||
| public List<IVariableV1> trainable_variables | |||
| { | |||
| get | |||
| { | |||
| if(trainableWeights.Count == 0) | |||
| _layers.ForEach(x => trainableWeights.AddRange(x.trainableWeights)); | |||
| return trainableWeights; | |||
| } | |||
| } | |||
| protected List<IVariableV1> nonTrainableWeights; | |||
| public List<IVariableV1> non_trainable_variables => nonTrainableWeights; | |||
| string name; | |||
| public string Name => name; | |||
| @@ -112,20 +121,20 @@ namespace Tensorflow.Keras.Engine | |||
| /// <param name="input"></param> | |||
| /// <param name="is_training"></param> | |||
| /// <returns></returns> | |||
| public Tensor[] Apply(Tensor[] inputs, bool is_training = false) | |||
| public Tensor Apply(Tensor inputs, bool is_training = false) | |||
| { | |||
| var input = inputs[0]; | |||
| Tensor[] outputs = null; | |||
| Tensor outputs = null; | |||
| callContext = callContext ?? new ThreadLocal<CallContext>() | |||
| { | |||
| Value = new CallContext() | |||
| }; | |||
| var eager = tf.executing_eagerly(); | |||
| using var ctxManager = CallContext.enter(); | |||
| string nameScope = ""; | |||
| if (tf.executing_eagerly()) | |||
| if (eager) | |||
| { | |||
| nameScope = name; | |||
| } | |||
| @@ -134,7 +143,7 @@ namespace Tensorflow.Keras.Engine | |||
| throw new NotImplementedException(""); | |||
| } | |||
| using var graph = tf.keras.backend.get_graph().as_default(); | |||
| // using var graph = tf.keras.backend.get_graph().as_default(); | |||
| tf_with(ops.name_scope(nameScope), scope => | |||
| { | |||
| @@ -143,74 +152,36 @@ namespace Tensorflow.Keras.Engine | |||
| outputs = call(inputs, is_training: is_training); | |||
| (input, outputs) = _set_connectivity_metadata_(input, outputs); | |||
| _handle_activity_regularization(inputs[0], outputs); | |||
| _set_mask_metadata(inputs[0], outputs, null); | |||
| outputs = _set_connectivity_metadata_(inputs, outputs); | |||
| _handle_activity_regularization(inputs, outputs); | |||
| _set_mask_metadata(inputs, outputs, null); | |||
| }); | |||
| return outputs; | |||
| } | |||
| [Obsolete("User Apply()")] | |||
| public Tensor[] __call__(Tensor[] inputs, | |||
| Tensor training = null, | |||
| Tensor state = null, | |||
| VariableScope scope = null) | |||
| private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) | |||
| { | |||
| var input_list = inputs; | |||
| var input = inputs[0]; | |||
| Tensor[] outputs = null; | |||
| // We will attempt to build a TF graph if & only if all inputs are symbolic. | |||
| // This is always the case in graph mode. It can also be the case in eager | |||
| // mode when all inputs can be traced back to `keras.Input()` (when building | |||
| // models using the functional API). | |||
| bool build_graph = tf_utils.are_all_symbolic_tensors(input_list); | |||
| if (build_graph) | |||
| { | |||
| // Only create Keras history if at least one tensor originates from a | |||
| // `keras.Input`. Otherwise this Layer may be being used outside the Keras | |||
| // framework. | |||
| // base_layer_utils.create_keras_history(inputs) | |||
| } | |||
| // with base_layer_utils.call_context(self): | |||
| // Handle Keras mask propagation from previous layer to current layer. | |||
| // with base_layer_utils.call_context(self): | |||
| // Check input assumptions set after layer building, e.g. input shape. | |||
| if (build_graph) | |||
| /*var returnOutputs = new List<Tensor>(); | |||
| foreach(var x in outputs) | |||
| { | |||
| // Symbolic execution on symbolic tensors. We will attempt to build | |||
| // the corresponding TF subgraph inside `backend.get_graph()` | |||
| var graph = tf.keras.backend.get_graph().as_default(); | |||
| tf_with(ops.name_scope(_name_scope()), delegate | |||
| if (inputs.Contains(x)) | |||
| { | |||
| // Build layer if applicable (if the `build` method has been | |||
| // overridden). | |||
| MaybeBuild(inputs); | |||
| outputs = call(inputs, | |||
| // training: training, | |||
| state: state); | |||
| (input, outputs) = _set_connectivity_metadata_(input, outputs); | |||
| _handle_activity_regularization(inputs[0], outputs); | |||
| _set_mask_metadata(inputs[0], outputs, null); | |||
| }); | |||
| } | |||
| } | |||
| returnOutputs.Add(x); | |||
| }*/ | |||
| return outputs; | |||
| } | |||
| new Node(this, new NodeArgs | |||
| { | |||
| Outputs = outputs | |||
| }); | |||
| private (Tensor, Tensor[]) _set_connectivity_metadata_(Tensor inputs, Tensor[] outputs) | |||
| { | |||
| //_add_inbound_node(input_tensors: inputs, output_tensors: outputs); | |||
| return (inputs, outputs); | |||
| return outputs; | |||
| } | |||
| private void _handle_activity_regularization(Tensor inputs, Tensor[] outputs) | |||
| private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | |||
| { | |||
| //if(_activity_regularizer != null) | |||
| { | |||
| @@ -218,7 +189,7 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| } | |||
| private void _set_mask_metadata(Tensor inputs, Tensor[] outputs, Tensor previous_mask) | |||
| private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) | |||
| { | |||
| } | |||
| @@ -228,7 +199,7 @@ namespace Tensorflow.Keras.Engine | |||
| return null; | |||
| } | |||
| protected virtual Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) | |||
| protected virtual Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| @@ -238,15 +209,15 @@ namespace Tensorflow.Keras.Engine | |||
| return Name; | |||
| } | |||
| protected void MaybeBuild(Tensor[] inputs) | |||
| protected void MaybeBuild(Tensor inputs) | |||
| { | |||
| // Check input assumptions set before layer building, e.g. input rank. | |||
| if (built) | |||
| return; | |||
| if (DType == TF_DataType.DtInvalid) | |||
| args.DType = inputs[0].dtype; | |||
| args.DType = inputs.dtype; | |||
| var input_shapes = inputs[0].TensorShape; | |||
| var input_shapes = inputs.TensorShape; | |||
| build(input_shapes); | |||
| built = true; | |||
| } | |||
| @@ -35,13 +35,13 @@ namespace Tensorflow.Keras.Engine | |||
| public int[] node_indices; | |||
| public int[] tensor_indices; | |||
| public Tensor[] input_tensors; | |||
| public Tensor[] Outputs => args.Outputs; | |||
| public Tensor input_tensors; | |||
| public Tensor Outputs => args.Outputs; | |||
| public TensorShape[] input_shapes; | |||
| public TensorShape[] output_shapes; | |||
| List<Layer> kerasInputs; | |||
| public Node(InputLayer layer, NodeArgs args) | |||
| public Node(Layer layer, NodeArgs args) | |||
| { | |||
| this.args = args; | |||
| @@ -25,9 +25,7 @@ namespace Tensorflow.Keras.Engine | |||
| #pragma warning disable CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false | |||
| bool _is_graph_network; | |||
| #pragma warning restore CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false | |||
| #pragma warning disable CS0169 // The field 'Sequential.outputs' is never used | |||
| Tensor[] outputs; | |||
| #pragma warning restore CS0169 // The field 'Sequential.outputs' is never used | |||
| Tensor outputs; | |||
| bool computeOutputAndMaskJointly; | |||
| bool autoTrackSubLayers; | |||
| @@ -51,6 +49,11 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| public void add(Tensor layer) | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Adds a layer instance on top of the layer stack. | |||
| /// </summary> | |||
| @@ -71,7 +74,7 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| // Instantiate an input layer. | |||
| var x = tf.keras.Input( | |||
| batch_shape: layer.BatchInputShape, | |||
| shape: layer.BatchInputShape, | |||
| dtype: layer.DType, | |||
| name: layer.Name + "_input"); | |||
| @@ -86,7 +89,7 @@ namespace Tensorflow.Keras.Engine | |||
| if (set_inputs) | |||
| { | |||
| // If an input layer (placeholder) is available. | |||
| // outputs = layer.inbound_nodes; | |||
| outputs = layer.InboundNodes[^1].Outputs; | |||
| } | |||
| } | |||
| @@ -14,15 +14,35 @@ namespace Tensorflow | |||
| { | |||
| public KerasDataset datasets { get; } = new KerasDataset(); | |||
| public Initializers initializers { get; } = new Initializers(); | |||
| public Layers layers { get; } = new Layers(); | |||
| public LayersApi layers { get; } = new LayersApi(); | |||
| public Activations activations { get; } = new Activations(); | |||
| public BackendImpl backend { get; } = new BackendImpl(); | |||
| public Models models { get; } = new Models(); | |||
| public Sequential Sequential() | |||
| => new Sequential(); | |||
| public Tensor[] Input(int[] batch_shape = null, | |||
| /// <summary> | |||
| /// Instantiate a Keras tensor. | |||
| /// </summary> | |||
| /// <param name="shape"></param> | |||
| /// <param name="batch_size"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="sparse"> | |||
| /// A boolean specifying whether the placeholder to be created is sparse. | |||
| /// </param> | |||
| /// <param name="ragged"> | |||
| /// A boolean specifying whether the placeholder to be created is ragged. | |||
| /// </param> | |||
| /// <param name="tensor"> | |||
| /// Optional existing tensor to wrap into the `Input` layer. | |||
| /// If set, the layer will not create a placeholder tensor. | |||
| /// </param> | |||
| /// <returns></returns> | |||
| public Tensor Input(TensorShape shape = null, | |||
| int batch_size = -1, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string name = null, | |||
| @@ -33,7 +53,7 @@ namespace Tensorflow | |||
| var args = new InputLayerArgs | |||
| { | |||
| Name = name, | |||
| BatchInputShape = batch_shape, | |||
| InputShape = shape, | |||
| BatchSize = batch_size, | |||
| DType = dtype, | |||
| Sparse = sparse, | |||
| @@ -43,7 +63,7 @@ namespace Tensorflow | |||
| var layer = new InputLayer(args); | |||
| return layer.InboundNodes[0].Outputs; | |||
| return layer.InboundNodes[0].Outputs[0]; | |||
| } | |||
| public static Embedding Embedding(int input_dim, | |||
| @@ -55,7 +75,7 @@ namespace Tensorflow | |||
| embeddings_initializer, | |||
| mask_zero); | |||
| public class Layers | |||
| public class LayersApi | |||
| { | |||
| public Layer Dense(int units, | |||
| Activation activation = null, | |||
| @@ -143,15 +143,15 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| { | |||
| Tensor outputs = null; | |||
| if (fused) | |||
| { | |||
| Tensor training = tf.convert_to_tensor(is_training); | |||
| outputs = _fused_batch_norm(inputs[0], training: training); | |||
| return new[] { outputs, outputs }; | |||
| outputs = _fused_batch_norm(inputs, training: training); | |||
| return outputs; | |||
| } | |||
| throw new NotImplementedException("BatchNormalization call"); | |||
| @@ -108,9 +108,9 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor[] call(Tensor[] inputs, bool training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) | |||
| { | |||
| var outputs = _convolution_op.__call__(inputs[0], kernel); | |||
| var outputs = _convolution_op.__call__(inputs, kernel); | |||
| if (use_bias) | |||
| { | |||
| if (data_format == "channels_first") | |||
| @@ -126,7 +126,7 @@ namespace Tensorflow.Keras.Layers | |||
| if (activation != null) | |||
| outputs = activation.Activate(outputs); | |||
| return new[] { outputs, outputs }; | |||
| return outputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -65,17 +65,17 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor[] call(Tensor[] inputs, bool training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) | |||
| { | |||
| Tensor outputs = null; | |||
| var rank = inputs[0].rank; | |||
| var rank = inputs.rank; | |||
| if(rank > 2) | |||
| { | |||
| throw new NotImplementedException("call rank > 2"); | |||
| } | |||
| else | |||
| { | |||
| outputs = gen_math_ops.mat_mul(inputs[0], kernel.AsTensor()); | |||
| outputs = gen_math_ops.mat_mul(inputs, kernel.AsTensor()); | |||
| } | |||
| if (args.UseBias) | |||
| @@ -83,7 +83,7 @@ namespace Tensorflow.Keras.Layers | |||
| if (args.Activation != null) | |||
| outputs = activation(outputs); | |||
| return new[] { outputs }; | |||
| return outputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -57,14 +57,14 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| { | |||
| var dtype = inputs[0].dtype; | |||
| var dtype = inputs.dtype; | |||
| if (dtype != tf.int32 && dtype != tf.int64) | |||
| inputs[0] = math_ops.cast(inputs[0], tf.int32); | |||
| inputs = math_ops.cast(inputs, tf.int32); | |||
| var @out = embedding_ops.embedding_lookup(embeddings, inputs[0]); | |||
| return new[] { @out, @out }; | |||
| var outputs = embedding_ops.embedding_lookup(embeddings, inputs[0]); | |||
| return outputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -84,8 +84,8 @@ namespace Tensorflow.Keras.Layers | |||
| // input_tensor._keras_mask = None | |||
| new Node(this, new NodeArgs | |||
| { | |||
| InputTensors = new Tensor[] { args.InputTensor }, | |||
| Outputs = new Tensor[] { args.InputTensor } | |||
| InputTensors = args.InputTensor, | |||
| Outputs = args.InputTensor | |||
| }); | |||
| typeSpec = new TensorSpec(args.InputTensor.TensorShape, | |||
| @@ -45,7 +45,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.input_spec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| { | |||
| int[] pool_shape; | |||
| if (data_format == "channels_last") | |||
| @@ -60,13 +60,13 @@ namespace Tensorflow.Keras.Layers | |||
| } | |||
| var outputs = pool_function.Apply( | |||
| inputs[0], | |||
| inputs, | |||
| ksize: pool_shape, | |||
| strides: strides, | |||
| padding: padding.ToUpper(), | |||
| data_format: conv_utils.convert_data_format(data_format, 4)); | |||
| return new[] { outputs, outputs }; | |||
| return outputs; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| public class Models | |||
| { | |||
| public Sequential Sequential() | |||
| => new Sequential(); | |||
| } | |||
| } | |||
| @@ -39,7 +39,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||
| string name = null, | |||
| bool experimental_aggregate_gradients = true) | |||
| => apply_gradients(new (Tensor, ResourceVariable)[] { grads_and_vars }, | |||
| => apply_gradients(grads_and_vars, | |||
| name: name, | |||
| experimental_aggregate_gradients: experimental_aggregate_gradients); | |||
| @@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| _resource_apply_dense(var, grad, apply_state); | |||
| } | |||
| protected virtual Operation _resource_apply_dense(ResourceVariable var, | |||
| protected virtual Operation _resource_apply_dense(IVariableV1 var, | |||
| EagerTensor grad, | |||
| Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||
| { | |||
| @@ -107,7 +107,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| return grads_and_vars.Select(x => x.Item1).ToArray(); | |||
| } | |||
| Dictionary<DeviceDType, Dictionary<string, Tensor>> _prepare(ResourceVariable[] var_list) | |||
| Dictionary<DeviceDType, Dictionary<string, Tensor>> _prepare(IVariableV1[] var_list) | |||
| { | |||
| var _apply_state = new Dictionary<DeviceDType, Dictionary<string, Tensor>>(); | |||
| var keys = var_list.Select(x => new DeviceDType | |||
| @@ -151,7 +151,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| return math_ops.cast(value, dtype); | |||
| } | |||
| void _create_all_weights(ResourceVariable[] var_list) | |||
| void _create_all_weights(IVariableV1[] var_list) | |||
| { | |||
| if(_iterations == null) | |||
| { | |||
| @@ -190,7 +190,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| _hypers_created = true; | |||
| } | |||
| void _create_slots(ResourceVariable[] var_list) | |||
| void _create_slots(IVariableV1[] var_list) | |||
| { | |||
| if(_momentum) | |||
| { | |||
| @@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| _get_hyper("momentum", device_dtype.DType)); | |||
| } | |||
| protected override Operation _resource_apply_dense(ResourceVariable var, EagerTensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||
| protected override Operation _resource_apply_dense(IVariableV1 var, EagerTensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||
| { | |||
| if (_momentum) | |||
| { | |||
| @@ -88,9 +88,8 @@ namespace Tensorflow.Layers | |||
| { | |||
| _current_scope = scope2; | |||
| // Actually call layer | |||
| outputs = base.__call__(new Tensor[] { inputs }, | |||
| training: training, | |||
| state: state); | |||
| /*outputs = base.Apply(new Tensor[] { inputs }, | |||
| is_training: training);*/ | |||
| }); | |||
| @@ -74,7 +74,7 @@ namespace Tensorflow | |||
| /// <param name="training"></param> | |||
| /// <param name="state"></param> | |||
| /// <returns></returns> | |||
| protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| { | |||
| var one = constant_op.constant(1, dtype: dtypes.int32); | |||
| // Parameters of gates are concatenated into one multiply for efficiency. | |||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||
| // array_ops.split(value: state, num_or_size_splits: 2, axis: one); | |||
| throw new NotImplementedException("BasicLstmCell call"); | |||
| } | |||
| var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs[0], h }, 1), _kernel as RefVariable); | |||
| var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel as RefVariable); | |||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); | |||
| // i = input_gate, j = new_input, f = forget_gate, o = output_gate | |||
| @@ -105,9 +105,9 @@ namespace Tensorflow | |||
| if (_state_is_tuple) | |||
| return new[] { new_c, new_h }; | |||
| return new_c; | |||
| else | |||
| return new[] { array_ops.concat(new[] { new_c, new_h }, 1) }; | |||
| return array_ops.concat(new[] { new_c, new_h }, 1); | |||
| } | |||
| public override object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| @@ -67,14 +67,14 @@ namespace Tensorflow | |||
| built = true; | |||
| } | |||
| protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| { | |||
| // Most basic RNN: output = new_state = act(W * input + U * state + B). | |||
| var concat = array_ops.concat(new[] { inputs[0], state }, 1); | |||
| var concat = array_ops.concat(new[] { inputs, state }, 1); | |||
| var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable); | |||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); | |||
| var output = _activation(gate_inputs, null); | |||
| return new[] { output, output }; | |||
| return output; | |||
| } | |||
| } | |||
| } | |||
| @@ -108,6 +108,17 @@ namespace Tensorflow.Operations | |||
| string data_format = null, | |||
| string name = null) | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "BiasAdd", name, | |||
| null, | |||
| value, bias, | |||
| "data_format", data_format); | |||
| return results[0]; | |||
| } | |||
| if (data_format == null) | |||
| data_format = "NHWC"; | |||
| @@ -125,6 +136,17 @@ namespace Tensorflow.Operations | |||
| string data_format = "NHWC", | |||
| string name = null) | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "BiasAddGrad", name, | |||
| null, | |||
| out_backprop, | |||
| "data_format", data_format); | |||
| return results[0]; | |||
| } | |||
| if (data_format == null) | |||
| data_format = "NHWC"; | |||
| @@ -460,6 +482,16 @@ namespace Tensorflow.Operations | |||
| /// </remarks> | |||
| public static (Tensor loss, Tensor backprop) sparse_softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = "SparseSoftmaxCrossEntropyWithLogits") | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "SparseSoftmaxCrossEntropyWithLogits", name, | |||
| null, | |||
| features, labels); | |||
| return (results[0], results[1]); | |||
| } | |||
| var op = tf.OpDefLib._apply_op_helper("SparseSoftmaxCrossEntropyWithLogits", name: name, args: new { features, labels }); | |||
| int _idx = 0; | |||
| var loss = op.outputs[_idx++]; | |||
| @@ -475,7 +507,7 @@ namespace Tensorflow.Operations | |||
| /// <returns>A `Tensor`. Has the same type as `features`.</returns> | |||
| public static Tensor relu(Tensor features, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Relu", name, | |||
| @@ -31,8 +31,47 @@ namespace Tensorflow | |||
| public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = null) | |||
| => gen_array_ops.placeholder_with_default(input, shape, name); | |||
| /// <summary> | |||
| /// An identity op that triggers an error if a gradient is requested. | |||
| /// </summary> | |||
| /// <param name="input"> | |||
| /// any tensor. | |||
| /// </param> | |||
| /// <param name="name"> | |||
| /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PreventGradient'. | |||
| /// </param> | |||
| /// <param name="message"> | |||
| /// Will be printed in the error when anyone tries to differentiate | |||
| /// this operation. | |||
| /// </param> | |||
| /// <returns> | |||
| /// the same input tensor. | |||
| /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. | |||
| /// </returns> | |||
| /// <remarks> | |||
| /// When executed in a graph, this op outputs its input tensor as-is. | |||
| /// | |||
| /// When building ops to compute gradients, the TensorFlow gradient system | |||
| /// will return an error when trying to lookup the gradient of this op, | |||
| /// because no gradient must ever be registered for this function. This | |||
| /// op exists to prevent subtle bugs from silently returning unimplemented | |||
| /// gradients in some corner cases. | |||
| /// </remarks> | |||
| public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) | |||
| => gen_array_ops.prevent_gradient(input, message: message, name: name); | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "PreventGradient", name, | |||
| null, | |||
| input, | |||
| "message", message); | |||
| return results[0]; | |||
| } | |||
| var op = tf.OpDefLib._apply_op_helper("PreventGradient", name: name, args: new { input, message }); | |||
| return op.output; | |||
| } | |||
| internal static Tensor constant(object value, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| @@ -186,38 +186,6 @@ namespace Tensorflow | |||
| return _op.output; | |||
| } | |||
| /// <summary> | |||
| /// An identity op that triggers an error if a gradient is requested. | |||
| /// </summary> | |||
| /// <param name="input"> | |||
| /// any tensor. | |||
| /// </param> | |||
| /// <param name="name"> | |||
| /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PreventGradient'. | |||
| /// </param> | |||
| /// <param name="message"> | |||
| /// Will be printed in the error when anyone tries to differentiate | |||
| /// this operation. | |||
| /// </param> | |||
| /// <returns> | |||
| /// the same input tensor. | |||
| /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. | |||
| /// </returns> | |||
| /// <remarks> | |||
| /// When executed in a graph, this op outputs its input tensor as-is. | |||
| /// | |||
| /// When building ops to compute gradients, the TensorFlow gradient system | |||
| /// will return an error when trying to lookup the gradient of this op, | |||
| /// because no gradient must ever be registered for this function. This | |||
| /// op exists to prevent subtle bugs from silently returning unimplemented | |||
| /// gradients in some corner cases. | |||
| /// </remarks> | |||
| public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) | |||
| { | |||
| var op = tf.OpDefLib._apply_op_helper("PreventGradient", name: name, args: new { input, message }); | |||
| return op.output; | |||
| } | |||
| /// <summary> | |||
| /// Return a tensor with the same shape and contents as the input tensor or value. | |||
| /// </summary> | |||
| @@ -45,6 +45,7 @@ namespace Tensorflow | |||
| public Operation Initializer => initializer_op; | |||
| public Operation Op => handle.op; | |||
| public Graph Graph => handle.graph; | |||
| public string Device => ""; | |||
| public BaseResourceVariable() | |||
| { | |||
| @@ -148,6 +149,6 @@ namespace Tensorflow | |||
| { | |||
| } | |||
| public Tensor AsTensor() => _graph_element; | |||
| public Tensor AsTensor() => read_value(); | |||
| } | |||
| } | |||
| @@ -33,6 +33,7 @@ namespace Tensorflow | |||
| { | |||
| public string Name { get; } | |||
| public Tensor Handle { get; } | |||
| public string Device { get; } | |||
| public Operation Initializer { get; } | |||
| public Operation Op { get; } | |||
| public Tensor GraphElement { get; } | |||
| @@ -48,6 +48,7 @@ namespace Tensorflow | |||
| public TF_DataType dtype => _variable.dtype; | |||
| public TensorShape shape => tensor_util.to_shape(_variable.shape); | |||
| public string Device => ""; | |||
| public string Name => _variable.name; | |||
| @@ -51,7 +51,6 @@ namespace Tensorflow | |||
| { | |||
| Status = new Status(); | |||
| Context = new Context(new ContextOptions(), Status); | |||
| enable_eager_execution(); | |||
| OpDefLib = new OpDefLibrary(); | |||
| ConstructThreadingObjects(); | |||
| InitGradientEnvironment(); | |||
| @@ -16,6 +16,13 @@ namespace TensorFlowNET.UnitTest.Keras | |||
| [TestClass, Ignore] | |||
| public class LayersTest : GraphModeTestBase | |||
| { | |||
| [TestMethod] | |||
| public void Sequential() | |||
| { | |||
| var model = tf.keras.models.Sequential(); | |||
| model.add(tf.keras.Input(shape: 16)); | |||
| } | |||
| [TestMethod] | |||
| public void Embedding() | |||
| { | |||