You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. def get_encoder_decoder_hp(model='gin', decoder=None, decoupled=False):
  2. if model == 'gin':
  3. model_hp = {
  4. # hp from model
  5. "num_layers": 2,
  6. "hidden": [64],
  7. "dropout": 0.5,
  8. "act": "relu",
  9. "eps": "False",
  10. "mlp_layers": 2
  11. }
  12. if model == 'gat':
  13. if decoupled:
  14. model_hp = {
  15. "num_layers": 2,
  16. "hidden": [8],
  17. "num_hidden_heads": 8,
  18. "num_output_heads": 1,
  19. "dropout": 0.6,
  20. "act": "elu"
  21. }
  22. else:
  23. model_hp = {
  24. "num_layers": 2,
  25. "hidden": [8],
  26. "heads": 8,
  27. "dropout": 0.6,
  28. "act": "elu"
  29. }
  30. elif model == 'gcn':
  31. model_hp = {
  32. "num_layers": 2,
  33. "hidden": [16],
  34. "dropout": 0.5,
  35. "act": "relu"
  36. }
  37. elif model == 'sage':
  38. model_hp = {
  39. "num_layers": 2,
  40. "hidden": [64],
  41. "dropout": 0.5,
  42. "act": "relu",
  43. "agg": "mean",
  44. }
  45. else:
  46. model_hp = {}
  47. return model_hp, {}