You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # GCN Example
  2. ## Description
  3. This is an example of training GCN with Cora and Citeseer dataset in MindSpore.
  4. ## Requirements
  5. - Install [MindSpore](https://www.mindspore.cn/install/en).
  6. - Download the dataset Cora or Citeseer provided by /kimiyoung/planetoid from github.
  7. > Place the dataset to any path you want, the folder should include files as follows(we use Cora dataset as an example):
  8. ```
  9. .
  10. └─data
  11. ├─ind.cora.allx
  12. ├─ind.cora.ally
  13. ├─ind.cora.graph
  14. ├─ind.cora.test.index
  15. ├─ind.cora.tx
  16. ├─ind.cora.ty
  17. ├─ind.cora.x
  18. └─ind.cora.y
  19. ```
  20. > Generate dataset in mindrecord format for cora or citeseer.
  21. >> Usage
  22. ```buildoutcfg
  23. cd ./scripts
  24. # SRC_PATH is the dataset file path you downloaded, DATASET_NAME is cora or citeseer
  25. sh run_process_data.sh [SRC_PATH] [DATASET_NAME]
  26. ```
  27. >> Launch
  28. ```
  29. #Generate dataset in mindrecord format for cora
  30. sh run_process_data.sh ./data cora
  31. #Generate dataset in mindrecord format for citeseer
  32. sh run_process_data.sh ./data citeseer
  33. ```
  34. ## Structure
  35. ```shell
  36. .
  37. └─gcn
  38. ├─README.md
  39. ├─scripts
  40. | ├─run_process_data.sh # Generate dataset in mindrecord format
  41. | └─run_train.sh # Launch training
  42. |
  43. ├─src
  44. | ├─config.py # Parameter configuration
  45. | ├─dataset.py # Data preprocessin
  46. | ├─gcn.py # GCN backbone
  47. | └─metrics.py # Loss and accuracy
  48. |
  49. └─train.py # Train net
  50. ```
  51. ## Parameter configuration
  52. Parameters for training can be set in config.py.
  53. ```
  54. "learning_rate": 0.01, # Learning rate
  55. "epochs": 200, # Epoch sizes for training
  56. "hidden1": 16, # Hidden size for the first graph convolution layer
  57. "dropout": 0.5, # Dropout ratio for the first graph convolution layer
  58. "weight_decay": 5e-4, # Weight decay for the parameter of the first graph convolution layer
  59. "early_stopping": 10, # Tolerance for early stopping
  60. ```
  61. ## Running the example
  62. ### Train
  63. #### Usage
  64. ```
  65. # run train with cora or citeseer dataset, DATASET_NAME is cora or citeseer
  66. sh run_train.sh [DATASET_NAME]
  67. ```
  68. #### Launch
  69. ```bash
  70. sh run_train.sh cora
  71. ```
  72. #### Result
  73. Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the followings in log.
  74. ```
  75. Epoch: 0001 train_loss= 1.95373 train_acc= 0.09286 val_loss= 1.95075 val_acc= 0.20200 time= 7.25737
  76. Epoch: 0002 train_loss= 1.94812 train_acc= 0.32857 val_loss= 1.94717 val_acc= 0.34000 time= 0.00438
  77. Epoch: 0003 train_loss= 1.94249 train_acc= 0.47857 val_loss= 1.94337 val_acc= 0.43000 time= 0.00428
  78. Epoch: 0004 train_loss= 1.93550 train_acc= 0.55000 val_loss= 1.93957 val_acc= 0.46400 time= 0.00421
  79. Epoch: 0005 train_loss= 1.92617 train_acc= 0.67143 val_loss= 1.93558 val_acc= 0.45400 time= 0.00430
  80. ...
  81. Epoch: 0196 train_loss= 0.60326 train_acc= 0.97857 val_loss= 1.05155 val_acc= 0.78200 time= 0.00418
  82. Epoch: 0197 train_loss= 0.60377 train_acc= 0.97143 val_loss= 1.04940 val_acc= 0.78000 time= 0.00418
  83. Epoch: 0198 train_loss= 0.60680 train_acc= 0.95000 val_loss= 1.04847 val_acc= 0.78000 time= 0.00414
  84. Epoch: 0199 train_loss= 0.61920 train_acc= 0.96429 val_loss= 1.04797 val_acc= 0.78400 time= 0.00413
  85. Epoch: 0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.78600 time= 0.00415
  86. Optimization Finished!
  87. Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083
  88. ...
  89. ```