|
|
|
@@ -37,15 +37,15 @@ parser.add_argument('--device_target', type=str, default="Ascend", choices=['Asc |
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.data_name == "ag":
|
|
|
|
from src.config import config_ag as config
|
|
|
|
from src.config import config_ag as config_ascend
|
|
|
|
from src.config import config_ag_gpu as config_gpu
|
|
|
|
target_label1 = ['0', '1', '2', '3']
|
|
|
|
elif args.data_name == 'dbpedia':
|
|
|
|
from src.config import config_db as config
|
|
|
|
from src.config import config_db as config_ascend
|
|
|
|
from src.config import config_db_gpu as config_gpu
|
|
|
|
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
|
|
|
|
elif args.data_name == 'yelp_p':
|
|
|
|
from src.config import config_yelpp as config
|
|
|
|
from src.config import config_yelpp as config_ascend
|
|
|
|
from src.config import config_yelpp_gpu as config_gpu
|
|
|
|
target_label1 = ['0', '1']
|
|
|
|
context.set_context(
|
|
|
|
|