Browse Source

replace tensorflow io API with python standard API or API from popular third party library

pull/13258/head
zhouneng 4 years ago
parent
commit
ae101dec11
1 changed files with 14 additions and 14 deletions
  1. +14
    -14
      model_zoo/official/recommend/ncf/src/movielens.py

+ 14
- 14
model_zoo/official/recommend/ncf/src/movielens.py View File

@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function


import os import os
import shutil
import tempfile import tempfile
import zipfile import zipfile
import argparse import argparse
@@ -32,7 +33,6 @@ from six.moves import urllib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from absl import logging from absl import logging
import tensorflow as tf


ML_1M = "ml-1m" ML_1M = "ml-1m"
ML_20M = "ml-20m" ML_20M = "ml-20m"
@@ -100,9 +100,9 @@ def _download_and_clean(dataset, data_dir):


expected_files = ["{}.zip".format(dataset), RATINGS_FILE, MOVIES_FILE] expected_files = ["{}.zip".format(dataset), RATINGS_FILE, MOVIES_FILE]


tf.io.gfile.makedirs(data_subdir)
os.makedirs(data_subdir, exist_ok=True)
if set(expected_files).intersection( if set(expected_files).intersection(
tf.io.gfile.listdir(data_subdir)) == set(expected_files):
os.listdir(data_subdir)) == set(expected_files):
logging.info("Dataset {} has already been downloaded".format(dataset)) logging.info("Dataset {} has already been downloaded".format(dataset))
return return


@@ -127,16 +127,16 @@ def _download_and_clean(dataset, data_dir):
else: else:
_regularize_20m_dataset(temp_dir) _regularize_20m_dataset(temp_dir)


for fname in tf.io.gfile.listdir(temp_dir):
if not tf.io.gfile.exists(os.path.join(data_subdir, fname)):
tf.io.gfile.copy(os.path.join(temp_dir, fname),
os.path.join(data_subdir, fname))
for fname in os.listdir(temp_dir):
if not os.path.exists(os.path.join(data_subdir, fname)):
shutil.copy(os.path.join(temp_dir, fname),
os.path.join(data_subdir, fname))
else: else:
logging.info("Skipping copy of {}, as it already exists in the " logging.info("Skipping copy of {}, as it already exists in the "
"destination folder.".format(fname)) "destination folder.".format(fname))


finally: finally:
tf.io.gfile.rmtree(temp_dir)
shutil.rmtree(temp_dir)




def _transform_csv(input_path, output_path, names, skip_first, separator=","): def _transform_csv(input_path, output_path, names, skip_first, separator=","):
@@ -152,8 +152,8 @@ def _transform_csv(input_path, output_path, names, skip_first, separator=","):
if six.PY2: if six.PY2:
names = [six.ensure_text(n, "utf-8") for n in names] names = [six.ensure_text(n, "utf-8") for n in names]


with tf.io.gfile.GFile(output_path, "wb") as f_out, \
tf.io.gfile.GFile(input_path, "rb") as f_in:
with open(output_path, "wb") as f_out, \
open(input_path, "rb") as f_in:


# Write column names to the csv. # Write column names to the csv.
f_out.write(",".join(names).encode("utf-8")) f_out.write(",".join(names).encode("utf-8"))
@@ -199,7 +199,7 @@ def _regularize_1m_dataset(temp_dir):
output_path=os.path.join(temp_dir, MOVIES_FILE), output_path=os.path.join(temp_dir, MOVIES_FILE),
names=MOVIE_COLUMNS, skip_first=False, separator="::") names=MOVIE_COLUMNS, skip_first=False, separator="::")


tf.io.gfile.rmtree(working_dir)
shutil.rmtree(working_dir)




def _regularize_20m_dataset(temp_dir): def _regularize_20m_dataset(temp_dir):
@@ -233,7 +233,7 @@ def _regularize_20m_dataset(temp_dir):
output_path=os.path.join(temp_dir, MOVIES_FILE), output_path=os.path.join(temp_dir, MOVIES_FILE),
names=MOVIE_COLUMNS, skip_first=True, separator=",") names=MOVIE_COLUMNS, skip_first=True, separator=",")


tf.io.gfile.rmtree(working_dir)
shutil.rmtree(working_dir)




def download(dataset, data_dir): def download(dataset, data_dir):
@@ -244,14 +244,14 @@ def download(dataset, data_dir):




def ratings_csv_to_dataframe(data_dir, dataset): def ratings_csv_to_dataframe(data_dir, dataset):
with tf.io.gfile.GFile(os.path.join(data_dir, dataset, RATINGS_FILE)) as f:
with open(os.path.join(data_dir, dataset, RATINGS_FILE)) as f:
return pd.read_csv(f, encoding="utf-8") return pd.read_csv(f, encoding="utf-8")




def csv_to_joint_dataframe(data_dir, dataset): def csv_to_joint_dataframe(data_dir, dataset):
ratings = ratings_csv_to_dataframe(data_dir, dataset) ratings = ratings_csv_to_dataframe(data_dir, dataset)


with tf.io.gfile.GFile(os.path.join(data_dir, dataset, MOVIES_FILE)) as f:
with open(os.path.join(data_dir, dataset, MOVIES_FILE)) as f:
movies = pd.read_csv(f, encoding="utf-8") movies = pd.read_csv(f, encoding="utf-8")


df = ratings.merge(movies, on=ITEM_COLUMN) df = ratings.merge(movies, on=ITEM_COLUMN)


Loading…
Cancel
Save