| @@ -26,6 +26,20 @@ namespace Tensorflow | |||
| name: name); | |||
| public static IActivation relu => new relu(); | |||
| public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, | |||
| RefVariable scale, | |||
| RefVariable offset, | |||
| Tensor mean = null, | |||
| Tensor variance = null, | |||
| float epsilon = 0.001f, | |||
| string data_format = "NHWC", | |||
| bool is_training = true, | |||
| string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, | |||
| epsilon: epsilon, | |||
| data_format: data_format, | |||
| is_training: is_training, | |||
| name: name); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,7 +6,10 @@ namespace Tensorflow.Framework | |||
| { | |||
| public class smart_module | |||
| { | |||
| public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||
| public static object smart_cond(Tensor pred, | |||
| Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||
| Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||
| string name = null) | |||
| { | |||
| return control_flow_ops.cond(pred, | |||
| true_fn: true_fn, | |||
| @@ -8,7 +8,7 @@ namespace Tensorflow | |||
| { | |||
| public partial class Graph | |||
| { | |||
| public Context _control_flow_context; | |||
| public IControlFlowContext _control_flow_context; | |||
| private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); | |||
| public Queue<_ControlDependenciesController> _control_dependencies_stack | |||
| @@ -72,7 +72,7 @@ namespace Tensorflow | |||
| /// Returns the current control flow context. | |||
| /// </summary> | |||
| /// <returns>A context object.</returns> | |||
| public Context _get_control_flow_context() | |||
| public IControlFlowContext _get_control_flow_context() | |||
| { | |||
| return _control_flow_context; | |||
| } | |||
| @@ -81,7 +81,7 @@ namespace Tensorflow | |||
| /// Sets the current control flow context. | |||
| /// </summary> | |||
| /// <param name="ctx">a context object.</param> | |||
| public void _set_control_flow_context(Context ctx) | |||
| public void _set_control_flow_context(IControlFlowContext ctx) | |||
| { | |||
| _control_flow_context = ctx; | |||
| } | |||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||
| private List<ITensorOrOperation> _seen_nodes; | |||
| private Queue<_ControlDependenciesController> _old_stack; | |||
| private bool _new_stack; | |||
| private Context _old_control_flow_context; | |||
| private IControlFlowContext _old_control_flow_context; | |||
| public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | |||
| @@ -142,14 +142,27 @@ namespace Tensorflow.Keras.Layers | |||
| var beta = this.beta; | |||
| var gamma = this.gamma; | |||
| Action _fused_batch_norm_training = () => | |||
| Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () => | |||
| { | |||
| return tf.nn.fused_batch_norm( | |||
| inputs, | |||
| gamma, | |||
| beta, | |||
| epsilon: epsilon, | |||
| data_format: _data_format); | |||
| }; | |||
| Action _fused_batch_norm_inference = () => | |||
| Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () => | |||
| { | |||
| return tf.nn.fused_batch_norm( | |||
| inputs, | |||
| gamma, | |||
| beta, | |||
| mean: moving_mean, | |||
| variance: moving_variance, | |||
| epsilon: epsilon, | |||
| is_training: false, | |||
| data_format: _data_format); | |||
| }; | |||
| tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | |||
| @@ -18,7 +18,10 @@ namespace Tensorflow.Keras.Utils | |||
| return true; | |||
| } | |||
| public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||
| public static object smart_cond(Tensor pred, | |||
| Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||
| Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||
| string name = null) | |||
| { | |||
| return smart_module.smart_cond(pred, | |||
| true_fn: true_fn, | |||
| @@ -0,0 +1,76 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| /// <summary> | |||
| /// The context for the conditional construct. | |||
| /// </summary> | |||
| public class CondContext : ControlFlowContext | |||
| { | |||
| private string _name; | |||
| /// <summary> | |||
| /// The boolean tensor for the cond predicate | |||
| /// </summary> | |||
| private Tensor _pred; | |||
| /// <summary> | |||
| /// The predicate tensor in this branch | |||
| /// </summary> | |||
| private Tensor _pivot; | |||
| /// <summary> | |||
| /// 0 or 1 representing this branch | |||
| /// </summary> | |||
| private int _branch; | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| private List<string> _values = new List<string>(); | |||
| private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="pred">The `boolean` tensor for the conditional predicate.</param> | |||
| /// <param name="pivot">The predicate tensor in this branch.</param> | |||
| /// <param name="branch">0 or 1 representing this branch.</param> | |||
| /// <param name="name">Name of the `CondContext` python object.</param> | |||
| /// <param name="context_def"></param> | |||
| /// <param name="import_scope"></param> | |||
| public CondContext(Tensor pred, | |||
| Tensor pivot, | |||
| int branch, | |||
| string name = "cond_text", | |||
| object context_def = null, | |||
| string import_scope = null) | |||
| { | |||
| _name = ops.get_default_graph().unique_name(name); | |||
| if (context_def != null) | |||
| throw new NotImplementedException("CondContext context_def is not null"); | |||
| else | |||
| { | |||
| // Initializes the default fields. | |||
| base.__init__(); | |||
| _pred = pred; | |||
| _pivot = pivot; | |||
| // Values considered to have been already seen in this context. pred is not | |||
| // included in this context. | |||
| _values.Add(pred.name); | |||
| _external_values[pred.name] = pred; | |||
| _values.Add(pivot.name); | |||
| pivot.op._set_control_flow_context(this); | |||
| } | |||
| } | |||
| public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn) | |||
| { | |||
| // Add the subgraph defined by fn() to the graph. | |||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
| var original_result = fn(); | |||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
| return original_result; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,46 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public abstract class ControlFlowContext : IPython, IControlFlowContext | |||
| { | |||
| protected Stack<IControlFlowContext> _context_stack; | |||
| public ControlFlowContext() | |||
| { | |||
| _context_stack = new Stack<IControlFlowContext>(); | |||
| } | |||
| public void __init__() | |||
| { | |||
| } | |||
| public void __enter__() | |||
| { | |||
| } | |||
| public virtual void Enter() | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| _context_stack.Push(graph._get_control_flow_context()); | |||
| graph._set_control_flow_context(this); | |||
| } | |||
| public void Exit() | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| var last_context = _context_stack.Pop(); | |||
| graph._set_control_flow_context(last_context); | |||
| } | |||
| public void __exit__() | |||
| { | |||
| } | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public interface IControlFlowContext | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| public class WhileContext : ControlFlowContext | |||
| { | |||
| } | |||
| } | |||
| @@ -52,5 +52,30 @@ namespace Tensorflow.Operations | |||
| return _op.outputs[0]; | |||
| } | |||
| public static (Tensor, Tensor, Tensor) _fused_batch_norm(Tensor x, | |||
| Tensor scale, | |||
| Tensor offset, | |||
| Tensor mean, | |||
| Tensor variance, | |||
| float epsilon = 0.0001f, | |||
| string data_format = "NHWC", | |||
| bool is_training = true, | |||
| string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("FusedBatchNorm", name: name, args: new | |||
| { | |||
| x, | |||
| scale, | |||
| offset, | |||
| mean, | |||
| variance, | |||
| epsilon, | |||
| data_format, | |||
| is_training | |||
| }); | |||
| return (_op.outputs[0], _op.outputs[1], _op.outputs[2]); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,11 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Operation | |||
| { | |||
| private CondContext _control_flow_context; | |||
| /// <summary> | |||
| /// Add this op to its control flow context. | |||
| /// </summary> | |||
| @@ -24,5 +27,10 @@ namespace Tensorflow | |||
| c_api.TF_AddControlInput(graph, op); | |||
| } | |||
| } | |||
| public void _set_control_flow_context(CondContext ctx) | |||
| { | |||
| _control_flow_context = ctx; | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -136,9 +137,9 @@ namespace Tensorflow | |||
| return gen_array_ops.identity(data, name: name); | |||
| } | |||
| public static (Tensor, Tensor) cond(Tensor pred, | |||
| Action true_fn = null, | |||
| Action false_fn = null, | |||
| public static (Tensor, Tensor) cond(Tensor pred, | |||
| Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||
| Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||
| bool strict = false, | |||
| string name = null) | |||
| { | |||
| @@ -154,6 +155,22 @@ namespace Tensorflow | |||
| foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | |||
| tensor.op.graph.prevent_fetching(tensor.op); | |||
| // Build the graph for the true branch in a new context. | |||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | |||
| context_t.Enter(); | |||
| var res_t = context_t.BuildCondBranch(true_fn); | |||
| context_t.Exit(); | |||
| // Build the graph for the false branch in a new context. | |||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | |||
| context_f.Enter(); | |||
| var res_f = context_f.BuildCondBranch(false_fn); | |||
| context_f.Exit(); | |||
| var res_t_flat = new Tensor[] { res_t.Item1, res_t.Item2, res_t.Item3 }; | |||
| var res_f_flat = new Tensor[] { res_f.Item1, res_f.Item2, res_f.Item3 }; | |||
| return (p_2, p_1); | |||
| }); | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -44,5 +45,36 @@ namespace Tensorflow | |||
| return (mean, variance); | |||
| }); | |||
| } | |||
| public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, | |||
| RefVariable scale, | |||
| RefVariable offset, | |||
| Tensor mean, | |||
| Tensor variance, | |||
| float epsilon = 0.001f, | |||
| string data_format = "NHWC", | |||
| bool is_training = true, | |||
| string name = null) | |||
| { | |||
| x = ops.convert_to_tensor(x, name: "input"); | |||
| var scale_tensor = ops.convert_to_tensor(scale, name: "scale"); | |||
| var offset_tensor = ops.convert_to_tensor(offset, name: "offset"); | |||
| if (mean == null) | |||
| mean = constant_op.constant(new float[0]); | |||
| if(variance == null) | |||
| variance = constant_op.constant(new float[0]); | |||
| var min_epsilon = 1.001e-5f; | |||
| epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; | |||
| return gen_nn_ops._fused_batch_norm(x, | |||
| scale_tensor, | |||
| offset_tensor, | |||
| mean, | |||
| variance, | |||
| epsilon, | |||
| data_format, | |||
| is_training, | |||
| name); | |||
| } | |||
| } | |||
| } | |||
| @@ -107,6 +107,9 @@ namespace Tensorflow | |||
| case float floatVal: | |||
| nparray = floatVal; | |||
| break; | |||
| case float[] floatVals: | |||
| nparray = floatVals; | |||
| break; | |||
| case double doubleVal: | |||
| nparray = doubleVal; | |||
| break; | |||
| @@ -44,6 +44,9 @@ namespace Tensorflow | |||
| /// Key to collect update_ops | |||
| /// </summary> | |||
| public static string UPDATE_OPS = "update_ops"; | |||
| // Used to store v2 summary names. | |||
| public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||
| } | |||
| } | |||
| } | |||