| @@ -15,6 +15,7 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow.Contexts | |||
| @@ -87,6 +88,29 @@ namespace Tensorflow.Contexts | |||
| context_switches.Pop(); | |||
| } | |||
| public Tensor RunInAutoMode(Func<Tensor> graphAction, Func<Tensor> eagerAction, params Tensor[] tensors) | |||
| { | |||
| var shouldRunInEager = executing_eagerly() | |||
| && tensors.Count(x => x.IsEagerTensor) == tensors.Length; | |||
| if (shouldRunInEager) | |||
| return eagerAction(); | |||
| else | |||
| { | |||
| if (executing_eagerly()) | |||
| { | |||
| graph_mode(); | |||
| var result = graphAction(); | |||
| restore_mode(); | |||
| return result; | |||
| } | |||
| else | |||
| { | |||
| return graphAction(); | |||
| } | |||
| } | |||
| } | |||
| public void Dispose() | |||
| => Handle.Dispose(); | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class ZeroPadding2DArgs : LayerArgs | |||
| { | |||
| public NDArray Padding { get; set; } | |||
| } | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.Binding; | |||
| @@ -121,6 +122,38 @@ namespace Tensorflow.Keras | |||
| _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); | |||
| } | |||
| /// <summary> | |||
| /// Pads the 2nd and 3rd dimensions of a 4D tensor. | |||
| /// </summary> | |||
| /// <param name="x"></param> | |||
| /// <param name="padding"></param> | |||
| /// <param name="data_format"></param> | |||
| /// <returns></returns> | |||
| public Tensor spatial_2d_padding(Tensor x, NDArray padding = null, string data_format = null) | |||
| { | |||
| if (padding == null) | |||
| padding = new[,] { { 1, 1 }, { 1, 1 } }; | |||
| NDArray pattern; | |||
| if (data_format == "channels_first") | |||
| pattern = new int[,] | |||
| { | |||
| { 0, 0 }, | |||
| { 0, 0 }, | |||
| { padding[0][0], padding[0][1] }, | |||
| { padding[1][0], padding[1][1] } | |||
| }; | |||
| else | |||
| pattern = new int[,] | |||
| { | |||
| { 0, 0 }, | |||
| { padding[0][0], padding[0][1] }, | |||
| { padding[1][0], padding[1][1] }, | |||
| { 0, 0 } | |||
| }; | |||
| return array_ops.pad(x, pattern); | |||
| } | |||
| public class _DummyEagerGraph | |||
| { } | |||
| @@ -0,0 +1,47 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Security.Cryptography.X509Certificates; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public class BaseLayerUtils | |||
| { | |||
| public static Layer[] CreateKerasHistoryHelper(Tensors tensors) | |||
| { | |||
| var processed_ops = new List<Operation>(); | |||
| var created_layers = new List<Layer>(); | |||
| foreach (var tensor in tensors) | |||
| { | |||
| if (tensor.KerasHistory != null) | |||
| continue; | |||
| var op = tensor.op; | |||
| if (!processed_ops.Contains(op)) | |||
| { | |||
| var layer_inputs = new List<Tensor>(); | |||
| foreach (var (i, op_input) in enumerate(op.inputs._inputs)) | |||
| { | |||
| if (uses_keras_history(op_input)) | |||
| layer_inputs.Add(op_input); | |||
| else | |||
| { | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return created_layers.ToArray(); | |||
| } | |||
| static bool uses_keras_history(Tensor op_input) | |||
| { | |||
| return Layer.KerasHistories.Any(x => x.tensor == op_input); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Security.Cryptography.X509Certificates; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| @@ -47,12 +49,15 @@ namespace Tensorflow.Keras.Engine | |||
| // A graph network does not autocast inputs, as its layers will cast them instead. | |||
| _autocast = false; | |||
| if (outputs.Any(x => x.KerasHistory == null)) | |||
| BaseLayerUtils.CreateKerasHistoryHelper(outputs); | |||
| // Build self._output_layers: | |||
| foreach(var x in outputs) | |||
| foreach (var x in outputs) | |||
| { | |||
| var (layer, node_index, tensor_index) = x.KerasHistory; | |||
| _output_layers.append(layer); | |||
| _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); | |||
| _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | |||
| } | |||
| // Build self._input_layers: | |||
| @@ -60,8 +65,9 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| var (layer, node_index, tensor_index) = x.KerasHistory; | |||
| _input_layers.append(layer); | |||
| _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); | |||
| _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -12,12 +12,15 @@ namespace Tensorflow.Keras.Engine | |||
| Layer layer; | |||
| int node_index; | |||
| int tensor_index; | |||
| public Tensor tensor; | |||
| public KerasHistory(Layer layer, int node_index, int tensor_index) | |||
| public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) | |||
| { | |||
| this.layer = layer; | |||
| this.node_index = node_index; | |||
| this.tensor_index = tensor_index; | |||
| this.tensor = tensor; | |||
| Console.WriteLine(tensor.name); | |||
| } | |||
| public void Deconstruct(out Layer layer, out int node_index, out int tensor_index) | |||
| @@ -27,6 +30,9 @@ namespace Tensorflow.Keras.Engine | |||
| tensor_index = this.tensor_index; | |||
| } | |||
| public override string ToString() | |||
| => $"{layer.GetType().Name} {layer.Name} {tensor.name}"; | |||
| public static implicit operator Layer(KerasHistory history) | |||
| => history.layer; | |||
| } | |||
| @@ -0,0 +1,18 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public partial class Layer | |||
| { | |||
| /// <summary> | |||
| /// Loads all layer weights, either from a TensorFlow or an HDF5 weight file. | |||
| /// </summary> | |||
| /// <param name="filepath"></param> | |||
| public void load_weights(string filepath) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -56,6 +56,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// Provides information about which inputs are compatible with the layer. | |||
| /// </summary> | |||
| protected InputSpec inputSpec; | |||
| bool dynamic = true; | |||
| public bool SupportsMasking { get; set; } | |||
| protected List<IVariableV1> trainableWeights; | |||
| public List<IVariableV1> trainable_variables | |||
| @@ -88,6 +89,7 @@ namespace Tensorflow.Keras.Engine | |||
| ThreadLocal<CallContext> callContext; | |||
| public CallContext CallContext => callContext.Value; | |||
| public static List<KerasHistory> KerasHistories = new List<KerasHistory>(); | |||
| public Layer(LayerArgs args) | |||
| { | |||
| @@ -129,6 +131,11 @@ namespace Tensorflow.Keras.Engine | |||
| Value = new CallContext() | |||
| }; | |||
| var history = inputs.Where(x => x.KerasHistory != null | |||
| && !KerasHistories.Contains(x.KerasHistory)) | |||
| .Select(x => x.KerasHistory); | |||
| KerasHistories.AddRange(history); | |||
| if (_in_functional_construction_mode(inputs)) | |||
| return _functional_construction_call(inputs); | |||
| @@ -166,7 +173,8 @@ namespace Tensorflow.Keras.Engine | |||
| bool _in_functional_construction_mode(Tensors inputs) | |||
| { | |||
| return inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||
| return tf.Context.executing_eagerly() | |||
| && inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||
| } | |||
| Tensors _functional_construction_call(Tensors inputs) | |||
| @@ -191,6 +199,15 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| MaybeBuild(inputs); | |||
| // Wrapping `call` function in autograph to allow for dynamic control | |||
| // flow and control dependencies in call. We are limiting this to | |||
| // subclassed layers as autograph is strictly needed only for | |||
| // subclassed layers and models. | |||
| // tf_convert will respect the value of autograph setting in the | |||
| // enclosing tf.function, if any. | |||
| if (!dynamic) | |||
| throw new NotImplementedException(""); | |||
| outputs = call(inputs); | |||
| outputs = _set_connectivity_metadata_(inputs, outputs); | |||
| @@ -243,6 +260,13 @@ namespace Tensorflow.Keras.Engine | |||
| return null; | |||
| } | |||
| /// <summary> | |||
| /// Subclass has to override this method. | |||
| /// </summary> | |||
| /// <param name="inputs"></param> | |||
| /// <param name="state"></param> | |||
| /// <param name="is_training"></param> | |||
| /// <returns></returns> | |||
| protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| @@ -263,9 +287,9 @@ namespace Tensorflow.Keras.Engine | |||
| tf.init_scope(); | |||
| //tf.Context.eager_mode(); | |||
| tf.Context.eager_mode(); | |||
| build(inputs.shape); | |||
| //tf.Context.restore_mode(); | |||
| tf.Context.restore_mode(); | |||
| built = true; | |||
| } | |||
| @@ -282,18 +306,14 @@ namespace Tensorflow.Keras.Engine | |||
| protected virtual IVariableV1 add_weight(string name, | |||
| TensorShape shape, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| IInitializer initializer = null, | |||
| IRegularizer regularizer = null, | |||
| bool? trainable = null, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| VariableAggregation aggregation = VariableAggregation.None, | |||
| bool trainable = true, | |||
| Func<VariableArgs, IVariableV1> getter = null) | |||
| { | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = TF_DataType.TF_FLOAT; | |||
| if (trainable == null) | |||
| trainable = true; | |||
| // Initialize variable when no initializer provided | |||
| if (initializer == null) | |||
| { | |||
| @@ -306,6 +326,9 @@ namespace Tensorflow.Keras.Engine | |||
| throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | |||
| } | |||
| if (synchronization == VariableSynchronization.OnRead) | |||
| trainable = false; | |||
| var args = new VariableArgs | |||
| { | |||
| Name = name, | |||
| @@ -314,7 +337,9 @@ namespace Tensorflow.Keras.Engine | |||
| Getter = getter ?? base_layer_utils.make_variable, | |||
| Overwrite = true, | |||
| Initializer = initializer, | |||
| Trainable = trainable.Value | |||
| Synchronization = synchronization, | |||
| Aggregation = aggregation, | |||
| Trainable = trainable | |||
| }; | |||
| var variable = _add_variable_with_custom_getter(args); | |||
| @@ -58,7 +58,7 @@ namespace Tensorflow.Keras.Engine | |||
| // Set metadata on outputs. | |||
| var node_index = layer.InboundNodes.Count - 1; | |||
| foreach (var (i, tensor) in enumerate(Outputs)) | |||
| tensor.KerasHistory = new KerasHistory(layer, node_index, i); | |||
| tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); | |||
| } | |||
| } | |||
| } | |||
| @@ -38,8 +38,8 @@ namespace Tensorflow.Keras.Layers | |||
| string _data_format; | |||
| IInitializer beta_initializer => args.BetaInitializer; | |||
| IInitializer gamma_initializer => args.GammaInitializer; | |||
| IInitializer moving_mean_initializer; | |||
| IInitializer moving_variance_initializer; | |||
| IInitializer moving_mean_initializer => args.MovingMeanInitializer; | |||
| IInitializer moving_variance_initializer => args.MovingVarianceInitializer; | |||
| IRegularizer gamma_regularizer => args.GammaRegularizer; | |||
| IVariableV1 gamma; | |||
| IVariableV1 beta; | |||
| @@ -101,13 +101,17 @@ namespace Tensorflow.Keras.Layers | |||
| param_shape, | |||
| dtype: param_dtype, | |||
| initializer: moving_mean_initializer, | |||
| synchronization: VariableSynchronization.OnRead, | |||
| aggregation: VariableAggregation.Mean, | |||
| trainable: false); | |||
| moving_variance = add_weight("moving_variance", | |||
| shape: param_shape, | |||
| dtype: param_dtype, | |||
| initializer: moving_variance_initializer, | |||
| trainable: false); | |||
| shape: param_shape, | |||
| dtype: param_dtype, | |||
| initializer: moving_variance_initializer, | |||
| synchronization: VariableSynchronization.OnRead, | |||
| aggregation: VariableAggregation.Mean, | |||
| trainable: false); | |||
| if (renorm) | |||
| throw new NotImplementedException("build when renorm is true"); | |||
| @@ -131,6 +135,12 @@ namespace Tensorflow.Keras.Layers | |||
| private Tensor _fused_batch_norm(Tensor inputs, Tensor training) | |||
| { | |||
| TensorShape input_batch_size = null; | |||
| var use_fused_avg_updates = true; | |||
| float exponential_avg_factor = 0; | |||
| if (use_fused_avg_updates) | |||
| exponential_avg_factor = 1.0f - momentum; | |||
| var beta = this.beta; | |||
| var gamma = this.gamma; | |||
| @@ -146,17 +156,22 @@ namespace Tensorflow.Keras.Layers | |||
| Func<Tensor[]> _fused_batch_norm_inference = () => | |||
| { | |||
| var moving_mean_tensor = moving_mean.AsTensor(); | |||
| var moving_variance_tensor = moving_variance.AsTensor(); | |||
| return tf.nn.fused_batch_norm( | |||
| inputs, | |||
| gamma, | |||
| beta, | |||
| mean: moving_mean.AsTensor(), | |||
| variance: moving_variance.AsTensor(), | |||
| mean: moving_mean_tensor, | |||
| variance: moving_variance_tensor, | |||
| epsilon: epsilon, | |||
| is_training: false, | |||
| data_format: _data_format); | |||
| }; | |||
| if (use_fused_avg_updates && input_batch_size != null) | |||
| throw new NotImplementedException(""); | |||
| var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | |||
| var (output, mean, variance) = (results[0], results[1], results[2]); | |||
| var training_value = tf_utils.constant_value(training); | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| @@ -33,6 +34,7 @@ namespace Tensorflow.Keras.Layers | |||
| DataFormat = data_format, | |||
| DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, | |||
| Groups = groups, | |||
| UseBias = use_bias, | |||
| KernelRegularizer = kernel_regularizer, | |||
| KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer, | |||
| BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, | |||
| @@ -129,6 +131,17 @@ namespace Tensorflow.Keras.Layers | |||
| InputShape = input_shape | |||
| }); | |||
| /// <summary> | |||
| /// Zero-padding layer for 2D input (e.g. picture). | |||
| /// </summary> | |||
| /// <param name="padding"></param> | |||
| /// <returns></returns> | |||
| public ZeroPadding2D ZeroPadding2D(NDArray padding) | |||
| => new ZeroPadding2D(new ZeroPadding2DArgs | |||
| { | |||
| Padding = padding | |||
| }); | |||
| Activation GetActivationByName(string name) | |||
| => name switch | |||
| { | |||
| @@ -0,0 +1,39 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| /// <summary> | |||
| /// Zero-padding layer for 2D input (e.g. picture). | |||
| /// | |||
| /// This layer can add rows and columns of zeros | |||
| /// at the top, bottom, left and right side of an image tensor. | |||
| /// </summary> | |||
| public class ZeroPadding2D : Layer | |||
| { | |||
| string data_format; | |||
| NDArray padding; | |||
| InputSpec input_spec; | |||
| public ZeroPadding2D(ZeroPadding2DArgs args, string data_format = null) | |||
| : base(args) | |||
| { | |||
| this.data_format = conv_utils.normalize_data_format(data_format); | |||
| this.padding = args.Padding; | |||
| this.input_spec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| return tf.keras.backend.spatial_2d_padding(inputs, | |||
| padding: padding, | |||
| data_format: data_format); | |||
| } | |||
| } | |||
| } | |||
| @@ -127,7 +127,7 @@ namespace Tensorflow.Layers | |||
| int[] shape, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| IInitializer initializer = null, | |||
| bool? trainable = null, | |||
| bool trainable = true, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| VariableAggregation aggregation = VariableAggregation.None) | |||
| { | |||
| @@ -137,8 +137,6 @@ namespace Tensorflow.Layers | |||
| if (synchronization == VariableSynchronization.OnRead) | |||
| trainable = false; | |||
| else if (!trainable.HasValue) | |||
| trainable = true; | |||
| if (default_graph.building_function) | |||
| { | |||
| @@ -56,20 +56,24 @@ namespace Tensorflow.Operations | |||
| var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index); | |||
| Tensor result = null; | |||
| tf_with(ops.name_scope(name, default_name: null, (input, filters)), scope => | |||
| tf_with(ops.name_scope(name, default_name: null), scope => | |||
| { | |||
| name = scope; | |||
| if (num_spatial_dims == 2) | |||
| { | |||
| var filters_tensor = filters.AsTensor(); | |||
| result = gen_nn_ops.conv2d(new Conv2dParams | |||
| { | |||
| Input = input, | |||
| Filter = filters.AsTensor(), | |||
| Filter = filters_tensor, | |||
| Strides = strides, | |||
| Padding = padding, | |||
| DataFormat = data_format, | |||
| Dilations = dilations, | |||
| Name = name | |||
| }); | |||
| } | |||
| else | |||
| throw new NotImplementedException(""); | |||
| }); | |||
| @@ -263,7 +263,7 @@ namespace Tensorflow | |||
| List<TF_DataType> types, | |||
| List<TF_DataType> base_types, | |||
| List<TF_DataType> input_types, | |||
| dynamic values) | |||
| object values) | |||
| { | |||
| var input_name = input_arg.Name; | |||
| @@ -73,6 +73,16 @@ namespace Tensorflow | |||
| return _op.output; | |||
| } | |||
| public static Tensor concat_v2(Tensor[] values, int axis, string name = null) | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("ConcatV2", name: name, | |||
| args: new { values, axis }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "ConcatV2", name, | |||
| null, | |||
| values, axis).FirstOrDefault(), | |||
| values); | |||
| private static Tensor concat_v2_eager_fallback<T1, T2>(T1[] values, T2 axis, string name, Context ctx) | |||
| { | |||
| var _attr_N = len(values); | |||
| @@ -293,20 +303,13 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor reshape<T>(Tensor tensor, T shape, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Reshape", name, | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Reshape", name, | |||
| null, | |||
| tensor, shape); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }); | |||
| return _op.output; | |||
| } | |||
| tensor, shape).FirstOrDefault(), | |||
| tensor); | |||
| public static Tensor reshape(Tensor tensor, int[] shape, string name = null) | |||
| { | |||
| @@ -399,21 +402,15 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Shape", name, | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Shape", name, | |||
| new { input, out_type }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Shape", name, | |||
| null, | |||
| input, | |||
| "out_type", out_type); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("Shape", name, new { input, out_type }); | |||
| return _op.outputs[0]; | |||
| } | |||
| "out_type", out_type).FirstOrDefault(), | |||
| input); | |||
| /// <summary> | |||
| /// Returns shape of tensors. | |||
| @@ -460,20 +457,13 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor tile<T>(Tensor input, T multiples, string name = null) | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Tile", name, | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Tile", name, | |||
| null, | |||
| input, multiples); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }); | |||
| return _op.outputs[0]; | |||
| } | |||
| input, multiples).FirstOrDefault(), | |||
| input); | |||
| public static Tensor transpose<T1, T2>(T1 x, T2 perm, string name = null) | |||
| { | |||
| @@ -510,37 +500,29 @@ namespace Tensorflow | |||
| int new_axis_mask = 0, | |||
| int shrink_axis_mask = 0, | |||
| string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "StridedSlice", name, | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("StridedSlice", name, new | |||
| { | |||
| input, | |||
| begin, | |||
| end, | |||
| strides, | |||
| begin_mask, | |||
| end_mask, | |||
| ellipsis_mask, | |||
| new_axis_mask, | |||
| shrink_axis_mask | |||
| }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "StridedSlice", name, | |||
| null, | |||
| input, begin, end, strides, | |||
| "begin_mask", begin_mask, | |||
| "end_mask", end_mask, | |||
| "ellipsis_mask", ellipsis_mask, | |||
| "new_axis_mask", new_axis_mask, | |||
| "shrink_axis_mask", shrink_axis_mask); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("StridedSlice", name, new | |||
| { | |||
| input, | |||
| begin, | |||
| end, | |||
| strides, | |||
| begin_mask, | |||
| end_mask, | |||
| ellipsis_mask, | |||
| new_axis_mask, | |||
| shrink_axis_mask | |||
| }); | |||
| return _op.outputs[0]; | |||
| } | |||
| "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
| input, begin, end, strides); | |||
| public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides, | |||
| int begin_mask = 0, | |||
| @@ -319,21 +319,13 @@ namespace Tensorflow | |||
| /// Specifically, <c>y = 1 / (1 + exp(-x))</c>. | |||
| /// </remarks> | |||
| public static Tensor sigmoid(Tensor x, string name = "Sigmoid") | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Sigmoid", name, | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Sigmoid", name, | |||
| null, | |||
| x); | |||
| return results[0]; | |||
| } | |||
| var op = tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }); | |||
| return op.output; | |||
| } | |||
| x).FirstOrDefault(), | |||
| x); | |||
| /// <summary> | |||
| /// Computes the gradient of the sigmoid of <c>x</c> wrt its input. | |||
| @@ -668,11 +660,13 @@ namespace Tensorflow | |||
| /// <param name="name"> A name for the operation (optional).</param> | |||
| /// <returns> A `Tensor`. Has the same type as `x`.</returns> | |||
| public static Tensor exp(Tensor x, string name = null) | |||
| { | |||
| var _op = tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }); | |||
| return _op.outputs[0]; | |||
| } | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Exp", name, | |||
| null, | |||
| x).FirstOrDefault(), | |||
| x); | |||
| /// <summary> | |||
| /// Computes natural logarithm of x element-wise. | |||
| @@ -698,22 +692,14 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Cast", name, | |||
| null, | |||
| x, | |||
| "DstT", DstT, "Truncate", Truncate); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); | |||
| return _op.outputs[0]; | |||
| } | |||
| "DstT", DstT, "Truncate", Truncate).FirstOrDefault(), | |||
| x); | |||
| public static Tensor neg(Tensor x, string name = null) | |||
| { | |||
| @@ -1151,20 +1137,13 @@ namespace Tensorflow | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Range", name, | |||
| => tf.Context.RunInAutoMode(() | |||
| => tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }).output, () | |||
| => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Range", name, | |||
| null, | |||
| start, limit, delta); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }); | |||
| return _op.outputs[0]; | |||
| } | |||
| start, limit, delta).FirstOrDefault(), | |||
| start, limit, delta); | |||
| /// <summary> | |||
| /// Rounds the values of a tensor to the nearest integer, element-wise. | |||
| @@ -225,14 +225,12 @@ namespace Tensorflow | |||
| public static string name_from_scope_name(string name) | |||
| { | |||
| if (name.EndsWith("/")) | |||
| { | |||
| if (name == null) | |||
| return null; | |||
| else if (name.EndsWith("/")) | |||
| return name.Substring(0, name.Length - 1); | |||
| } | |||
| else | |||
| { | |||
| return name; | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -444,7 +442,12 @@ namespace Tensorflow | |||
| case NDArray nd: | |||
| return constant_op.constant(nd, dtype: dtype, name: name); | |||
| case EagerTensor tensor: | |||
| return tf.executing_eagerly() ? tensor : tensor.AsPlaceholder(name: name); | |||
| if (tf.executing_eagerly()) | |||
| return tensor; | |||
| else | |||
| return tensor.dtype == TF_DataType.TF_RESOURCE | |||
| ? tensor.AsPlaceholder(name: name) | |||
| : tensor.AsContatnt(name: name); | |||
| case Tensor tensor: | |||
| return tensor; | |||
| case Tensor[] tensors: | |||
| @@ -48,13 +48,13 @@ namespace Tensorflow | |||
| public void __enter__() | |||
| { | |||
| _name = _name ?? _default_name; | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| (scope_name, old_scope_name) = enter_eager_name_scope(tf.Context, _name); | |||
| } | |||
| else | |||
| { | |||
| _name = _name ?? _default_name; | |||
| Graph g = null; | |||
| if (_values is List<Tensor> vList) | |||
| @@ -72,7 +72,8 @@ namespace Tensorflow | |||
| private (string, string) enter_eager_name_scope(Context ctx, string name) | |||
| { | |||
| if (name == null) | |||
| return (null, null); | |||
| /*if (name == null) | |||
| name = ""; | |||
| var scope_name = name; | |||
| @@ -87,7 +88,7 @@ namespace Tensorflow | |||
| } | |||
| ctx.ScopeName = scope_name; | |||
| return (scope_name, old_name); | |||
| return (scope_name, old_name);*/ | |||
| } | |||
| [DebuggerHidden] | |||