You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

PythonTest.cs 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Newtonsoft.Json.Linq;
  3. using Tensorflow.NumPy;
  4. using System;
  5. using System.Collections;
  6. using System.Linq;
  7. using Tensorflow;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. /// <summary>
  12. /// Use as base class for test classes to get additional assertions
  13. /// </summary>
  14. public class PythonTest
  15. {
  16. #region python compatibility layer
  17. protected PythonTest self { get => this; }
  18. protected int None => -1;
  19. #endregion
  20. #region pytest assertions
  21. public void assertItemsEqual(ICollection given, ICollection expected)
  22. {
  23. if (given is Hashtable && expected is Hashtable)
  24. {
  25. Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString());
  26. return;
  27. }
  28. Assert.IsNotNull(expected);
  29. Assert.IsNotNull(given);
  30. var e = expected.OfType<object>().ToArray();
  31. var g = given.OfType<object>().ToArray();
  32. Assert.AreEqual(e.Length, g.Length, $"The collections differ in length expected {e.Length} but got {g.Length}");
  33. for (int i = 0; i < e.Length; i++)
  34. {
  35. /*if (g[i] is NDArray && e[i] is NDArray)
  36. assertItemsEqual((g[i] as NDArray).GetData<object>(), (e[i] as NDArray).GetData<object>());
  37. else*/
  38. if (e[i] is ICollection && g[i] is ICollection)
  39. assertEqual(g[i], e[i]);
  40. else
  41. Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}");
  42. }
  43. }
  44. public void assertAllEqual(ICollection given, ICollection expected)
  45. {
  46. assertItemsEqual(given, expected);
  47. }
  48. public void assertFloat32Equal(float expected, float actual, string msg)
  49. {
  50. float eps = 1e-6f;
  51. Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}");
  52. }
  53. public void assertFloat64Equal(double expected, double actual, string msg)
  54. {
  55. double eps = 1e-16f;
  56. Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}");
  57. }
  58. public void AssetSequenceEqual(float[] expected, float[] actual)
  59. {
  60. float eps = 1e-5f;
  61. for (int i = 0; i < expected.Length; i++)
  62. Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}");
  63. }
  64. public void AssetSequenceEqual(double[] expected, double[] actual)
  65. {
  66. double eps = 1e-5f;
  67. for (int i = 0; i < expected.Length; i++)
  68. Assert.IsTrue(Math.Abs(expected[i] - actual[i]) < eps * Math.Max(1.0f, Math.Abs(expected[i])), $"expected {expected} vs actual {actual}");
  69. }
  70. public void assertEqual(object given, object expected)
  71. {
  72. /*if (given is NDArray && expected is NDArray)
  73. {
  74. assertItemsEqual((given as NDArray).GetData<object>(), (expected as NDArray).GetData<object>());
  75. return;
  76. }*/
  77. if (given is Hashtable && expected is Hashtable)
  78. {
  79. Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString());
  80. return;
  81. }
  82. if (given is ICollection && expected is ICollection)
  83. {
  84. assertItemsEqual(given as ICollection, expected as ICollection);
  85. return;
  86. }
  87. if (given is float && expected is float)
  88. {
  89. assertFloat32Equal((float)expected, (float)given, "");
  90. return;
  91. }
  92. if (given is double && expected is double)
  93. {
  94. assertFloat64Equal((double)expected, (double)given, "");
  95. return;
  96. }
  97. Assert.AreEqual(expected, given);
  98. }
  99. public void assertEquals(object given, object expected)
  100. {
  101. assertEqual(given, expected);
  102. }
  103. public void assert(object given)
  104. {
  105. if (given is bool)
  106. Assert.IsTrue((bool)given);
  107. Assert.IsNotNull(given);
  108. }
  109. public void assertIsNotNone(object given)
  110. {
  111. Assert.IsNotNull(given);
  112. }
  113. public void assertFalse(bool cond)
  114. {
  115. Assert.IsFalse(cond);
  116. }
  117. public void assertTrue(bool cond)
  118. {
  119. Assert.IsTrue(cond);
  120. }
  121. public void assertAllClose(NDArray array1, NDArray array2, double eps = 1e-5)
  122. {
  123. Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
  124. }
  125. public void assertAllClose(double value, NDArray array2, double eps = 1e-5)
  126. {
  127. var array1 = np.ones_like(array2) * value;
  128. Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
  129. }
  130. private class CollectionComparer : System.Collections.IComparer
  131. {
  132. private readonly double _epsilon;
  133. public CollectionComparer(double eps = 1e-06) {
  134. _epsilon = eps;
  135. }
  136. public int Compare(object x, object y)
  137. {
  138. var a = (double)x;
  139. var b = (double)y;
  140. double delta = Math.Abs(a - b);
  141. if (delta < _epsilon)
  142. {
  143. return 0;
  144. }
  145. return a.CompareTo(b);
  146. }
  147. }
  148. public void assertAllCloseAccordingToType<T>(
  149. T[] expected,
  150. T[] given,
  151. double eps = 1e-6,
  152. float float_eps = 1e-6f)
  153. {
  154. // TODO: check if any of arguments is not double and change toletance
  155. CollectionAssert.AreEqual(expected, given, new CollectionComparer(eps));
  156. }
  157. public void assertProtoEquals(object toProto, object o)
  158. {
  159. throw new NotImplementedException();
  160. }
  161. #endregion
  162. #region tensor evaluation and test session
  163. private Session _cached_session = null;
  164. private Graph _cached_graph = null;
  165. private object _cached_config = null;
  166. private bool _cached_force_gpu = false;
  167. private void _ClearCachedSession()
  168. {
  169. if (self._cached_session != null)
  170. {
  171. self._cached_session.Dispose();
  172. self._cached_session = null;
  173. }
  174. }
  175. //protected object _eval_helper(Tensor[] tensors)
  176. //{
  177. // if (tensors == null)
  178. // return null;
  179. // return nest.map_structure(self._eval_tensor, tensors);
  180. //}
  181. protected object _eval_tensor(object tensor)
  182. {
  183. if (tensor == null)
  184. return None;
  185. //else if (callable(tensor))
  186. // return self._eval_helper(tensor())
  187. else
  188. {
  189. try
  190. {
  191. //TODO:
  192. // if sparse_tensor.is_sparse(tensor):
  193. // return sparse_tensor.SparseTensorValue(tensor.indices, tensor.values,
  194. // tensor.dense_shape)
  195. //return (tensor as Tensor).numpy();
  196. }
  197. catch (Exception)
  198. {
  199. throw new ValueError("Unsupported type: " + tensor.GetType());
  200. }
  201. return null;
  202. }
  203. }
  204. /// <summary>
  205. /// This function is used in many original tensorflow unit tests to evaluate tensors
  206. /// in a test session with special settings (for instance constant folding off)
  207. ///
  208. /// </summary>
  209. public T evaluate<T>(Tensor tensor)
  210. {
  211. object result = null;
  212. // if context.executing_eagerly():
  213. // return self._eval_helper(tensors)
  214. // else:
  215. {
  216. var sess = tf.Session();
  217. var ndarray = tensor.eval(sess);
  218. if (typeof(T) == typeof(double))
  219. {
  220. double x = ndarray;
  221. result = x;
  222. }
  223. else if (typeof(T) == typeof(int))
  224. {
  225. int x = ndarray;
  226. result = x;
  227. }
  228. else
  229. {
  230. result = ndarray;
  231. }
  232. return (T)result;
  233. }
  234. }
  235. ///Returns a TensorFlow Session for use in executing tests.
  236. public Session cached_session(
  237. Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
  238. {
  239. // This method behaves differently than self.session(): for performance reasons
  240. // `cached_session` will by default reuse the same session within the same
  241. // test.The session returned by this function will only be closed at the end
  242. // of the test(in the TearDown function).
  243. // Use the `use_gpu` and `force_gpu` options to control where ops are run.If
  244. // `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if
  245. // `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
  246. // possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
  247. // the CPU.
  248. // Example:
  249. // python
  250. // class MyOperatorTest(test_util.TensorFlowTestCase) :
  251. // def testMyOperator(self):
  252. // with self.cached_session() as sess:
  253. // valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
  254. // result = MyOperator(valid_input).eval()
  255. // self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
  256. // invalid_input = [-1.0, 2.0, 7.0]
  257. // with self.assertRaisesOpError("negative input not supported"):
  258. // MyOperator(invalid_input).eval()
  259. // Args:
  260. // graph: Optional graph to use during the returned session.
  261. // config: An optional config_pb2.ConfigProto to use to configure the
  262. // session.
  263. // use_gpu: If True, attempt to run as many ops as possible on GPU.
  264. // force_gpu: If True, pin all ops to `/device:GPU:0`.
  265. // Yields:
  266. // A Session object that should be used as a context manager to surround
  267. // the graph building and execution code in a test case.
  268. // TODO:
  269. // if context.executing_eagerly():
  270. // return self._eval_helper(tensors)
  271. // else:
  272. {
  273. var sess = self._get_cached_session(
  274. graph, config, force_gpu, crash_if_inconsistent_args: true);
  275. using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu);
  276. return cached;
  277. }
  278. }
  279. //Returns a TensorFlow Session for use in executing tests.
  280. public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
  281. {
  282. //Note that this will set this session and the graph as global defaults.
  283. //Use the `use_gpu` and `force_gpu` options to control where ops are run.If
  284. //`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
  285. //`use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
  286. //possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
  287. //the CPU.
  288. //Example:
  289. //```python
  290. //class MyOperatorTest(test_util.TensorFlowTestCase):
  291. // def testMyOperator(self):
  292. // with self.session(use_gpu= True):
  293. // valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
  294. // result = MyOperator(valid_input).eval()
  295. // self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
  296. // invalid_input = [-1.0, 2.0, 7.0]
  297. // with self.assertRaisesOpError("negative input not supported"):
  298. // MyOperator(invalid_input).eval()
  299. //```
  300. //Args:
  301. // graph: Optional graph to use during the returned session.
  302. // config: An optional config_pb2.ConfigProto to use to configure the
  303. // session.
  304. // use_gpu: If True, attempt to run as many ops as possible on GPU.
  305. // force_gpu: If True, pin all ops to `/device:GPU:0`.
  306. //Yields:
  307. // A Session object that should be used as a context manager to surround
  308. // the graph building and execution code in a test case.
  309. Session s = null;
  310. //if (context.executing_eagerly())
  311. // yield None
  312. //else
  313. //{
  314. s = self._create_session(graph, config, force_gpu);
  315. //}
  316. return s.as_default();
  317. }
  318. private Session _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu)
  319. {
  320. // Set the session and its graph to global default and constrain devices."""
  321. if (tf.executing_eagerly())
  322. return null;
  323. else
  324. {
  325. sess.graph.as_default();
  326. sess.as_default();
  327. {
  328. if (force_gpu)
  329. {
  330. // TODO:
  331. // Use the name of an actual device if one is detected, or
  332. // '/device:GPU:0' otherwise
  333. /* var gpu_name = gpu_device_name();
  334. if (!gpu_name)
  335. gpu_name = "/device:GPU:0"
  336. using (sess.graph.device(gpu_name)) {
  337. yield return sess;
  338. }*/
  339. return sess;
  340. }
  341. else if (use_gpu)
  342. return sess;
  343. else
  344. using (sess.graph.device("/device:CPU:0"))
  345. return sess;
  346. }
  347. }
  348. }
  349. // See session() for details.
  350. private Session _create_session(Graph graph, object cfg, bool forceGpu)
  351. {
  352. var prepare_config = new Func<object, object>((config) =>
  353. {
  354. // """Returns a config for sessions.
  355. // Args:
  356. // config: An optional config_pb2.ConfigProto to use to configure the
  357. // session.
  358. // Returns:
  359. // A config_pb2.ConfigProto object.
  360. //TODO: config
  361. // # use_gpu=False. Currently many tests rely on the fact that any device
  362. // # will be used even when a specific device is supposed to be used.
  363. // allow_soft_placement = not force_gpu
  364. // if config is None:
  365. // config = config_pb2.ConfigProto()
  366. // config.allow_soft_placement = allow_soft_placement
  367. // config.gpu_options.per_process_gpu_memory_fraction = 0.3
  368. // elif not allow_soft_placement and config.allow_soft_placement:
  369. // config_copy = config_pb2.ConfigProto()
  370. // config_copy.CopyFrom(config)
  371. // config = config_copy
  372. // config.allow_soft_placement = False
  373. // # Don't perform optimizations for tests so we don't inadvertently run
  374. // # gpu ops on cpu
  375. // config.graph_options.optimizer_options.opt_level = -1
  376. // # Disable Grappler constant folding since some tests & benchmarks
  377. // # use constant input and become meaningless after constant folding.
  378. // # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
  379. // # GRAPPLER TEAM.
  380. // config.graph_options.rewrite_options.constant_folding = (
  381. // rewriter_config_pb2.RewriterConfig.OFF)
  382. // config.graph_options.rewrite_options.pin_to_host_optimization = (
  383. // rewriter_config_pb2.RewriterConfig.OFF)
  384. return config;
  385. });
  386. //TODO: use this instead of normal session
  387. //return new ErrorLoggingSession(graph = graph, config = prepare_config(config))
  388. return new Session(graph);//, config = prepare_config(config))
  389. }
  390. private Session _get_cached_session(
  391. Graph graph = null,
  392. object config = null,
  393. bool force_gpu = false,
  394. bool crash_if_inconsistent_args = true)
  395. {
  396. // See cached_session() for documentation.
  397. if (self._cached_session == null)
  398. {
  399. var sess = self._create_session(graph, config, force_gpu);
  400. self._cached_session = sess;
  401. self._cached_graph = graph;
  402. self._cached_config = config;
  403. self._cached_force_gpu = force_gpu;
  404. return sess;
  405. }
  406. else
  407. {
  408. if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph))
  409. throw new ValueError(@"The graph used to get the cached session is
  410. different than the one that was used to create the
  411. session. Maybe create a new session with
  412. self.session()");
  413. if (crash_if_inconsistent_args && !self._cached_config.Equals(config))
  414. {
  415. throw new ValueError(@"The config used to get the cached session is
  416. different than the one that was used to create the
  417. session. Maybe create a new session with
  418. self.session()");
  419. }
  420. if (crash_if_inconsistent_args && !self._cached_force_gpu.Equals(force_gpu))
  421. {
  422. throw new ValueError(@"The force_gpu value used to get the cached session is
  423. different than the one that was used to create the
  424. session. Maybe create a new session with
  425. self.session()");
  426. }
  427. return _cached_session;
  428. }
  429. }
  430. [TestCleanup]
  431. public void Cleanup()
  432. {
  433. _ClearCachedSession();
  434. }
  435. #endregion
  436. public void AssetSequenceEqual<T>(T[] a, T[] b)
  437. {
  438. Assert.IsTrue(Enumerable.SequenceEqual(a, b));
  439. }
  440. }
  441. }