You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 1.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. def get_encoder_decoder_hp(model='gin', decoder=None):
  2. if model == 'gin':
  3. model_hp = {
  4. "num_layers": 5,
  5. "hidden": [64],
  6. "act": "relu",
  7. "eps": "False",
  8. "mlp_layers": 2,
  9. "neighbor_pooling_type": "sum"
  10. }
  11. elif model == 'gat':
  12. model_hp = {
  13. # hp from model
  14. "num_layers": 2,
  15. "hidden": [8],
  16. "heads": 8,
  17. "dropout": 0.6,
  18. "act": "relu",
  19. }
  20. elif model == 'gcn':
  21. model_hp = {
  22. "num_layers": 2,
  23. "hidden": [16],
  24. "dropout": 0.5,
  25. "act": "relu"
  26. }
  27. elif model == 'sage':
  28. model_hp = {
  29. "num_layers": 2,
  30. "hidden": [64],
  31. "dropout": 0.5,
  32. "act": "relu",
  33. "agg": "gcn",
  34. }
  35. elif model == 'topk':
  36. model_hp = {
  37. "num_layers": 5,
  38. "hidden": [64, 64, 64, 64]
  39. }
  40. return model_hp, {}