| @@ -61,7 +61,7 @@ namespace Tensorflow | |||||
| public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); | public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, uint proto_len, SafeStatusHandle status); | |||||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Set `num_dims` to -1 to represent "unknown rank". | /// Set `num_dims` to -1 to represent "unknown rank". | ||||
| @@ -33,5 +33,8 @@ namespace Tensorflow.Contexts | |||||
| public Action EnterContextFn { get; set; } | public Action EnterContextFn { get; set; } | ||||
| public string DeviceStack { get; set; } | public string DeviceStack { get; set; } | ||||
| public override string ToString() | |||||
| => $"EagerMode: {EagerMode}, IsBuildingFunction: {IsBuildingFunction}"; | |||||
| } | } | ||||
| } | } | ||||
| @@ -168,7 +168,7 @@ namespace Tensorflow | |||||
| if (op_def == null) | if (op_def == null) | ||||
| op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
| (_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); | |||||
| (_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def); | |||||
| _is_stateful = op_def.IsStateful; | _is_stateful = op_def.IsStateful; | ||||
| // Initialize self._outputs. | // Initialize self._outputs. | ||||
| @@ -190,13 +190,9 @@ namespace Tensorflow | |||||
| // Add attrs | // Add attrs | ||||
| foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
| { | { | ||||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
| var protoHandle = Marshal.AllocHGlobal(bytes.Length); | |||||
| Marshal.Copy(bytes, 0, protoHandle, bytes.Length); | |||||
| uint len = (uint)bytes.Length; | |||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status.Handle); | |||||
| var bytes = attr.Value.ToByteArray(); | |||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle); | |||||
| status.Check(true); | status.Check(true); | ||||
| Marshal.FreeHGlobal(protoHandle); | |||||
| } | } | ||||
| var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); | var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); | ||||