You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Operation.cs 9.6 kB

7 years ago
7 years ago
7 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
7 years ago
6 years ago
7 years ago
7 years ago
7 years ago
6 years ago
6 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. using Google.Protobuf.Collections;
  2. //using Newtonsoft.Json;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using System.Text;
  8. namespace Tensorflow
  9. {
  10. /// <summary>
  11. /// Represents a graph node that performs computation on tensors.
  12. ///
  13. /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
  14. /// more `Tensor` objects as input, and produces zero or more `Tensor`
  15. /// objects as output. Objects of type `Operation` are created by
  16. /// calling an op constructor(such as `tf.matmul`)
  17. /// or `tf.Graph.create_op`.
  18. ///
  19. /// For example `c = tf.matmul(a, b)` creates an `Operation` of type
  20. /// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
  21. /// as output.
  22. ///
  23. /// After the graph has been launched in a session, an `Operation` can
  24. /// be executed by passing it to
  25. /// `tf.Session.run`.
  26. /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
  27. /// </summary>
  28. public partial class Operation : ITensorOrOperation
  29. {
  30. private readonly IntPtr _handle; // _c_op in python
  31. private readonly IntPtr _operDesc;
  32. private Graph _graph;
  33. //[JsonIgnore]
  34. public Graph graph => _graph;
  35. //[JsonIgnore]
  36. public int _id => _id_value;
  37. //[JsonIgnore]
  38. public int _id_value;
  39. public string type => OpType;
  40. //[JsonIgnore]
  41. public Operation op => this;
  42. public TF_DataType dtype => TF_DataType.DtInvalid;
  43. private Status status = new Status();
  44. public string name => c_api.StringPiece(c_api.TF_OperationName(_handle));
  45. public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
  46. public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
  47. private NodeDef _node_def;
  48. public NodeDef node_def
  49. {
  50. get
  51. {
  52. if(_node_def == null)
  53. _node_def = GetNodeDef();
  54. return _node_def;
  55. }
  56. }
  57. public Operation(IntPtr handle)
  58. {
  59. if (handle == IntPtr.Zero)
  60. return;
  61. _handle = handle;
  62. _graph = ops.get_default_graph();
  63. _outputs = new Tensor[NumOutputs];
  64. for (int i = 0; i < NumOutputs; i++)
  65. _outputs[i] = new Tensor(this, i, OutputType(i));
  66. }
  67. public Operation(Graph g, string opType, string oper_name)
  68. {
  69. _graph = g;
  70. _operDesc = c_api.TF_NewOperation(g, opType, oper_name);
  71. c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
  72. _handle = c_api.TF_FinishOperation(_operDesc, status);
  73. }
  74. /// <summary>
  75. /// Creates an `Operation`.
  76. /// </summary>
  77. /// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param>
  78. /// <param name="g">`Graph`. The parent graph.</param>
  79. /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
  80. /// <param name="output_types">list of `DType` objects.</param>
  81. /// <param name="control_inputs">
  82. /// list of operations or tensors from which to have a
  83. /// control dependency.
  84. /// </param>
  85. /// <param name="input_types">
  86. /// List of `DType` objects representing the
  87. /// types of the tensors accepted by the `Operation`. By default
  88. /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
  89. /// reference-typed inputs must specify these explicitly.
  90. /// </param>
  91. /// <param name="original_op"></param>
  92. /// <param name="op_def"></param>
  93. public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
  94. {
  95. _graph = g;
  96. // Build the list of control inputs.
  97. var control_input_ops = new List<Operation>();
  98. if(control_inputs != null)
  99. {
  100. foreach(var c in control_inputs)
  101. {
  102. switch (c)
  103. {
  104. case Operation c1:
  105. control_input_ops.Add(c1);
  106. break;
  107. case Tensor tensor:
  108. control_input_ops.Add(tensor.op);
  109. break;
  110. // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
  111. //case IndexedSlices islices:
  112. // control_input_ops.Add(islices.op);
  113. // break;
  114. default:
  115. throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
  116. }
  117. }
  118. }
  119. // Dict mapping op name to file and line information for op colocation
  120. // context managers.
  121. _control_flow_context = graph._get_control_flow_context();
  122. // This will be set by self.inputs.
  123. if (op_def == null)
  124. op_def = g.GetOpDef(node_def.Op);
  125. var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
  126. (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
  127. // Initialize self._outputs.
  128. output_types = new TF_DataType[NumOutputs];
  129. for (int i = 0; i < NumOutputs; i++)
  130. output_types[i] = OutputType(i);
  131. _outputs = new Tensor[NumOutputs];
  132. for (int i = 0; i < NumOutputs; i++)
  133. _outputs[i] = new Tensor(this, i, OutputType(i));
  134. graph._add_op(this);
  135. if (_handle != IntPtr.Zero)
  136. _control_flow_post_processing();
  137. }
  138. public void run(FeedItem[] feed_dict = null, Session session = null)
  139. {
  140. ops._run_using_default_session(this, feed_dict, graph, session);
  141. }
  142. private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)
  143. {
  144. var grouped_inputs = new List<object>();
  145. int i = 0;
  146. int input_len = 0;
  147. bool is_sequence = false;
  148. foreach (var input_arg in op_def.InputArg)
  149. {
  150. if (!string.IsNullOrEmpty(input_arg.NumberAttr))
  151. {
  152. input_len = (int)attrs[input_arg.NumberAttr].I;
  153. is_sequence = true;
  154. }
  155. else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
  156. {
  157. input_len = attrs[input_arg.TypeListAttr].List.Type.Count;
  158. is_sequence = true;
  159. }
  160. else
  161. {
  162. input_len = 1;
  163. is_sequence = false;
  164. }
  165. if (is_sequence)
  166. grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray());
  167. else
  168. grouped_inputs.Add(inputs[i]);
  169. i += input_len;
  170. }
  171. return grouped_inputs.ToArray();
  172. }
  173. public object get_attr(string name)
  174. {
  175. AttrValue x = null;
  176. using (var buf = new Buffer())
  177. {
  178. c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
  179. status.Check(true);
  180. x = AttrValue.Parser.ParseFrom(buf);
  181. }
  182. string oneof_value = x.ValueCase.ToString();
  183. if (string.IsNullOrEmpty(oneof_value))
  184. return null;
  185. if(oneof_value == "list")
  186. throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
  187. if (oneof_value == "type")
  188. return x.Type;
  189. object result = x.GetType().GetProperty(oneof_value).GetValue(x);
  190. if (result is Google.Protobuf.ByteString byteString)
  191. return byteString.ToStringUtf8();
  192. return result;
  193. }
  194. public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
  195. {
  196. return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
  197. }
  198. private NodeDef GetNodeDef()
  199. {
  200. using (var s = new Status())
  201. using (var buffer = new Buffer())
  202. {
  203. c_api.TF_OperationToNodeDef(_handle, buffer, s);
  204. s.Check();
  205. return NodeDef.Parser.ParseFrom(buffer);
  206. }
  207. }
  208. public override string ToString()
  209. {
  210. return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
  211. }
  212. public static implicit operator Operation(IntPtr handle) => new Operation(handle);
  213. public static implicit operator IntPtr(Operation op) => op._handle;
  214. public override bool Equals(object obj)
  215. {
  216. switch (obj)
  217. {
  218. case IntPtr val:
  219. return val == _handle;
  220. case Operation val:
  221. return val._handle == _handle;
  222. }
  223. return base.Equals(obj);
  224. }
  225. }
  226. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。