| @@ -12,6 +12,8 @@ | |||
| - [Script Parameters](#script-parameters) | |||
| - [Training Process](#training-process) | |||
| - [Training](#training) | |||
| - [Evaluation Process](#evaluation-process) | |||
| - [Evaluation](#evaluation) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Description of random situation](#description-of-random-situation) | |||
| @@ -88,6 +90,9 @@ After installing MindSpore via the official website and Dataset is correctly gen | |||
| ``` | |||
| # run training example with Amazon-Beauty dataset | |||
| sh run_train_ascend.sh | |||
| # run evaluation example with Amazon-Beauty dataset | |||
| sh run_eval_ascend.sh | |||
| ``` | |||
| # [Script Description](#contents) | |||
| @@ -99,6 +104,7 @@ After installing MindSpore via the official website and Dataset is correctly gen | |||
| └─bgcf | |||
| ├─README.md | |||
| ├─scripts | |||
| | ├─run_eval_ascend.sh # Launch evaluation | |||
| | ├─run_process_data_ascend.sh # Generate dataset in mindrecord format | |||
| | └─run_train_ascend.sh # Launch training | |||
| | | |||
| @@ -110,6 +116,7 @@ After installing MindSpore via the official website and Dataset is correctly gen | |||
| | ├─metrics.py # Recommendation metrics | |||
| | └─utils.py # Utils for training bgcf | |||
| | | |||
| ├─eval.py # Evaluation net | |||
| └─train.py # Train net | |||
| ``` | |||
| @@ -118,7 +125,7 @@ After installing MindSpore via the official website and Dataset is correctly gen | |||
| Parameters for both training and evaluation can be set in config.py. | |||
| - config for BGCF dataset | |||
| ```python | |||
| "learning_rate": 0.001, # Learning rate | |||
| "num_epochs": 600, # Epoch sizes for training | |||
| @@ -130,6 +137,7 @@ Parameters for both training and evaluation can be set in config.py. | |||
| "neighbor_dropout": [0.0, 0.2, 0.3]# Dropout ratio for different aggregation layer | |||
| "num_graphs":5 # Num of sample graph | |||
| ``` | |||
| config.py for more configuration. | |||
| ## [Training Process](#contents) | |||
| @@ -154,27 +162,54 @@ Parameters for both training and evaluation can be set in config.py. | |||
| Epoch 598 iter 12 loss 3640.7612 | |||
| Epoch 599 iter 12 loss 3654.9087 | |||
| Epoch 600 iter 12 loss 3632.4585 | |||
| epoch:600, recall_@10:0.10393, recall_@20:0.15669, ndcg_@10:0.07564, ndcg_@20:0.09343, | |||
| sedp_@10:0.01936, sedp_@20:0.01544, nov_@10:7.58599, nov_@20:7.79782 | |||
| ... | |||
| ``` | |||
| ## [Evaluation Process](#contents) | |||
| ### Evaluation | |||
| - Evaluation on Ascend | |||
| ```python | |||
| sh run_eval_ascend.sh | |||
| ``` | |||
| Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the | |||
| followings in log. | |||
| ```python | |||
| epoch:020, recall_@10:0.07345, recall_@20:0.11193, ndcg_@10:0.05293, ndcg_@20:0.06613, | |||
| sedp_@10:0.01393, sedp_@20:0.01126, nov_@10:6.95106, nov_@20:7.22280 | |||
| epoch:040, recall_@10:0.07410, recall_@20:0.11537, ndcg_@10:0.05387, ndcg_@20:0.06801, | |||
| sedp_@10:0.01445, sedp_@20:0.01168, nov_@10:7.34799, nov_@20:7.58883 | |||
| epoch:060, recall_@10:0.07654, recall_@20:0.11987, ndcg_@10:0.05530, ndcg_@20:0.07015, | |||
| sedp_@10:0.01474, sedp_@20:0.01206, nov_@10:7.46553, nov_@20:7.69436 | |||
| ... | |||
| epoch:560, recall_@10:0.09825, recall_@20:0.14877, ndcg_@10:0.07176, ndcg_@20:0.08883, | |||
| sedp_@10:0.01882, sedp_@20:0.01501, nov_@10:7.58045, nov_@20:7.79586 | |||
| epoch:580, recall_@10:0.09917, recall_@20:0.14970, ndcg_@10:0.07337, ndcg_@20:0.09037, | |||
| sedp_@10:0.01896, sedp_@20:0.01504, nov_@10:7.57995, nov_@20:7.79439 | |||
| epoch:600, recall_@10:0.09926, recall_@20:0.15080, ndcg_@10:0.07283, ndcg_@20:0.09016, | |||
| sedp_@10:0.01890, sedp_@20:0.01517, nov_@10:7.58277, nov_@20:7.80038 | |||
| ... | |||
| ``` | |||
| # [Model Description](#contents) | |||
| ## [Performance](#contents) | |||
| | Parameter | BGCF | | |||
| | ------------------------------------ | ----------------------------------------- | | |||
| | Resource | Ascend 910 | | |||
| | uploaded Date | 09/04/2020(month/day/year) | | |||
| | MindSpore Version | 1.0 | | |||
| | uploaded Date | | | |||
| | MindSpore Version | | | |||
| | Dataset | Amazon-Beauty | | |||
| | Training Parameter | epoch=600 | | |||
| | Optimizer | Adam | | |||
| | Loss Function | BPR loss | | |||
| | Recall@20 | 0.1534 | | |||
| | NDCG@20 | 0.0912 | | |||
| | Total time | 30min | | |||
| | Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf | | |||
| | Training Cost | 25min | | |||
| | Scripts | | | |||
| # [Description of random situation](#contents) | |||
| @@ -0,0 +1,105 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| BGCF evaluation script. | |||
| """ | |||
| import os | |||
| import datetime | |||
| import mindspore.context as context | |||
| from mindspore.train.serialization import load_checkpoint | |||
| from src.bgcf import BGCF | |||
| from src.utils import BGCFLogger | |||
| from src.config import parser_args | |||
| from src.metrics import BGCFEvaluate | |||
| from src.callback import ForwardBGCF, TestBGCF | |||
| from src.dataset import TestGraphDataset, load_graph | |||
| def evaluation(): | |||
| """evaluation""" | |||
| num_user = train_graph.graph_info()["node_num"][0] | |||
| num_item = train_graph.graph_info()["node_num"][1] | |||
| eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks) | |||
| for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval): | |||
| bgcfnet_test = BGCF([parser.input_dim, num_user, num_item], | |||
| parser.embedded_dimension, | |||
| parser.activation, | |||
| [0.0, 0.0, 0.0], | |||
| num_user, | |||
| num_item, | |||
| parser.input_dim) | |||
| load_checkpoint(parser.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch), net=bgcfnet_test) | |||
| forward_net = ForwardBGCF(bgcfnet_test) | |||
| user_reps, item_reps = TestBGCF(forward_net, num_user, num_item, parser.input_dim, test_graph_dataset) | |||
| test_recall_bgcf, test_ndcg_bgcf, \ | |||
| test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser) | |||
| if parser.log_name: | |||
| log.write( | |||
| 'epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, ' | |||
| 'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch, | |||
| test_recall_bgcf[1], | |||
| test_recall_bgcf[2], | |||
| test_ndcg_bgcf[1], | |||
| test_ndcg_bgcf[2], | |||
| test_sedp[0], | |||
| test_sedp[1], | |||
| test_nov[1], | |||
| test_nov[2])) | |||
| else: | |||
| print('epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, ' | |||
| 'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch, | |||
| test_recall_bgcf[1], | |||
| test_recall_bgcf[2], | |||
| test_ndcg_bgcf[1], | |||
| test_ndcg_bgcf[2], | |||
| test_sedp[0], | |||
| test_sedp[1], | |||
| test_nov[1], | |||
| test_nov[2])) | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="Ascend", | |||
| save_graphs=False) | |||
| parser = parser_args() | |||
| os.environ['DEVICE_ID'] = parser.device | |||
| train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath) | |||
| test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs, | |||
| num_bgcn_neigh=parser.gnew_neighs, | |||
| num_neg=parser.num_neg) | |||
| if parser.log_name: | |||
| now = datetime.datetime.now().strftime("%b_%d_%H_%M_%S") | |||
| name = "bgcf" + '-' + parser.log_name + '-' + parser.dataset | |||
| log_save_path = './log-files/' + name + '/' + now | |||
| log = BGCFLogger(logname=name, now=now, foldername='log-files', copy=False) | |||
| log.open(log_save_path + '/log.train.txt', mode='a') | |||
| for arg in vars(parser): | |||
| log.write(arg + '=' + str(getattr(parser, arg)) + '\n') | |||
| else: | |||
| for arg in vars(parser): | |||
| print(arg + '=' + str(getattr(parser, arg))) | |||
| evaluation() | |||
| @@ -0,0 +1,38 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=1 | |||
| export RANK_SIZE=$DEVICE_NUM | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| if [ -d "eval" ]; | |||
| then | |||
| rm -rf ./eval | |||
| fi | |||
| mkdir ./eval | |||
| cp ../*.py ./eval | |||
| cp *.sh ./eval | |||
| cp -r ../src ./eval | |||
| cd ./eval || exit | |||
| env > env.log | |||
| echo "start evaluation for device $DEVICE_ID" | |||
| python eval.py --datapath=../data_mr --ckptpath=../ckpts &> log & | |||
| cd .. | |||
| @@ -25,14 +25,20 @@ then | |||
| rm -rf ./train | |||
| fi | |||
| mkdir ./train | |||
| if [ -d "ckpts" ]; | |||
| then | |||
| rm -rf ./ckpts | |||
| fi | |||
| mkdir ./ckpts | |||
| cp ../*.py ./train | |||
| cp *.sh ./train | |||
| cp -r ../src ./train | |||
| cd ./train || exit | |||
| mkdir ./ckpts | |||
| env > env.log | |||
| echo "start training for device $DEVICE_ID" | |||
| python train.py --datapath=../data_mr &> log & | |||
| python train.py --datapath=../data_mr --ckptpath=../ckpts &> log & | |||
| cd .. | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in train.py | |||
| network config setting | |||
| """ | |||
| import argparse | |||
| @@ -21,37 +21,38 @@ import argparse | |||
| def parser_args(): | |||
| """Config for BGCF""" | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("-d", "--dataset", type=str, default="Beauty") | |||
| parser.add_argument("-dpath", "--datapath", type=str, default="./scripts/data_mr") | |||
| parser.add_argument("-de", "--device", type=str, default='0') | |||
| parser.add_argument('--seed', type=int, default=0) | |||
| parser.add_argument('--Ks', type=list, default=[5, 10, 20, 100]) | |||
| parser.add_argument('--test_ratio', type=float, default=0.2) | |||
| parser.add_argument('--val_ratio', type=float, default=None) | |||
| parser.add_argument('-w', '--workers', type=int, default=10) | |||
| parser.add_argument("-d", "--dataset", type=str, default="Beauty", help="choose which dataset") | |||
| parser.add_argument("-dpath", "--datapath", type=str, default="./scripts/data_mr", help="minddata path") | |||
| parser.add_argument("-de", "--device", type=str, default='0', help="device id") | |||
| parser.add_argument('--seed', type=int, default=0, help="random seed") | |||
| parser.add_argument('--Ks', type=list, default=[5, 10, 20, 100], help="top K") | |||
| parser.add_argument('--test_ratio', type=float, default=0.2, help="test ratio") | |||
| parser.add_argument('-w', '--workers', type=int, default=8, help="number of process") | |||
| parser.add_argument("-ckpt", "--ckptpath", type=str, default="./ckpts", help="checkpoint path") | |||
| parser.add_argument("-eps", "--epsilon", type=float, default=1e-8) | |||
| parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3) | |||
| parser.add_argument("-l2", "--l2", type=float, default=0.03) | |||
| parser.add_argument("-wd", "--weight_decay", type=float, default=0.01) | |||
| parser.add_argument("-act", "--activation", type=str, default='tanh', choices=['relu', 'tanh']) | |||
| parser.add_argument("-ndrop", "--neighbor_dropout", type=list, default=[0.0, 0.2, 0.3]) | |||
| parser.add_argument("-log", "--log_name", type=str, default='test') | |||
| parser.add_argument("-eps", "--epsilon", type=float, default=1e-8, help="optimizer parameter") | |||
| parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3, help="learning rate") | |||
| parser.add_argument("-l2", "--l2", type=float, default=0.03, help="l2 coefficient") | |||
| parser.add_argument("-wd", "--weight_decay", type=float, default=0.01, help="weight decay") | |||
| parser.add_argument("-act", "--activation", type=str, default='tanh', choices=['relu', 'tanh'], | |||
| help="activation function") | |||
| parser.add_argument("-ndrop", "--neighbor_dropout", type=list, default=[0.0, 0.2, 0.3], | |||
| help="dropout ratio for different aggregation layer") | |||
| parser.add_argument("-log", "--log_name", type=str, default='test', help="log name") | |||
| parser.add_argument("-e", "--num_epoch", type=int, default=600) | |||
| parser.add_argument('-input', '--input_dim', type=int, default=64, choices=[64, 128]) | |||
| parser.add_argument("-b", "--batch_pairs", type=int, default=5000) | |||
| parser.add_argument('--eval_interval', type=int, default=20) | |||
| parser.add_argument("-e", "--num_epoch", type=int, default=600, help="epoch sizes for training") | |||
| parser.add_argument('-input', '--input_dim', type=int, default=64, choices=[64, 128], | |||
| help="user and item embedding dimension") | |||
| parser.add_argument("-b", "--batch_pairs", type=int, default=5000, help="batch size") | |||
| parser.add_argument('--eval_interval', type=int, default=20, help="evaluation interval") | |||
| parser.add_argument("-neg", "--num_neg", type=int, default=10) | |||
| parser.add_argument('-max', '--max_degree', type=str, default='[128,128]') | |||
| parser.add_argument("-g1", "--raw_neighs", type=int, default=40) | |||
| parser.add_argument("-g2", "--gnew_neighs", type=int, default=20) | |||
| parser.add_argument("-emb", "--embedded_dimension", type=int, default=64) | |||
| parser.add_argument('-dist', '--distance', type=str, default='iou') | |||
| parser.add_argument('--dist_reg', type=float, default=0.003) | |||
| parser.add_argument("-neg", "--num_neg", type=int, default=10, help="negative sampling rate ") | |||
| parser.add_argument("-g1", "--raw_neighs", type=int, default=40, help="num of sampling neighbors in raw graph") | |||
| parser.add_argument("-g2", "--gnew_neighs", type=int, default=20, help="num of sampling neighbors in sample graph") | |||
| parser.add_argument("-emb", "--embedded_dimension", type=int, default=64, help="output embedding dim") | |||
| parser.add_argument('--dist_reg', type=float, default=0.003, help="distance loss coefficient") | |||
| parser.add_argument('-ng', '--num_graphs', type=int, default=5) | |||
| parser.add_argument('-geps', '--graph_epsilon', type=float, default=0.01) | |||
| parser.add_argument('-ng', '--num_graphs', type=int, default=5, help="num of sample graph") | |||
| parser.add_argument('-geps', '--graph_epsilon', type=float, default=0.01, help="node copy parameter") | |||
| return parser.parse_args() | |||
| @@ -175,8 +175,8 @@ def load_graph(data_path): | |||
| return train_graph, test_graph, sampled_graph_list | |||
| def create_dataset(train_graph, sampled_graph_list, batch_size=32, repeat_size=1, num_samples=40, num_bgcn_neigh=20, | |||
| num_neg=10): | |||
| def create_dataset(train_graph, sampled_graph_list, num_workers, batch_size=32, repeat_size=1, | |||
| num_samples=40, num_bgcn_neigh=20, num_neg=10): | |||
| """Data generator for training""" | |||
| edge_num = train_graph.graph_info()['edge_num'][0] | |||
| out_column_names = ["users", "items", "neg_item_id", "pos_users", "pos_items", "u_group_nodes", "u_neighs", | |||
| @@ -185,7 +185,7 @@ def create_dataset(train_graph, sampled_graph_list, batch_size=32, repeat_size=1 | |||
| train_graph_dataset = TrainGraphDataset( | |||
| train_graph, sampled_graph_list, batch_size, num_samples, num_bgcn_neigh, num_neg) | |||
| dataset = ds.GeneratorDataset(source=train_graph_dataset, column_names=out_column_names, | |||
| sampler=RandomBatchedSampler(edge_num, batch_size), num_parallel_workers=8) | |||
| sampler=RandomBatchedSampler(edge_num, batch_size), num_parallel_workers=num_workers) | |||
| dataset = dataset.repeat(repeat_size) | |||
| return dataset | |||
| @@ -17,23 +17,21 @@ BGCF training script. | |||
| """ | |||
| import os | |||
| import time | |||
| import datetime | |||
| from mindspore import Tensor | |||
| import mindspore.context as context | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.train.serialization import save_checkpoint, load_checkpoint | |||
| from mindspore.train.serialization import save_checkpoint | |||
| from src.bgcf import BGCF | |||
| from src.metrics import BGCFEvaluate | |||
| from src.config import parser_args | |||
| from src.utils import BGCFLogger, convert_item_id | |||
| from src.callback import ForwardBGCF, TrainBGCF, TestBGCF | |||
| from src.dataset import load_graph, create_dataset, TestGraphDataset | |||
| from src.utils import convert_item_id | |||
| from src.callback import TrainBGCF | |||
| from src.dataset import load_graph, create_dataset | |||
| def train_and_eval(): | |||
| """Train and eval""" | |||
| def train(): | |||
| """Train""" | |||
| num_user = train_graph.graph_info()["node_num"][0] | |||
| num_item = train_graph.graph_info()["node_num"][1] | |||
| num_pairs = train_graph.graph_info()['edge_num'][0] | |||
| @@ -50,8 +48,6 @@ def train_and_eval(): | |||
| parser.epsilon, parser.dist_reg) | |||
| train_net.set_train(True) | |||
| eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks) | |||
| itr = train_ds.create_dict_iterator(parser.num_epoch, output_numpy=True) | |||
| num_iter = int(num_pairs / parser.batch_pairs) | |||
| @@ -102,49 +98,7 @@ def train_and_eval(): | |||
| iter_num += 1 | |||
| if _epoch % parser.eval_interval == 0: | |||
| if os.path.exists("ckpts/bgcf.ckpt"): | |||
| os.remove("ckpts/bgcf.ckpt") | |||
| save_checkpoint(bgcfnet, "ckpts/bgcf.ckpt") | |||
| bgcfnet_test = BGCF([parser.input_dim, num_user, num_item], | |||
| parser.embedded_dimension, | |||
| parser.activation, | |||
| [0.0, 0.0, 0.0], | |||
| num_user, | |||
| num_item, | |||
| parser.input_dim) | |||
| load_checkpoint("ckpts/bgcf.ckpt", net=bgcfnet_test) | |||
| forward_net = ForwardBGCF(bgcfnet_test) | |||
| user_reps, item_reps = TestBGCF(forward_net, num_user, num_item, parser.input_dim, test_graph_dataset) | |||
| test_recall_bgcf, test_ndcg_bgcf, \ | |||
| test_sedp, test_nov = eval_class.eval_with_rep(user_reps, item_reps, parser) | |||
| if parser.log_name: | |||
| log.write( | |||
| 'epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, ' | |||
| 'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch, | |||
| test_recall_bgcf[1], | |||
| test_recall_bgcf[2], | |||
| test_ndcg_bgcf[1], | |||
| test_ndcg_bgcf[2], | |||
| test_sedp[0], | |||
| test_sedp[1], | |||
| test_nov[1], | |||
| test_nov[2])) | |||
| else: | |||
| print('epoch:%03d, recall_@10:%.5f, recall_@20:%.5f, ndcg_@10:%.5f, ndcg_@20:%.5f, ' | |||
| 'sedp_@10:%.5f, sedp_@20:%.5f, nov_@10:%.5f, nov_@20:%.5f\n' % (_epoch, | |||
| test_recall_bgcf[1], | |||
| test_recall_bgcf[2], | |||
| test_ndcg_bgcf[1], | |||
| test_ndcg_bgcf[2], | |||
| test_sedp[0], | |||
| test_sedp[1], | |||
| test_nov[1], | |||
| test_nov[2])) | |||
| save_checkpoint(bgcfnet, parser.ckptpath + "/bgcf_epoch{}.ckpt".format(_epoch)) | |||
| if __name__ == "__main__": | |||
| @@ -153,23 +107,9 @@ if __name__ == "__main__": | |||
| save_graphs=False) | |||
| parser = parser_args() | |||
| os.environ['DEVICE_ID'] = parser.device | |||
| train_graph, _, sampled_graph_list = load_graph(parser.datapath) | |||
| train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs, | |||
| num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg) | |||
| train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath) | |||
| train_ds = create_dataset(train_graph, sampled_graph_list, batch_size=parser.batch_pairs) | |||
| test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs, | |||
| num_bgcn_neigh=parser.gnew_neighs, | |||
| num_neg=parser.num_neg) | |||
| if parser.log_name: | |||
| now = datetime.datetime.now().strftime("%b_%d_%H_%M_%S") | |||
| name = "bgcf" + '-' + parser.log_name + '-' + parser.dataset | |||
| log_save_path = './log-files/' + name + '/' + now | |||
| log = BGCFLogger(logname=name, now=now, foldername='log-files', copy=False) | |||
| log.open(log_save_path + '/log.train.txt', mode='a') | |||
| for arg in vars(parser): | |||
| log.write(arg + '=' + str(getattr(parser, arg)) + '\n') | |||
| else: | |||
| for arg in vars(parser): | |||
| print(arg + '=' + str(getattr(parser, arg))) | |||
| train_and_eval() | |||
| train() | |||
| @@ -258,7 +258,7 @@ def trans(src_path, data_name, out_path): | |||
| test_ratio = 0.2 | |||
| train_set, test_set = split_data_randomly( | |||
| inner_data_records, test_ratio=test_ratio, seed=0) | |||
| inner_data_records, test_ratio, seed=0) | |||
| train_matrix = generate_rating_matrix( | |||
| train_set, len(user_mapping), len(item_mapping)) | |||
| u_adj_list, v_adj_list = create_adj_matrix(train_matrix) | |||
| @@ -329,7 +329,7 @@ def trans(src_path, data_name, out_path): | |||
| for i in range(num_graphs): | |||
| print('=== info: sampling graph {} / {}'.format(i + 1, num_graphs)) | |||
| sampled_user_graph = sample_graph_copying(node_neighbors_dict=u_adj_list, | |||
| distances=user_distances) | |||
| distances=user_distances, epsilon=0.01) | |||
| print('avg. sampled user-item graph degree: ', | |||
| np.mean([len(x) for x in [*sampled_user_graph.values()]])) | |||