|
|
|
@@ -25,17 +25,17 @@ namespace Tensorflow |
|
|
|
|
|
|
|
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) |
|
|
|
{ |
|
|
|
var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); |
|
|
|
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); |
|
|
|
|
|
|
|
// Add inputs |
|
|
|
if(inputs != null && inputs.Count > 0) |
|
|
|
{ |
|
|
|
foreach (var op_input in inputs) |
|
|
|
/*foreach (var op_input in inputs) |
|
|
|
{ |
|
|
|
c_api.TF_AddInput(op_desc, op_input._as_tf_output()); |
|
|
|
} |
|
|
|
}*/ |
|
|
|
|
|
|
|
//c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); |
|
|
|
c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); |
|
|
|
} |
|
|
|
|
|
|
|
var status = new Status(); |
|
|
|
@@ -48,9 +48,10 @@ namespace Tensorflow |
|
|
|
var bytes = attr.Value.ToByteArray(); |
|
|
|
var proto = Marshal.AllocHGlobal(bytes.Length); |
|
|
|
Marshal.Copy(bytes, 0, proto, bytes.Length); |
|
|
|
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status); |
|
|
|
|
|
|
|
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status); |
|
|
|
|
|
|
|
if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); |
|
|
|
status.Check(true); |
|
|
|
} |
|
|
|
|
|
|
|
var c_op = c_api.TF_FinishOperation(op_desc, status); |
|
|
|
@@ -60,6 +61,11 @@ namespace Tensorflow |
|
|
|
return c_op; |
|
|
|
} |
|
|
|
|
|
|
|
public static OpDef _get_op_def(Graph graph, string type) |
|
|
|
{ |
|
|
|
return graph.GetOpDef(type); |
|
|
|
} |
|
|
|
|
|
|
|
public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null) |
|
|
|
{ |
|
|
|
var node_def = new node_def_pb2.NodeDef(); |
|
|
|
|