|
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import contextlib
- import io
- import itertools
- import json
- import logging
- import os
- import tempfile
- from collections import OrderedDict
- from fvcore.common.file_io import PathManager
- from PIL import Image
- from tabulate import tabulate
-
- from detectron2.data import MetadataCatalog
- from detectron2.utils import comm
-
- from .evaluator import DatasetEvaluator
-
- logger = logging.getLogger(__name__)
-
-
- class COCOPanopticEvaluator(DatasetEvaluator):
- """
- Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
- It saves panoptic segmentation prediction in `output_dir`
-
- It contains a synchronize call and has to be called from all workers.
- """
-
- def __init__(self, dataset_name, output_dir):
- """
- Args:
- dataset_name (str): name of the dataset
- output_dir (str): output directory to save results for evaluation
- """
- self._metadata = MetadataCatalog.get(dataset_name)
- self._thing_contiguous_id_to_dataset_id = {
- v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
- }
- self._stuff_contiguous_id_to_dataset_id = {
- v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items()
- }
-
- self._predictions_json = os.path.join(output_dir, "predictions.json")
-
- def reset(self):
- self._predictions = []
-
- def _convert_category_id(self, segment_info):
- isthing = segment_info.pop("isthing", None)
- if isthing is None:
- # the model produces panoptic category id directly. No more conversion needed
- return segment_info
- if isthing is True:
- segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
- segment_info["category_id"]
- ]
- else:
- segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
- segment_info["category_id"]
- ]
- return segment_info
-
- def process(self, inputs, outputs):
- from panopticapi.utils import id2rgb
-
- for input, output in zip(inputs, outputs):
- panoptic_img, segments_info = output["panoptic_seg"]
- panoptic_img = panoptic_img.cpu().numpy()
-
- file_name = os.path.basename(input["file_name"])
- file_name_png = os.path.splitext(file_name)[0] + ".png"
- with io.BytesIO() as out:
- Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
- segments_info = [self._convert_category_id(x) for x in segments_info]
- self._predictions.append(
- {
- "image_id": input["image_id"],
- "file_name": file_name_png,
- "png_string": out.getvalue(),
- "segments_info": segments_info,
- }
- )
-
- def evaluate(self):
- comm.synchronize()
-
- self._predictions = comm.gather(self._predictions)
- self._predictions = list(itertools.chain(*self._predictions))
- if not comm.is_main_process():
- return
-
- gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
- gt_folder = self._metadata.panoptic_root
-
- with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
- logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
- for p in self._predictions:
- with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
- f.write(p.pop("png_string"))
-
- with open(gt_json, "r") as f:
- json_data = json.load(f)
- json_data["annotations"] = self._predictions
- with PathManager.open(self._predictions_json, "w") as f:
- f.write(json.dumps(json_data))
-
- from panopticapi.evaluation import pq_compute
-
- with contextlib.redirect_stdout(io.StringIO()):
- pq_res = pq_compute(
- gt_json,
- PathManager.get_local_path(self._predictions_json),
- gt_folder=gt_folder,
- pred_folder=pred_dir,
- )
-
- res = {}
- res["PQ"] = 100 * pq_res["All"]["pq"]
- res["SQ"] = 100 * pq_res["All"]["sq"]
- res["RQ"] = 100 * pq_res["All"]["rq"]
- res["PQ_th"] = 100 * pq_res["Things"]["pq"]
- res["SQ_th"] = 100 * pq_res["Things"]["sq"]
- res["RQ_th"] = 100 * pq_res["Things"]["rq"]
- res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
- res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
- res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
-
- results = OrderedDict({"panoptic_seg": res})
- _print_panoptic_results(pq_res)
-
- return results
-
-
- def _print_panoptic_results(pq_res):
- headers = ["", "PQ", "SQ", "RQ", "#categories"]
- data = []
- for name in ["All", "Things", "Stuff"]:
- row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
- data.append(row)
- table = tabulate(
- data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
- )
- logger.info("Panoptic Evaluation Results:\n" + table)
-
-
- if __name__ == "__main__":
- from detectron2.utils.logger import setup_logger
-
- logger = setup_logger()
- import argparse
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--gt-json")
- parser.add_argument("--gt-dir")
- parser.add_argument("--pred-json")
- parser.add_argument("--pred-dir")
- args = parser.parse_args()
-
- from panopticapi.evaluation import pq_compute
-
- with contextlib.redirect_stdout(io.StringIO()):
- pq_res = pq_compute(
- args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir
- )
- _print_panoptic_results(pq_res)
|