From aea62f641b8a0f74f2191898e4a90be415f5de1e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 5 Dec 2020 07:41:43 -0600 Subject: [PATCH] Use safe AttrValue to invoke TF_SetAttrValueProto. --- src/TensorFlowNET.Core/Attributes/c_api.ops.cs | 2 +- src/TensorFlowNET.Core/Contexts/ContextSwitch.cs | 3 +++ src/TensorFlowNET.Core/Operations/Operation.cs | 2 +- src/TensorFlowNET.Core/ops.cs | 8 ++------ 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs index 1476d4d3..1815b477 100644 --- a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs @@ -61,7 +61,7 @@ namespace Tensorflow public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); [DllImport(TensorFlowLibName)] - public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, uint proto_len, SafeStatusHandle status); + public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); /// /// Set `num_dims` to -1 to represent "unknown rank". diff --git a/src/TensorFlowNET.Core/Contexts/ContextSwitch.cs b/src/TensorFlowNET.Core/Contexts/ContextSwitch.cs index 02f2fd9c..4046e877 100644 --- a/src/TensorFlowNET.Core/Contexts/ContextSwitch.cs +++ b/src/TensorFlowNET.Core/Contexts/ContextSwitch.cs @@ -33,5 +33,8 @@ namespace Tensorflow.Contexts public Action EnterContextFn { get; set; } public string DeviceStack { get; set; } + + public override string ToString() + => $"EagerMode: {EagerMode}, IsBuildingFunction: {IsBuildingFunction}"; } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 69f92c30..01afd489 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -168,7 +168,7 @@ namespace Tensorflow if (op_def == null) op_def = g.GetOpDef(node_def.Op); - (_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); + (_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def); _is_stateful = op_def.IsStateful; // Initialize self._outputs. diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 67926caf..573eda16 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -190,13 +190,9 @@ namespace Tensorflow // Add attrs foreach (var attr in node_def.Attr) { - var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. - var protoHandle = Marshal.AllocHGlobal(bytes.Length); - Marshal.Copy(bytes, 0, protoHandle, bytes.Length); - uint len = (uint)bytes.Length; - c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status.Handle); + var bytes = attr.Value.ToByteArray(); + c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle); status.Check(true); - Marshal.FreeHGlobal(protoHandle); } var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);