You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm_omp.cc 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #include <cblas.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <math.h>
  5. #include <vector>
  6. //------------------------------------------------------------------------------
  7. void fill_rand( int m, int n, double* A, int ld )
  8. {
  9. for (int j = 0; j < n; ++j) {
  10. for (int i = 0; i < m; ++i) {
  11. A[ i + j*ld ] = rand() / double(RAND_MAX);
  12. }
  13. }
  14. }
  15. //------------------------------------------------------------------------------
  16. inline double max_nan( double x, double y )
  17. {
  18. return (isnan(y) || (y) >= (x) ? (y) : (x));
  19. }
  20. //------------------------------------------------------------------------------
  21. int main( int argc, char** argv )
  22. {
  23. int batch_size = 1000;
  24. int n = 50;
  25. if (argc > 1)
  26. batch_size = atoi( argv[1] );
  27. if (argc > 2)
  28. n = atoi( argv[2] );
  29. printf( "batch_size %d, n %d\n", batch_size, n );
  30. int ld = n;
  31. double alpha = 3.1416;
  32. double beta = 2.7183;
  33. for (int loops = 0; loops <20; loops ++) {
  34. printf( "init %d\n", loops );
  35. std::vector<double*> A_array( batch_size ),
  36. B_array( batch_size ),
  37. C_array( batch_size ),
  38. D_array( batch_size );
  39. for (int i = 0; i < batch_size; ++i) {
  40. A_array[ i ] = new double[ ld*n ];
  41. B_array[ i ] = new double[ ld*n ];
  42. C_array[ i ] = new double[ ld*n ];
  43. D_array[ i ] = new double[ ld*n ];
  44. fill_rand( n, n, A_array[ i ], ld );
  45. fill_rand( n, n, B_array[ i ], ld );
  46. fill_rand( n, n, C_array[ i ], ld );
  47. std::copy( C_array[ i ], C_array[ i ] + ld*n, D_array[ i ] );
  48. }
  49. printf( "test\n" );
  50. for (int i = 0; i < batch_size; ++i) {
  51. cblas_dgemm( CblasColMajor, CblasNoTrans, CblasNoTrans, n, n, n,
  52. alpha, A_array[ i ], ld, B_array[ i ], ld,
  53. beta, C_array[ i ], ld );
  54. }
  55. printf( "test OpenMP\n" );
  56. #pragma omp parallel for
  57. for (int i = 0; i < batch_size; ++i) {
  58. cblas_dgemm( CblasColMajor, CblasNoTrans, CblasNoTrans, n, n, n,
  59. alpha, A_array[ i ], ld, B_array[ i ], ld,
  60. beta, D_array[ i ], ld );
  61. }
  62. printf( "compare\n" );
  63. double max_error = 0;
  64. for (int i = 0; i < batch_size; ++i) {
  65. // norm( D - C )
  66. cblas_daxpy( ld*n, -1.0, C_array[ i ], 1, D_array[ i ], 1 );
  67. double error = cblas_dnrm2( ld*n, D_array[ i ], 1 );
  68. max_error = max_nan( error, max_error );
  69. }
  70. printf( "max error %.2e\n", max_error );
  71. printf( "delete\n" );
  72. for (int i = 0; i < batch_size; ++i) {
  73. delete [] A_array[ i ];
  74. delete [] B_array[ i ];
  75. delete [] C_array[ i ];
  76. }
  77. printf( "done %d\n", loops );
  78. }
  79. printf( "all done\n");
  80. return 0;
  81. }