diff --git a/src/TensorFlowNET.Core/APIs/c_api.customize.cs b/src/TensorFlowNET.Core/APIs/c_api.customize.cs new file mode 100644 index 00000000..173bdbe2 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/c_api.customize.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public partial class c_api + { + [DllImport(TensorFlowLibName)] + public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); + } +} diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 8524f724..fbebd4d6 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -223,7 +223,7 @@ namespace Tensorflow.Functions { input_tangents = new TangentInfo(); } - if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) + if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient()) { if(input_tangents.Indices is not null || executing_eagerly) { diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index ca00710c..4261d72b 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -317,27 +317,18 @@ namespace Tensorflow { Debug.Assert(types.Length == shapes.Length); int orig_num_outputs = this.outputs.Length; - //var new_outputs = new List(_outputs); - - var old_outputs = _outputs; - _outputs = new Tensor[orig_num_outputs + types.Length]; - for(int i = 0; i < orig_num_outputs; i++) - { - _outputs[i] = old_outputs[i]; - } + var new_outputs = new List(_outputs); // Since the `_outputs` is defined as `Array`, when we add new output, we // have to create a new array, which brings some performance concerns. // In the future maybe the type of `outputs` should be reconsidered. for(int i = 0; i < types.Length; i++) { - var t = new Tensor(this, orig_num_outputs + 1, types[i]); - _outputs[i] = t; - //t = tf.ensure_shape(t, shapes[i]); + var t = new Tensor(this, orig_num_outputs + i, types[i]); t.shape = shapes[i]; - //new_outputs.Add(t); + new_outputs.Add(t); } - //_outputs = new_outputs.ToArray(); + _outputs = new_outputs.ToArray(); } internal void _set_func_attr(string attr_name, string func_name) @@ -372,23 +363,9 @@ namespace Tensorflow internal void _set_attr_with_buf(string attr_name, Buffer attr_buf) { - //if(_op_desc is null) - //{ - // //var new_node_def = NodeDef.Parser.ParseFrom(node_def.ToByteArray()); - // //new_node_def.Name += "_temp"; - // //var op = new Operation(new_node_def, graph, inputs, _output_types, control_inputs, _input_types); - // //Status status = new(); - // //c_api.TF_SetAttrBool(op._op_desc, "trainable", true); - // ////c_api.TF_SetAttrValueProto(op._op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status); - // //status.Check(true); - // // TODO(Rinne): deal with it. Give a warning or make the Operation always contains `op_desc`. - //} - //else - //{ - // //Status status = new(); - // //c_api.TF_SetAttrValueProto(_op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status); - // //status.Check(true); - //} + Status status = new(); + c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status); + status.Check(true); } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 6ca65a07..c0e5d435 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -135,7 +135,7 @@ namespace Tensorflow protected virtual void SetShapeInternal(Shape value) { - if (value == null) + if (value is null || value.ndim == 0 || value.ndim == -1) c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), null, -1, tf.Status); else c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), value.dims, value.ndim, tf.Status);