| @@ -26,6 +26,20 @@ namespace Tensorflow | |||||
| name: name); | name: name); | ||||
| public static IActivation relu => new relu(); | public static IActivation relu => new relu(); | ||||
| public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, | |||||
| RefVariable scale, | |||||
| RefVariable offset, | |||||
| Tensor mean = null, | |||||
| Tensor variance = null, | |||||
| float epsilon = 0.001f, | |||||
| string data_format = "NHWC", | |||||
| bool is_training = true, | |||||
| string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, | |||||
| epsilon: epsilon, | |||||
| data_format: data_format, | |||||
| is_training: is_training, | |||||
| name: name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -6,7 +6,10 @@ namespace Tensorflow.Framework | |||||
| { | { | ||||
| public class smart_module | public class smart_module | ||||
| { | { | ||||
| public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||||
| public static object smart_cond(Tensor pred, | |||||
| Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||||
| Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||||
| string name = null) | |||||
| { | { | ||||
| return control_flow_ops.cond(pred, | return control_flow_ops.cond(pred, | ||||
| true_fn: true_fn, | true_fn: true_fn, | ||||
| @@ -8,7 +8,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Graph | public partial class Graph | ||||
| { | { | ||||
| public Context _control_flow_context; | |||||
| public IControlFlowContext _control_flow_context; | |||||
| private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); | private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); | ||||
| public Queue<_ControlDependenciesController> _control_dependencies_stack | public Queue<_ControlDependenciesController> _control_dependencies_stack | ||||
| @@ -72,7 +72,7 @@ namespace Tensorflow | |||||
| /// Returns the current control flow context. | /// Returns the current control flow context. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns>A context object.</returns> | /// <returns>A context object.</returns> | ||||
| public Context _get_control_flow_context() | |||||
| public IControlFlowContext _get_control_flow_context() | |||||
| { | { | ||||
| return _control_flow_context; | return _control_flow_context; | ||||
| } | } | ||||
| @@ -81,7 +81,7 @@ namespace Tensorflow | |||||
| /// Sets the current control flow context. | /// Sets the current control flow context. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="ctx">a context object.</param> | /// <param name="ctx">a context object.</param> | ||||
| public void _set_control_flow_context(Context ctx) | |||||
| public void _set_control_flow_context(IControlFlowContext ctx) | |||||
| { | { | ||||
| _control_flow_context = ctx; | _control_flow_context = ctx; | ||||
| } | } | ||||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||||
| private List<ITensorOrOperation> _seen_nodes; | private List<ITensorOrOperation> _seen_nodes; | ||||
| private Queue<_ControlDependenciesController> _old_stack; | private Queue<_ControlDependenciesController> _old_stack; | ||||
| private bool _new_stack; | private bool _new_stack; | ||||
| private Context _old_control_flow_context; | |||||
| private IControlFlowContext _old_control_flow_context; | |||||
| public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | ||||
| @@ -142,14 +142,27 @@ namespace Tensorflow.Keras.Layers | |||||
| var beta = this.beta; | var beta = this.beta; | ||||
| var gamma = this.gamma; | var gamma = this.gamma; | ||||
| Action _fused_batch_norm_training = () => | |||||
| Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () => | |||||
| { | { | ||||
| return tf.nn.fused_batch_norm( | |||||
| inputs, | |||||
| gamma, | |||||
| beta, | |||||
| epsilon: epsilon, | |||||
| data_format: _data_format); | |||||
| }; | }; | ||||
| Action _fused_batch_norm_inference = () => | |||||
| Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () => | |||||
| { | { | ||||
| return tf.nn.fused_batch_norm( | |||||
| inputs, | |||||
| gamma, | |||||
| beta, | |||||
| mean: moving_mean, | |||||
| variance: moving_variance, | |||||
| epsilon: epsilon, | |||||
| is_training: false, | |||||
| data_format: _data_format); | |||||
| }; | }; | ||||
| tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | ||||
| @@ -18,7 +18,10 @@ namespace Tensorflow.Keras.Utils | |||||
| return true; | return true; | ||||
| } | } | ||||
| public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||||
| public static object smart_cond(Tensor pred, | |||||
| Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||||
| Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||||
| string name = null) | |||||
| { | { | ||||
| return smart_module.smart_cond(pred, | return smart_module.smart_cond(pred, | ||||
| true_fn: true_fn, | true_fn: true_fn, | ||||
| @@ -0,0 +1,76 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| /// <summary> | |||||
| /// The context for the conditional construct. | |||||
| /// </summary> | |||||
| public class CondContext : ControlFlowContext | |||||
| { | |||||
| private string _name; | |||||
| /// <summary> | |||||
| /// The boolean tensor for the cond predicate | |||||
| /// </summary> | |||||
| private Tensor _pred; | |||||
| /// <summary> | |||||
| /// The predicate tensor in this branch | |||||
| /// </summary> | |||||
| private Tensor _pivot; | |||||
| /// <summary> | |||||
| /// 0 or 1 representing this branch | |||||
| /// </summary> | |||||
| private int _branch; | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| private List<string> _values = new List<string>(); | |||||
| private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="pred">The `boolean` tensor for the conditional predicate.</param> | |||||
| /// <param name="pivot">The predicate tensor in this branch.</param> | |||||
| /// <param name="branch">0 or 1 representing this branch.</param> | |||||
| /// <param name="name">Name of the `CondContext` python object.</param> | |||||
| /// <param name="context_def"></param> | |||||
| /// <param name="import_scope"></param> | |||||
| public CondContext(Tensor pred, | |||||
| Tensor pivot, | |||||
| int branch, | |||||
| string name = "cond_text", | |||||
| object context_def = null, | |||||
| string import_scope = null) | |||||
| { | |||||
| _name = ops.get_default_graph().unique_name(name); | |||||
| if (context_def != null) | |||||
| throw new NotImplementedException("CondContext context_def is not null"); | |||||
| else | |||||
| { | |||||
| // Initializes the default fields. | |||||
| base.__init__(); | |||||
| _pred = pred; | |||||
| _pivot = pivot; | |||||
| // Values considered to have been already seen in this context. pred is not | |||||
| // included in this context. | |||||
| _values.Add(pred.name); | |||||
| _external_values[pred.name] = pred; | |||||
| _values.Add(pivot.name); | |||||
| pivot.op._set_control_flow_context(this); | |||||
| } | |||||
| } | |||||
| public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn) | |||||
| { | |||||
| // Add the subgraph defined by fn() to the graph. | |||||
| var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| var original_result = fn(); | |||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| return original_result; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,46 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| public abstract class ControlFlowContext : IPython, IControlFlowContext | |||||
| { | |||||
| protected Stack<IControlFlowContext> _context_stack; | |||||
| public ControlFlowContext() | |||||
| { | |||||
| _context_stack = new Stack<IControlFlowContext>(); | |||||
| } | |||||
| public void __init__() | |||||
| { | |||||
| } | |||||
| public void __enter__() | |||||
| { | |||||
| } | |||||
| public virtual void Enter() | |||||
| { | |||||
| var graph = ops.get_default_graph(); | |||||
| _context_stack.Push(graph._get_control_flow_context()); | |||||
| graph._set_control_flow_context(this); | |||||
| } | |||||
| public void Exit() | |||||
| { | |||||
| var graph = ops.get_default_graph(); | |||||
| var last_context = _context_stack.Pop(); | |||||
| graph._set_control_flow_context(last_context); | |||||
| } | |||||
| public void __exit__() | |||||
| { | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public interface IControlFlowContext | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| public class WhileContext : ControlFlowContext | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -52,5 +52,30 @@ namespace Tensorflow.Operations | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static (Tensor, Tensor, Tensor) _fused_batch_norm(Tensor x, | |||||
| Tensor scale, | |||||
| Tensor offset, | |||||
| Tensor mean, | |||||
| Tensor variance, | |||||
| float epsilon = 0.0001f, | |||||
| string data_format = "NHWC", | |||||
| bool is_training = true, | |||||
| string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("FusedBatchNorm", name: name, args: new | |||||
| { | |||||
| x, | |||||
| scale, | |||||
| offset, | |||||
| mean, | |||||
| variance, | |||||
| epsilon, | |||||
| data_format, | |||||
| is_training | |||||
| }); | |||||
| return (_op.outputs[0], _op.outputs[1], _op.outputs[2]); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,11 +1,14 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class Operation | public partial class Operation | ||||
| { | { | ||||
| private CondContext _control_flow_context; | |||||
| /// <summary> | /// <summary> | ||||
| /// Add this op to its control flow context. | /// Add this op to its control flow context. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -24,5 +27,10 @@ namespace Tensorflow | |||||
| c_api.TF_AddControlInput(graph, op); | c_api.TF_AddControlInput(graph, op); | ||||
| } | } | ||||
| } | } | ||||
| public void _set_control_flow_context(CondContext ctx) | |||||
| { | |||||
| _control_flow_context = ctx; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -136,9 +137,9 @@ namespace Tensorflow | |||||
| return gen_array_ops.identity(data, name: name); | return gen_array_ops.identity(data, name: name); | ||||
| } | } | ||||
| public static (Tensor, Tensor) cond(Tensor pred, | |||||
| Action true_fn = null, | |||||
| Action false_fn = null, | |||||
| public static (Tensor, Tensor) cond(Tensor pred, | |||||
| Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||||
| Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||||
| bool strict = false, | bool strict = false, | ||||
| string name = null) | string name = null) | ||||
| { | { | ||||
| @@ -154,6 +155,22 @@ namespace Tensorflow | |||||
| foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | ||||
| tensor.op.graph.prevent_fetching(tensor.op); | tensor.op.graph.prevent_fetching(tensor.op); | ||||
| // Build the graph for the true branch in a new context. | |||||
| var context_t = new CondContext(pred, pivot_1, branch: 1); | |||||
| context_t.Enter(); | |||||
| var res_t = context_t.BuildCondBranch(true_fn); | |||||
| context_t.Exit(); | |||||
| // Build the graph for the false branch in a new context. | |||||
| var context_f = new CondContext(pred, pivot_2, branch: 0); | |||||
| context_f.Enter(); | |||||
| var res_f = context_f.BuildCondBranch(false_fn); | |||||
| context_f.Exit(); | |||||
| var res_t_flat = new Tensor[] { res_t.Item1, res_t.Item2, res_t.Item3 }; | |||||
| var res_f_flat = new Tensor[] { res_f.Item1, res_f.Item2, res_f.Item3 }; | |||||
| return (p_2, p_1); | return (p_2, p_1); | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -44,5 +45,36 @@ namespace Tensorflow | |||||
| return (mean, variance); | return (mean, variance); | ||||
| }); | }); | ||||
| } | } | ||||
| public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, | |||||
| RefVariable scale, | |||||
| RefVariable offset, | |||||
| Tensor mean, | |||||
| Tensor variance, | |||||
| float epsilon = 0.001f, | |||||
| string data_format = "NHWC", | |||||
| bool is_training = true, | |||||
| string name = null) | |||||
| { | |||||
| x = ops.convert_to_tensor(x, name: "input"); | |||||
| var scale_tensor = ops.convert_to_tensor(scale, name: "scale"); | |||||
| var offset_tensor = ops.convert_to_tensor(offset, name: "offset"); | |||||
| if (mean == null) | |||||
| mean = constant_op.constant(new float[0]); | |||||
| if(variance == null) | |||||
| variance = constant_op.constant(new float[0]); | |||||
| var min_epsilon = 1.001e-5f; | |||||
| epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; | |||||
| return gen_nn_ops._fused_batch_norm(x, | |||||
| scale_tensor, | |||||
| offset_tensor, | |||||
| mean, | |||||
| variance, | |||||
| epsilon, | |||||
| data_format, | |||||
| is_training, | |||||
| name); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -107,6 +107,9 @@ namespace Tensorflow | |||||
| case float floatVal: | case float floatVal: | ||||
| nparray = floatVal; | nparray = floatVal; | ||||
| break; | break; | ||||
| case float[] floatVals: | |||||
| nparray = floatVals; | |||||
| break; | |||||
| case double doubleVal: | case double doubleVal: | ||||
| nparray = doubleVal; | nparray = doubleVal; | ||||
| break; | break; | ||||
| @@ -44,6 +44,9 @@ namespace Tensorflow | |||||
| /// Key to collect update_ops | /// Key to collect update_ops | ||||
| /// </summary> | /// </summary> | ||||
| public static string UPDATE_OPS = "update_ops"; | public static string UPDATE_OPS = "update_ops"; | ||||
| // Used to store v2 summary names. | |||||
| public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||