You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gemm_omp.cc 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #include <cblas.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <math.h>
  5. #include <vector>
  6. //------------------------------------------------------------------------------
  7. void fill_rand( int m, int n, double* A, int ld )
  8. {
  9. for (int j = 0; j < n; ++j) {
  10. for (int i = 0; i < m; ++i) {
  11. A[ i + j*ld ] = rand() / double(RAND_MAX);
  12. }
  13. }
  14. }
  15. //------------------------------------------------------------------------------
  16. inline double max_nan( double x, double y )
  17. {
  18. return (isnan(y) || (y) >= (x) ? (y) : (x));
  19. }
  20. //------------------------------------------------------------------------------
  21. int main( int argc, char** argv )
  22. {
  23. int batch_size = 1000;
  24. int n = 50;
  25. if (argc > 1)
  26. batch_size = atoi( argv[1] );
  27. if (argc > 2)
  28. n = atoi( argv[2] );
  29. printf( "batch_size %d, n %d\n", batch_size, n );
  30. int ld = n;
  31. double alpha = 3.1416;
  32. double beta = 2.7183;
  33. printf( "init\n" );
  34. std::vector<double*> A_array( batch_size ),
  35. B_array( batch_size ),
  36. C_array( batch_size ),
  37. D_array( batch_size );
  38. for (int i = 0; i < batch_size; ++i) {
  39. A_array[ i ] = new double[ ld*n ];
  40. B_array[ i ] = new double[ ld*n ];
  41. C_array[ i ] = new double[ ld*n ];
  42. D_array[ i ] = new double[ ld*n ];
  43. fill_rand( n, n, A_array[ i ], ld );
  44. fill_rand( n, n, B_array[ i ], ld );
  45. fill_rand( n, n, C_array[ i ], ld );
  46. std::copy( C_array[ i ], C_array[ i ] + ld*n, D_array[ i ] );
  47. }
  48. printf( "test\n" );
  49. for (int i = 0; i < batch_size; ++i) {
  50. cblas_dgemm( CblasColMajor, CblasNoTrans, CblasNoTrans, n, n, n,
  51. alpha, A_array[ i ], ld, B_array[ i ], ld,
  52. beta, C_array[ i ], ld );
  53. }
  54. printf( "test OpenMP\n" );
  55. #pragma omp parallel for
  56. for (int i = 0; i < batch_size; ++i) {
  57. cblas_dgemm( CblasColMajor, CblasNoTrans, CblasNoTrans, n, n, n,
  58. alpha, A_array[ i ], ld, B_array[ i ], ld,
  59. beta, D_array[ i ], ld );
  60. }
  61. printf( "compare\n" );
  62. double max_error = 0;
  63. for (int i = 0; i < batch_size; ++i) {
  64. // norm( D - C )
  65. cblas_daxpy( ld*n, -1.0, C_array[ i ], 1, D_array[ i ], 1 );
  66. double error = cblas_dnrm2( ld*n, D_array[ i ], 1 );
  67. max_error = max_nan( error, max_error );
  68. }
  69. printf( "max error %.2e\n", max_error );
  70. printf( "delete\n" );
  71. for (int i = 0; i < batch_size; ++i) {
  72. delete [] A_array[ i ];
  73. delete [] B_array[ i ];
  74. delete [] C_array[ i ];
  75. }
  76. printf( "done\n" );
  77. return 0;
  78. }