| @@ -22,6 +22,7 @@ from __future__ import division | |||||
| from __future__ import print_function | from __future__ import print_function | ||||
| import os | import os | ||||
| import shutil | |||||
| import tempfile | import tempfile | ||||
| import zipfile | import zipfile | ||||
| import argparse | import argparse | ||||
| @@ -32,7 +33,6 @@ from six.moves import urllib | |||||
| import numpy as np | import numpy as np | ||||
| import pandas as pd | import pandas as pd | ||||
| from absl import logging | from absl import logging | ||||
| import tensorflow as tf | |||||
| ML_1M = "ml-1m" | ML_1M = "ml-1m" | ||||
| ML_20M = "ml-20m" | ML_20M = "ml-20m" | ||||
| @@ -100,9 +100,9 @@ def _download_and_clean(dataset, data_dir): | |||||
| expected_files = ["{}.zip".format(dataset), RATINGS_FILE, MOVIES_FILE] | expected_files = ["{}.zip".format(dataset), RATINGS_FILE, MOVIES_FILE] | ||||
| tf.io.gfile.makedirs(data_subdir) | |||||
| os.makedirs(data_subdir, exist_ok=True) | |||||
| if set(expected_files).intersection( | if set(expected_files).intersection( | ||||
| tf.io.gfile.listdir(data_subdir)) == set(expected_files): | |||||
| os.listdir(data_subdir)) == set(expected_files): | |||||
| logging.info("Dataset {} has already been downloaded".format(dataset)) | logging.info("Dataset {} has already been downloaded".format(dataset)) | ||||
| return | return | ||||
| @@ -127,16 +127,16 @@ def _download_and_clean(dataset, data_dir): | |||||
| else: | else: | ||||
| _regularize_20m_dataset(temp_dir) | _regularize_20m_dataset(temp_dir) | ||||
| for fname in tf.io.gfile.listdir(temp_dir): | |||||
| if not tf.io.gfile.exists(os.path.join(data_subdir, fname)): | |||||
| tf.io.gfile.copy(os.path.join(temp_dir, fname), | |||||
| os.path.join(data_subdir, fname)) | |||||
| for fname in os.listdir(temp_dir): | |||||
| if not os.path.exists(os.path.join(data_subdir, fname)): | |||||
| shutil.copy(os.path.join(temp_dir, fname), | |||||
| os.path.join(data_subdir, fname)) | |||||
| else: | else: | ||||
| logging.info("Skipping copy of {}, as it already exists in the " | logging.info("Skipping copy of {}, as it already exists in the " | ||||
| "destination folder.".format(fname)) | "destination folder.".format(fname)) | ||||
| finally: | finally: | ||||
| tf.io.gfile.rmtree(temp_dir) | |||||
| shutil.rmtree(temp_dir) | |||||
| def _transform_csv(input_path, output_path, names, skip_first, separator=","): | def _transform_csv(input_path, output_path, names, skip_first, separator=","): | ||||
| @@ -152,8 +152,8 @@ def _transform_csv(input_path, output_path, names, skip_first, separator=","): | |||||
| if six.PY2: | if six.PY2: | ||||
| names = [six.ensure_text(n, "utf-8") for n in names] | names = [six.ensure_text(n, "utf-8") for n in names] | ||||
| with tf.io.gfile.GFile(output_path, "wb") as f_out, \ | |||||
| tf.io.gfile.GFile(input_path, "rb") as f_in: | |||||
| with open(output_path, "wb") as f_out, \ | |||||
| open(input_path, "rb") as f_in: | |||||
| # Write column names to the csv. | # Write column names to the csv. | ||||
| f_out.write(",".join(names).encode("utf-8")) | f_out.write(",".join(names).encode("utf-8")) | ||||
| @@ -199,7 +199,7 @@ def _regularize_1m_dataset(temp_dir): | |||||
| output_path=os.path.join(temp_dir, MOVIES_FILE), | output_path=os.path.join(temp_dir, MOVIES_FILE), | ||||
| names=MOVIE_COLUMNS, skip_first=False, separator="::") | names=MOVIE_COLUMNS, skip_first=False, separator="::") | ||||
| tf.io.gfile.rmtree(working_dir) | |||||
| shutil.rmtree(working_dir) | |||||
| def _regularize_20m_dataset(temp_dir): | def _regularize_20m_dataset(temp_dir): | ||||
| @@ -233,7 +233,7 @@ def _regularize_20m_dataset(temp_dir): | |||||
| output_path=os.path.join(temp_dir, MOVIES_FILE), | output_path=os.path.join(temp_dir, MOVIES_FILE), | ||||
| names=MOVIE_COLUMNS, skip_first=True, separator=",") | names=MOVIE_COLUMNS, skip_first=True, separator=",") | ||||
| tf.io.gfile.rmtree(working_dir) | |||||
| shutil.rmtree(working_dir) | |||||
| def download(dataset, data_dir): | def download(dataset, data_dir): | ||||
| @@ -244,14 +244,14 @@ def download(dataset, data_dir): | |||||
| def ratings_csv_to_dataframe(data_dir, dataset): | def ratings_csv_to_dataframe(data_dir, dataset): | ||||
| with tf.io.gfile.GFile(os.path.join(data_dir, dataset, RATINGS_FILE)) as f: | |||||
| with open(os.path.join(data_dir, dataset, RATINGS_FILE)) as f: | |||||
| return pd.read_csv(f, encoding="utf-8") | return pd.read_csv(f, encoding="utf-8") | ||||
| def csv_to_joint_dataframe(data_dir, dataset): | def csv_to_joint_dataframe(data_dir, dataset): | ||||
| ratings = ratings_csv_to_dataframe(data_dir, dataset) | ratings = ratings_csv_to_dataframe(data_dir, dataset) | ||||
| with tf.io.gfile.GFile(os.path.join(data_dir, dataset, MOVIES_FILE)) as f: | |||||
| with open(os.path.join(data_dir, dataset, MOVIES_FILE)) as f: | |||||
| movies = pd.read_csv(f, encoding="utf-8") | movies = pd.read_csv(f, encoding="utf-8") | ||||
| df = ratings.merge(movies, on=ITEM_COLUMN) | df = ratings.merge(movies, on=ITEM_COLUMN) | ||||