Browse Source

Use safe AttrValue to invoke TF_SetAttrValueProto.

tags/v0.30
Oceania2018 5 years ago
parent
commit
aea62f641b
4 changed files with 7 additions and 8 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Attributes/c_api.ops.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Contexts/ContextSwitch.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +2
    -6
      src/TensorFlowNET.Core/ops.cs

+ 1
- 1
src/TensorFlowNET.Core/Attributes/c_api.ops.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, uint proto_len, SafeStatusHandle status);
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);

/// <summary>
/// Set `num_dims` to -1 to represent "unknown rank".


+ 3
- 0
src/TensorFlowNET.Core/Contexts/ContextSwitch.cs View File

@@ -33,5 +33,8 @@ namespace Tensorflow.Contexts
public Action EnterContextFn { get; set; }

public string DeviceStack { get; set; }

public override string ToString()
=> $"EagerMode: {EagerMode}, IsBuildingFunction: {IsBuildingFunction}";
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -168,7 +168,7 @@ namespace Tensorflow
if (op_def == null)
op_def = g.GetOpDef(node_def.Op);

(_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray());
(_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def);
_is_stateful = op_def.IsStateful;

// Initialize self._outputs.


+ 2
- 6
src/TensorFlowNET.Core/ops.cs View File

@@ -190,13 +190,9 @@ namespace Tensorflow
// Add attrs
foreach (var attr in node_def.Attr)
{
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
var protoHandle = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, protoHandle, bytes.Length);
uint len = (uint)bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status.Handle);
var bytes = attr.Value.ToByteArray();
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle);
status.Check(true);
Marshal.FreeHGlobal(protoHandle);
}

var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);


Loading…
Cancel
Save