| @@ -9,4 +9,4 @@ community_bridge: # Replace with a single Community Bridge project-name e.g., cl | |||||
| liberapay: # Replace with a single Liberapay username | liberapay: # Replace with a single Liberapay username | ||||
| issuehunt: # Replace with a single IssueHunt username | issuehunt: # Replace with a single IssueHunt username | ||||
| otechie: # Replace with a single Otechie username | otechie: # Replace with a single Otechie username | ||||
| custom: ['https://paypal.me/pools/c/8fK9eKwbbL']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] | |||||
| custom: ['https://bit.ly/2op1mu5']# Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] | |||||
| @@ -43,7 +43,8 @@ namespace Tensorflow | |||||
| public ExponentialMovingAverage ExponentialMovingAverage(float decay) | public ExponentialMovingAverage ExponentialMovingAverage(float decay) | ||||
| => new ExponentialMovingAverage(decay); | => new ExponentialMovingAverage(decay); | ||||
| public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); | |||||
| public Saver Saver(VariableV1[] var_list = null, int max_to_keep = 5) | |||||
| => new Saver(var_list: var_list, max_to_keep: max_to_keep); | |||||
| public string write_graph(Graph graph, string logdir, string name, bool as_text = true) | public string write_graph(Graph graph, string logdir, string name, bool as_text = true) | ||||
| => graph_io.write_graph(graph, logdir, name, as_text); | => graph_io.write_graph(graph, logdir, name, as_text); | ||||
| @@ -54,7 +55,7 @@ namespace Tensorflow | |||||
| clear_devices, | clear_devices, | ||||
| import_scope).Item1; | import_scope).Item1; | ||||
| public (MetaGraphDef, Dictionary<string, RefVariable>) export_meta_graph(string filename = "", | |||||
| public (MetaGraphDef, Dictionary<string, VariableV1>) export_meta_graph(string filename = "", | |||||
| bool as_text = false, | bool as_text = false, | ||||
| bool clear_devices = false, | bool clear_devices = false, | ||||
| bool clear_extraneous_savers = false, | bool clear_extraneous_savers = false, | ||||
| @@ -167,7 +167,7 @@ namespace Tensorflow | |||||
| /// <param name="strip_default_attrs"></param> | /// <param name="strip_default_attrs"></param> | ||||
| /// <param name="meta_info_def"></param> | /// <param name="meta_info_def"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static (MetaGraphDef, Dictionary<string, RefVariable>) export_scoped_meta_graph(string filename = "", | |||||
| public static (MetaGraphDef, Dictionary<string, VariableV1>) export_scoped_meta_graph(string filename = "", | |||||
| GraphDef graph_def = null, | GraphDef graph_def = null, | ||||
| bool as_text = false, | bool as_text = false, | ||||
| string unbound_inputs_col_name = "unbound_inputs", | string unbound_inputs_col_name = "unbound_inputs", | ||||
| @@ -179,8 +179,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| var var_list = new Dictionary<string, RefVariable>(); | |||||
| var variables = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_VARIABLES); | |||||
| var var_list = new Dictionary<string, VariableV1>(); | |||||
| var variables = graph.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES); | |||||
| if (variables != null) | if (variables != null) | ||||
| { | { | ||||
| @@ -190,6 +190,26 @@ namespace Tensorflow.Gradients | |||||
| return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; | ||||
| } | } | ||||
| [RegisterGradient("Pad")] | |||||
| public static Tensor[] _PadGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| var grad = grads[0]; | |||||
| var x = op.inputs[0]; | |||||
| var a = op.inputs[1]; | |||||
| var size = array_ops.stack(new object[] { array_ops.rank(x), 1 }); | |||||
| var pad_before = array_ops.slice(a, new[] { 0, 0 }, size); | |||||
| // Make it a 1-D tensor. | |||||
| var begin = array_ops.reshape(pad_before, new[] { -1 }); | |||||
| var sizes = array_ops.shape(x); | |||||
| var x_grad = array_ops.slice(grad, begin, sizes); | |||||
| if (len(op.inputs) == 3) | |||||
| return new Tensor[] { x_grad, null, null }; | |||||
| else | |||||
| return new Tensor[] { x_grad, null }; | |||||
| } | |||||
| [RegisterGradient("Squeeze")] | [RegisterGradient("Squeeze")] | ||||
| public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) | public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads) | ||||
| { | { | ||||
| @@ -36,56 +36,54 @@ namespace Tensorflow.Gradients | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [RegisterGradient("Switch")] | [RegisterGradient("Switch")] | ||||
| public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads) | |||||
| public static Tensor[] _SwitchGrad(Operation op, Tensor[] grads) | |||||
| { | { | ||||
| var grad = grads[0]; | |||||
| var graph = ops.get_default_graph(); | |||||
| var op_ctxt = op._get_control_flow_context(); | |||||
| var grad_ctxt = graph._get_control_flow_context(); | |||||
| switch (op_ctxt) | |||||
| { | |||||
| case WhileContext cwhile: | |||||
| throw new NotImplementedException("_SwitchGrad WhileContext"); | |||||
| case CondContext ccond: | |||||
| { | |||||
| var zero_grad = grads[1 - op_ctxt.branch]; | |||||
| // At this point, we have created zero_grad guarded by the right switch. | |||||
| // Unfortunately, we may still get None here for not trainable data types. | |||||
| if(zero_grad == null) | |||||
| { | |||||
| throw new NotImplementedException("_SwitchGrad CondContext zero_grad"); | |||||
| } | |||||
| return new Tensor[] | |||||
| { | |||||
| merge(grads, name: "cond_grad")[0], | |||||
| null | |||||
| }; | |||||
| } | |||||
| default: | |||||
| throw new NotImplementedException("_SwitchGrad WhileContext"); | |||||
| } | |||||
| throw new NotImplementedException("_SwitchGrad"); | throw new NotImplementedException("_SwitchGrad"); | ||||
| //graph = ops.get_default_graph() | |||||
| //# pylint: disable=protected-access | |||||
| //op_ctxt = op._get_control_flow_context() | |||||
| //grad_ctxt = graph._get_control_flow_context() | |||||
| //# pylint: enable=protected-access | |||||
| //if isinstance(op_ctxt, WhileContext): | |||||
| // merge_grad = grad_ctxt.grad_state.switch_map.get(op) | |||||
| // if merge_grad is not None: | |||||
| // # This is the second time this Switch is visited. It comes from | |||||
| // # the non-exit branch of the Switch, so update the second input | |||||
| // # to the Merge. | |||||
| // # TODO(yuanbyu): Perform shape inference with this new input. | |||||
| // if grad[1] is not None: | |||||
| // # pylint: disable=protected-access | |||||
| // control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], | |||||
| // enforce_shape_invariant=False) | |||||
| // # pylint: enable=protected-access | |||||
| // return None, None | |||||
| // elif grad[0] is not None: | |||||
| // # This is the first time this Switch is visited. It comes from | |||||
| // # the Exit branch, which is grad[0]. grad[1] is empty at this point. | |||||
| // # Use grad[0] for both inputs to merge for now, but update the second | |||||
| // # input of merge when we see this Switch the second time. | |||||
| // merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] | |||||
| // grad_ctxt.grad_state.switch_map[op] = merge_grad | |||||
| // return merge_grad, None | |||||
| // else: | |||||
| // # This is the first time this Switch is visited. It comes from the | |||||
| // # Identity branch. Such a Switch has `None` gradient for the Exit branch, | |||||
| // # meaning the output is not differentiable. | |||||
| // return None, None | |||||
| //elif isinstance(op_ctxt, CondContext): | |||||
| // zero_grad = grad[1 - op_ctxt.branch] | |||||
| // # At this point, we have created zero_grad guarded by the right switch. | |||||
| // # Unfortunately, we may still get None here for not trainable data types. | |||||
| // if zero_grad is None: | |||||
| // # For resource variables we get None always on the other branch, so bypass | |||||
| // # this. | |||||
| // if op.inputs[0].dtype == dtypes.resource: | |||||
| // return merge( | |||||
| // [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None | |||||
| // return None, None | |||||
| // return merge(grad, name="cond_grad")[0], None | |||||
| //else: | |||||
| // false_grad = switch(grad[0], op.inputs[1])[0] | |||||
| // true_grad = switch(grad[1], op.inputs[1])[1] | |||||
| // return merge([false_grad, true_grad])[0], None | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns the value of an available element of `inputs`. | |||||
| /// </summary> | |||||
| /// <param name="inputs"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| internal static Tensor[] merge(Tensor[] inputs, string name = null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "Merge", inputs), scope => | |||||
| { | |||||
| name = scope; | |||||
| if (inputs.Count(x => x.dtype.is_ref_dtype()) == inputs.Length) | |||||
| return gen_control_flow_ops.ref_merge(inputs, name: name); | |||||
| else | |||||
| return gen_control_flow_ops.merge(inputs, name: name); | |||||
| }); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -108,10 +108,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| // generate gradient subgraph for op. | // generate gradient subgraph for op. | ||||
| var op = queue.Dequeue(); | var op = queue.Dequeue(); | ||||
| if(tf.get_default_graph()._nodes_by_name.Count > 18505) | |||||
| { | |||||
| } | |||||
| _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); | ||||
| //if (loop_state != null) | //if (loop_state != null) | ||||
| //loop_state.EnterGradWhileContext(op, before: true); | //loop_state.EnterGradWhileContext(op, before: true); | ||||
| @@ -216,7 +213,7 @@ namespace Tensorflow | |||||
| in_grad.Tag == null && // maybe a IndexedSlice | in_grad.Tag == null && // maybe a IndexedSlice | ||||
| t_in.dtype != TF_DataType.TF_RESOURCE) | t_in.dtype != TF_DataType.TF_RESOURCE) | ||||
| { | { | ||||
| in_grad.shape = t_in.shape; | |||||
| in_grad.set_shape(t_in.TensorShape); | |||||
| } | } | ||||
| _SetGrad(grads, t_in, in_grad); | _SetGrad(grads, t_in, in_grad); | ||||
| @@ -0,0 +1,54 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| [RegisterGradient("image_grad")] | |||||
| public class image_grad | |||||
| { | |||||
| [RegisterGradient("ResizeNearestNeighbor")] | |||||
| public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| var grad = grads[0]; | |||||
| var image = op.inputs[0]; | |||||
| var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | |||||
| Tensor image_shape = null; | |||||
| if (shape.is_fully_defined()) | |||||
| throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined"); | |||||
| else | |||||
| image_shape = array_ops.shape(image)["1:3"]; | |||||
| grad = gen_image_ops.resize_nearest_neighbor_grad( | |||||
| grad, | |||||
| image_shape, | |||||
| align_corners: op.get_attr<bool>("align_corners"), | |||||
| half_pixel_centers: op.get_attr<bool>("half_pixel_centers")); | |||||
| return new Tensor[] | |||||
| { | |||||
| grad, | |||||
| null | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -166,6 +166,94 @@ namespace Tensorflow.Gradients | |||||
| }; | }; | ||||
| } | } | ||||
| [RegisterGradient("FusedBatchNorm")] | |||||
| public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) | |||||
| => _BaseFusedBatchNormGrad(op, 0, grads); | |||||
| /// <summary> | |||||
| /// Return the gradients for the 3 inputs of BatchNorm. | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| /// <param name="version"></param> | |||||
| /// <param name="grads"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads) | |||||
| { | |||||
| var x = op.inputs[0]; | |||||
| var grad_y = grads[0]; | |||||
| var scale = op.inputs[1]; | |||||
| var epsilon = op.get_attr<float>("epsilon"); | |||||
| var data_format = op.get_attr<string>("data_format"); | |||||
| var is_training = op.get_attr<bool>("is_training"); | |||||
| Func<FusedBatchNormParams, Tensor[]> grad_fun = null; | |||||
| switch (version) | |||||
| { | |||||
| case 2: | |||||
| throw new NotImplementedException(""); | |||||
| case 1: | |||||
| throw new NotImplementedException(""); | |||||
| default: | |||||
| grad_fun = gen_nn_ops.fused_batch_norm_grad; | |||||
| break; | |||||
| } | |||||
| if (is_training) | |||||
| { | |||||
| return grad_fun(new FusedBatchNormParams | |||||
| { | |||||
| YBackprop = grad_y, | |||||
| X = x, | |||||
| Scale = scale, | |||||
| ReserveSpace1 = op.outputs[3], | |||||
| ReserveSpace2 = op.outputs[4], | |||||
| ReserveSpace3 = version == 2 ? op.outputs[5] : null, | |||||
| Epsilon = epsilon, | |||||
| DataFormat = data_format, | |||||
| IsTraining = is_training | |||||
| }); | |||||
| } | |||||
| else | |||||
| { | |||||
| var pop_mean = op.inputs[3]; | |||||
| var pop_var = op.inputs[4]; | |||||
| if (data_format == "NCHW") | |||||
| throw new NotImplementedException(""); | |||||
| var results = grad_fun(new FusedBatchNormParams | |||||
| { | |||||
| YBackprop = grad_y, | |||||
| X = x, | |||||
| Scale = scale, | |||||
| ReserveSpace1 = op.outputs[3], | |||||
| ReserveSpace2 = op.outputs[4], | |||||
| ReserveSpace3 = version == 2 ? op.outputs[5] : null, | |||||
| Epsilon = epsilon, | |||||
| DataFormat = data_format, | |||||
| IsTraining = is_training | |||||
| }); | |||||
| var (dx, dscale, doffset) = (results[0], results[1], results[2]); | |||||
| if (data_format == "NCHW") | |||||
| throw new NotImplementedException(""); | |||||
| return new Tensor[] | |||||
| { | |||||
| dx, | |||||
| dscale, | |||||
| doffset, | |||||
| null, | |||||
| null | |||||
| }; | |||||
| } | |||||
| } | |||||
| [RegisterGradient("BatchNormWithGlobalNormalization")] | |||||
| public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads) | |||||
| { | |||||
| throw new NotImplementedException("BatchNormWithGlobalNormalization"); | |||||
| } | |||||
| private static bool IsZero(Tensor g) | private static bool IsZero(Tensor g) | ||||
| { | { | ||||
| if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) | if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type)) | ||||
| @@ -440,6 +440,9 @@ namespace Tensorflow | |||||
| case List<VariableV1> list: | case List<VariableV1> list: | ||||
| t = list.Select(x => (T)(object)x).ToList(); | t = list.Select(x => (T)(object)x).ToList(); | ||||
| break; | break; | ||||
| case List<ResourceVariable> list: | |||||
| t = list.Select(x => (T)(object)x).ToList(); | |||||
| break; | |||||
| case List<RefVariable> list: | case List<RefVariable> list: | ||||
| t = list.Select(x => (T)(object)x).ToList(); | t = list.Select(x => (T)(object)x).ToList(); | ||||
| break; | break; | ||||
| @@ -27,20 +27,6 @@ namespace Tensorflow.Operations | |||||
| /// </summary> | /// </summary> | ||||
| public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | ||||
| { | { | ||||
| /// <summary> | |||||
| /// The boolean tensor for the cond predicate | |||||
| /// </summary> | |||||
| private Tensor _pred; | |||||
| public Tensor pred => _pred; | |||||
| /// <summary> | |||||
| /// 0 or 1 representing this branch | |||||
| /// </summary> | |||||
| private int _branch; | |||||
| private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -45,10 +45,19 @@ namespace Tensorflow.Operations | |||||
| /// The predicate tensor in this branch | /// The predicate tensor in this branch | ||||
| /// </summary> | /// </summary> | ||||
| protected Tensor _pivot; | protected Tensor _pivot; | ||||
| public Tensor pivot | |||||
| { | |||||
| get => _pivot; | |||||
| } | |||||
| public Tensor pivot => _pivot; | |||||
| /// <summary> | |||||
| /// The boolean tensor for the cond predicate | |||||
| /// </summary> | |||||
| protected Tensor _pred; | |||||
| public Tensor pred => _pred; | |||||
| /// <summary> | |||||
| /// 0 or 1 representing this branch | |||||
| /// </summary> | |||||
| protected int _branch; | |||||
| public int branch => _branch; | |||||
| protected Stack<ControlFlowContext> _context_stack; | protected Stack<ControlFlowContext> _context_stack; | ||||
| protected ControlFlowContext _outer_context; | protected ControlFlowContext _outer_context; | ||||
| @@ -0,0 +1,27 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| public class FusedBatchNormParams | |||||
| { | |||||
| public string Name { get; set; } | |||||
| public Tensor YBackprop { get; set; } | |||||
| public Tensor X { get; set; } | |||||
| public Tensor Scale { get; set; } | |||||
| public Tensor ReserveSpace1 { get; set; } | |||||
| public Tensor ReserveSpace2 { get; set; } | |||||
| public Tensor ReserveSpace3 { get; set; } | |||||
| public float Epsilon { get; set; } | |||||
| public string DataFormat { get; set; } | |||||
| public bool IsTraining { get; set; } | |||||
| public FusedBatchNormParams() | |||||
| { | |||||
| Epsilon = 0.0001f; | |||||
| DataFormat = "NHWC"; | |||||
| IsTraining = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -156,6 +156,35 @@ namespace Tensorflow.Operations | |||||
| return op.output; | return op.output; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Gradient for batch normalization. | |||||
| /// </summary> | |||||
| /// <param name="y_backprop"></param> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="scale"></param> | |||||
| /// <param name="reserve_space_1"></param> | |||||
| /// <param name="reserve_space_2"></param> | |||||
| /// <param name="epsilon"></param> | |||||
| /// <param name="data_format"></param> | |||||
| /// <param name="is_training"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params) | |||||
| { | |||||
| var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new | |||||
| { | |||||
| y_backprop = @params.YBackprop, | |||||
| x = @params.X, | |||||
| scale = @params.Scale, | |||||
| reserve_space_1 = @params.ReserveSpace1, | |||||
| reserve_space_2 = @params.ReserveSpace2, | |||||
| epsilon = @params.Epsilon, | |||||
| data_format = @params.DataFormat, | |||||
| is_training = @params.IsTraining | |||||
| }); | |||||
| return op.outputs; | |||||
| } | |||||
| public static Tensor[] fused_batch_norm(Tensor x, | public static Tensor[] fused_batch_norm(Tensor x, | ||||
| Tensor scale, | Tensor scale, | ||||
| Tensor offset, | Tensor offset, | ||||
| @@ -53,7 +53,7 @@ namespace Tensorflow | |||||
| for (int i = 0; i < NumInputs; i++) | for (int i = 0; i < NumInputs; i++) | ||||
| { | { | ||||
| var tf_output = Input(i); | var tf_output = Input(i); | ||||
| var op = new Operation(tf_output.oper); | |||||
| var op = GetOperation(tf_output.oper); | |||||
| retval[i] = op.outputs[tf_output.index]; | retval[i] = op.outputs[tf_output.index]; | ||||
| } | } | ||||
| @@ -0,0 +1,41 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class Operation | |||||
| { | |||||
| // cache the mapping between managed and unmanaged op | |||||
| // some data is stored in managed instance, so when | |||||
| // create Operation by IntPtr, it will lost some data. | |||||
| private static Dictionary<IntPtr, Operation> OpInstances = new Dictionary<IntPtr, Operation>(); | |||||
| /// <summary> | |||||
| /// Get operation by handle | |||||
| /// </summary> | |||||
| /// <param name="handle"></param> | |||||
| /// <returns></returns> | |||||
| public Operation GetOperation(IntPtr handle) | |||||
| { | |||||
| return OpInstances.ContainsKey(handle) ? | |||||
| OpInstances[handle] : | |||||
| new Operation(handle); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -84,9 +84,10 @@ namespace Tensorflow | |||||
| _control_flow_context = _graph._get_control_flow_context(); | _control_flow_context = _graph._get_control_flow_context(); | ||||
| // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | ||||
| OpInstances[_handle] = this; | |||||
| } | } | ||||
| public Operation(Graph g, string opType, string oper_name) | |||||
| /*public Operation(Graph g, string opType, string oper_name) | |||||
| { | { | ||||
| _graph = g; | _graph = g; | ||||
| @@ -102,7 +103,7 @@ namespace Tensorflow | |||||
| // Dict mapping op name to file and line information for op colocation | // Dict mapping op name to file and line information for op colocation | ||||
| // context managers. | // context managers. | ||||
| _control_flow_context = graph._get_control_flow_context(); | _control_flow_context = graph._get_control_flow_context(); | ||||
| } | |||||
| }*/ | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates an `Operation`. | /// Creates an `Operation`. | ||||
| @@ -151,11 +152,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| if(node_def.Name == "define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNorm_1") | |||||
| { | |||||
| } | |||||
| // Dict mapping op name to file and line information for op colocation | // Dict mapping op name to file and line information for op colocation | ||||
| // context managers. | // context managers. | ||||
| _control_flow_context = graph._get_control_flow_context(); | _control_flow_context = graph._get_control_flow_context(); | ||||
| @@ -164,7 +160,7 @@ namespace Tensorflow | |||||
| if (op_def == null) | if (op_def == null) | ||||
| op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||||
| _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | ||||
| // Initialize self._outputs. | // Initialize self._outputs. | ||||
| @@ -180,6 +176,8 @@ namespace Tensorflow | |||||
| if (_handle != IntPtr.Zero) | if (_handle != IntPtr.Zero) | ||||
| _control_flow_post_processing(); | _control_flow_post_processing(); | ||||
| OpInstances[_handle] = this; | |||||
| } | } | ||||
| public void run(FeedItem[] feed_dict = null, Session session = null) | public void run(FeedItem[] feed_dict = null, Session session = null) | ||||
| @@ -220,6 +218,9 @@ namespace Tensorflow | |||||
| return grouped_inputs.ToArray(); | return grouped_inputs.ToArray(); | ||||
| } | } | ||||
| public T get_attr<T>(string name) | |||||
| => (T)get_attr(name); | |||||
| public object get_attr(string name) | public object get_attr(string name) | ||||
| { | { | ||||
| AttrValue x = null; | AttrValue x = null; | ||||
| @@ -611,7 +611,7 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | |||||
| public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||||
| => gen_array_ops.slice(input, begin, size, name: name); | => gen_array_ops.slice(input, begin, size, name: name); | ||||
| public static Tensor stack(object values, int axis = 0, string name = "stack") | public static Tensor stack(object values, int axis = 0, string name = "stack") | ||||
| @@ -518,7 +518,7 @@ namespace Tensorflow | |||||
| inputs = inputs.Select(inp => | inputs = inputs.Select(inp => | ||||
| ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref: true)) | ||||
| .ToArray(); | .ToArray(); | ||||
| return gen_control_flow_ops.merge(inputs, name).Item1; | |||||
| return gen_control_flow_ops.merge(inputs, name)[0]; | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -557,8 +557,31 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | throw new NotImplementedException("ZerosLikeOutsideLoop"); | ||||
| return array_ops.zeros_like(val, optimize: false); | return array_ops.zeros_like(val, optimize: false); | ||||
| } | } | ||||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||||
| else | |||||
| { | |||||
| var op_ctxt = op._get_control_flow_context(); | |||||
| if(op_ctxt != null) | |||||
| { | |||||
| // We are in a cond context. Use a switch to create zeros only when needed. | |||||
| var pred = op_ctxt.pred; | |||||
| var branch = op_ctxt.branch; | |||||
| var switch_val = @switch(op.inputs[0], pred)[1 - branch]; | |||||
| var pivot = array_ops.identity(switch_val); | |||||
| if (val.dtype == dtypes.resource) | |||||
| throw new NotImplementedException(""); | |||||
| var zeros_shape = array_ops.shape_internal(switch_val, optimize: false); | |||||
| // Ensure ops created within array_ops.zeros are dominated by switch in | |||||
| // cond context. | |||||
| return tf_with(ops.control_dependencies(new[] { pivot }), delegate | |||||
| { | |||||
| return array_ops.zeros(zeros_shape, dtype: val.dtype); | |||||
| }); | |||||
| } | |||||
| else | |||||
| { | |||||
| return array_ops.zeros_like(val, optimize: false); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -475,7 +475,7 @@ namespace Tensorflow | |||||
| return op.output; | return op.output; | ||||
| } | } | ||||
| public static Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | |||||
| public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -148,11 +148,18 @@ namespace Tensorflow | |||||
| return new []{_op.outputs[0], _op.outputs[1]}; | return new []{_op.outputs[0], _op.outputs[1]}; | ||||
| } | } | ||||
| public static (Tensor, Tensor) merge(Tensor[] inputs, string name = null) | |||||
| public static Tensor[] ref_merge(Tensor[] inputs, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("RefMerge", name, new { inputs }); | |||||
| return _op.outputs; | |||||
| } | |||||
| public static Tensor[] merge(Tensor[] inputs, string name = null) | |||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | var _op = _op_def_lib._apply_op_helper("Merge", name, new { inputs }); | ||||
| return (_op.outputs[0], _op.outputs[1]); | |||||
| return _op.outputs; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -183,5 +183,19 @@ namespace Tensorflow | |||||
| return op.output; | return op.output; | ||||
| } | } | ||||
| public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false, | |||||
| bool half_pixel_centers = false, string name = null) | |||||
| { | |||||
| var op = _op_def_lib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new | |||||
| { | |||||
| grads, | |||||
| size, | |||||
| align_corners, | |||||
| half_pixel_centers | |||||
| }); | |||||
| return op.output; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -105,10 +105,13 @@ namespace Tensorflow | |||||
| if (_handle == IntPtr.Zero) | if (_handle == IntPtr.Zero) | ||||
| { | { | ||||
| var status = new Status(); | |||||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||||
| status.Check(); | |||||
| } else | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||||
| status.Check(); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | { | ||||
| for (int i = 0; i < rank; i++) | for (int i = 0; i < rank; i++) | ||||
| dims[i] = c_api.TF_Dim(_handle, i); | dims[i] = c_api.TF_Dim(_handle, i); | ||||
| @@ -119,14 +122,15 @@ namespace Tensorflow | |||||
| set | set | ||||
| { | { | ||||
| var status = new Status(); | |||||
| if (value == null) | |||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||||
| else | |||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
| using (var status = new Status()) | |||||
| { | |||||
| if (value == null) | |||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||||
| else | |||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
| status.Check(true); | |||||
| status.Check(true); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -142,16 +146,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public void set_shape(TensorShape shape) | public void set_shape(TensorShape shape) | ||||
| { | { | ||||
| this.shape = (int[]) shape.dims.Clone(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||||
| public void SetShape(TensorShape shape) | |||||
| { | |||||
| this.shape = (int[]) shape.dims.Clone(); | |||||
| this.shape = shape.rank > 0 ? shape.dims : null; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -33,6 +33,7 @@ namespace Tensorflow | |||||
| public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | ||||
| public static TF_DataType float16 = TF_DataType.TF_HALF; | public static TF_DataType float16 = TF_DataType.TF_HALF; | ||||
| public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | ||||
| public static TF_DataType resource = TF_DataType.TF_RESOURCE; | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -227,29 +227,30 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("_create_c_op"); | throw new NotImplementedException("_create_c_op"); | ||||
| } | } | ||||
| var status = new Status(); | |||||
| // Add control inputs | |||||
| foreach (var control_input in control_inputs) | |||||
| c_api.TF_AddControlInput(op_desc, control_input); | |||||
| // Add attrs | |||||
| foreach (var attr in node_def.Attr) | |||||
| using (var status = new Status()) | |||||
| { | { | ||||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||||
| uint len = (uint) bytes.Length; | |||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||||
| // Add control inputs | |||||
| foreach (var control_input in control_inputs) | |||||
| c_api.TF_AddControlInput(op_desc, control_input); | |||||
| status.Check(true); | |||||
| } | |||||
| // Add attrs | |||||
| foreach (var attr in node_def.Attr) | |||||
| { | |||||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
| var protoHandle = Marshal.AllocHGlobal(bytes.Length); | |||||
| Marshal.Copy(bytes, 0, protoHandle, bytes.Length); | |||||
| uint len = (uint)bytes.Length; | |||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); | |||||
| status.Check(true); | |||||
| Marshal.FreeHGlobal(protoHandle); | |||||
| } | |||||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
| status.Check(true); | |||||
| status.Check(true); | |||||
| return c_op; | |||||
| return c_op; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -54,6 +54,8 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| List<RefVariable> first_stage_trainable_var_list; | List<RefVariable> first_stage_trainable_var_list; | ||||
| Operation train_op_with_frozen_variables; | Operation train_op_with_frozen_variables; | ||||
| Operation train_op_with_all_variables; | Operation train_op_with_all_variables; | ||||
| Saver loader; | |||||
| Saver saver; | |||||
| #endregion | #endregion | ||||
| public bool Run() | public bool Run() | ||||
| @@ -74,7 +76,9 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| public void Train(Session sess) | public void Train(Session sess) | ||||
| { | { | ||||
| sess.run(tf.global_variables_initializer()); | |||||
| print($"=> Restoring weights from: {cfg.TRAIN.INITIAL_WEIGHT} ... "); | |||||
| loader.restore(sess, cfg.TRAIN.INITIAL_WEIGHT); | |||||
| } | } | ||||
| public void Test(Session sess) | public void Test(Session sess) | ||||
| @@ -184,6 +188,21 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| }); | }); | ||||
| }); | }); | ||||
| tf_with(tf.name_scope("loader_and_saver"), delegate | |||||
| { | |||||
| loader = tf.train.Saver(net_var); | |||||
| saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10); | |||||
| }); | |||||
| tf_with(tf.name_scope("summary"), delegate | |||||
| { | |||||
| tf.summary.scalar("learn_rate", learn_rate); | |||||
| tf.summary.scalar("giou_loss", giou_loss); | |||||
| tf.summary.scalar("conf_loss", conf_loss); | |||||
| tf.summary.scalar("prob_loss", prob_loss); | |||||
| tf.summary.scalar("total_loss", loss); | |||||
| }); | |||||
| return graph; | return graph; | ||||
| } | } | ||||
| @@ -60,7 +60,7 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO | |||||
| public TrainConfig(string root) | public TrainConfig(string root) | ||||
| { | { | ||||
| _root = root; | _root = root; | ||||
| INITIAL_WEIGHT = Path.Combine(_root, "data", "checkpoint", "yolov3_coco_demo.ckpt"); | |||||
| INITIAL_WEIGHT = Path.Combine(_root, "checkpoint", "yolov3_coco_demo.ckpt"); | |||||
| ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); | ANNOT_PATH = Path.Combine(_root, "data", "dataset", "voc_train.txt"); | ||||
| } | } | ||||
| } | } | ||||