init - 初始化项目
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
from .configs.test_config import TestClsConfig, TestClsModuleConfig
|
||||
from .model_test_pipeline import ModelTestPipeline
|
||||
from ..evaluation.classification.cls_accuracy_evaluator import ClsAccEvaluation
|
||||
from ..utils import get_test_module
|
||||
|
||||
|
||||
class ClsModelTestPipeline(ModelTestPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
network_model,
|
||||
model_processor,
|
||||
dnn_model_processor,
|
||||
data_fetcher,
|
||||
img_processor=None,
|
||||
cls_args_parser=None,
|
||||
default_input_blob_preproc=None
|
||||
):
|
||||
super(ClsModelTestPipeline, self).__init__(
|
||||
network_model,
|
||||
model_processor,
|
||||
dnn_model_processor
|
||||
)
|
||||
|
||||
if cls_args_parser:
|
||||
self._parser = cls_args_parser
|
||||
|
||||
self.test_config = TestClsConfig()
|
||||
|
||||
parser_args = self._parser.parse_args()
|
||||
|
||||
if parser_args.test:
|
||||
self._test_module_config = TestClsModuleConfig()
|
||||
self._test_module = get_test_module(
|
||||
self._test_module_config.test_module_name,
|
||||
self._test_module_config.test_module_path
|
||||
)
|
||||
|
||||
if parser_args.default_img_preprocess:
|
||||
self._default_input_blob_preproc = default_input_blob_preproc
|
||||
if parser_args.evaluate:
|
||||
self._data_fetcher = data_fetcher(self.test_config, img_processor)
|
||||
|
||||
def _configure_test_module_params(self):
|
||||
self._test_module_param_list.extend((
|
||||
'--crop', self._test_module_config.crop,
|
||||
'--std', *self._test_module_config.std
|
||||
))
|
||||
|
||||
if self._test_module_config.rsz_height and self._test_module_config.rsz_width:
|
||||
self._test_module_param_list.extend((
|
||||
'--initial_height', self._test_module_config.rsz_height,
|
||||
'--initial_width', self._test_module_config.rsz_width,
|
||||
))
|
||||
|
||||
def _configure_acc_eval(self, log_path):
|
||||
self._accuracy_evaluator = ClsAccEvaluation(
|
||||
log_path,
|
||||
self.test_config.img_cls_file,
|
||||
self.test_config.batch_size
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
BASE_IMG_SCALE_FACTOR = 1 / 255.0
|
||||
PYTORCH_RSZ_HEIGHT = 256
|
||||
PYTORCH_RSZ_WIDTH = 256
|
||||
|
||||
pytorch_resize_input_blob = {
|
||||
"mean": ["123.675", "116.28", "103.53"],
|
||||
"scale": str(BASE_IMG_SCALE_FACTOR),
|
||||
"std": ["0.229", "0.224", "0.225"],
|
||||
"crop": "True",
|
||||
"rgb": True,
|
||||
"rsz_height": str(PYTORCH_RSZ_HEIGHT),
|
||||
"rsz_width": str(PYTORCH_RSZ_WIDTH)
|
||||
}
|
||||
|
||||
pytorch_input_blob = {
|
||||
"mean": ["123.675", "116.28", "103.53"],
|
||||
"scale": str(BASE_IMG_SCALE_FACTOR),
|
||||
"std": ["0.229", "0.224", "0.225"],
|
||||
"crop": "True",
|
||||
"rgb": True
|
||||
}
|
||||
|
||||
tf_input_blob = {
|
||||
"scale": str(1 / 127.5),
|
||||
"mean": ["127.5", "127.5", "127.5"],
|
||||
"std": [],
|
||||
"crop": "True",
|
||||
"rgb": True
|
||||
}
|
||||
|
||||
tf_model_blob_caffe_mode = {
|
||||
"mean": ["103.939", "116.779", "123.68"],
|
||||
"scale": "1.0",
|
||||
"std": [],
|
||||
"crop": "True",
|
||||
"rgb": False
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonConfig:
|
||||
output_data_root_dir: str = "dnn_model_runner/dnn_conversion"
|
||||
logs_dir: str = os.path.join(output_data_root_dir, "logs")
|
||||
log_file_path: str = os.path.join(logs_dir, "{}_log.txt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestClsConfig:
|
||||
batch_size: int = 1
|
||||
frame_size: int = 224
|
||||
img_root_dir: str = "./ILSVRC2012_img_val"
|
||||
# location of image-class matching
|
||||
img_cls_file: str = "./val.txt"
|
||||
bgr_to_rgb: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestClsModuleConfig:
|
||||
cls_test_data_dir: str = "../data"
|
||||
test_module_name: str = "classification"
|
||||
test_module_path: str = "classification.py"
|
||||
input_img: str = os.path.join(cls_test_data_dir, "squirrel_cls.jpg")
|
||||
model: str = ""
|
||||
|
||||
frame_height: str = str(TestClsConfig.frame_size)
|
||||
frame_width: str = str(TestClsConfig.frame_size)
|
||||
scale: str = "1.0"
|
||||
mean: List[str] = field(default_factory=lambda: ["0.0", "0.0", "0.0"])
|
||||
std: List[str] = field(default_factory=list)
|
||||
crop: str = "False"
|
||||
rgb: str = "True"
|
||||
rsz_height: str = ""
|
||||
rsz_width: str = ""
|
||||
classes: str = os.path.join(cls_test_data_dir, "dnn", "classification_classes_ILSVRC2012.txt")
|
||||
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .configs.test_config import CommonConfig
|
||||
from ..utils import create_parser, plot_acc
|
||||
|
||||
|
||||
class ModelTestPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
network_model,
|
||||
model_processor,
|
||||
dnn_model_processor
|
||||
):
|
||||
self._net_model = network_model
|
||||
self._model_processor = model_processor
|
||||
self._dnn_model_processor = dnn_model_processor
|
||||
|
||||
self._parser = create_parser()
|
||||
|
||||
self._test_module = None
|
||||
self._test_module_config = None
|
||||
self._test_module_param_list = None
|
||||
|
||||
self.test_config = None
|
||||
self._data_fetcher = None
|
||||
|
||||
self._default_input_blob_preproc = None
|
||||
self._accuracy_evaluator = None
|
||||
|
||||
def init_test_pipeline(self):
|
||||
cmd_args = self._parser.parse_args()
|
||||
model_dict = self._net_model.get_prepared_models()
|
||||
|
||||
model_names = list(model_dict.keys())
|
||||
print(
|
||||
"The model {} was successfully obtained and converted to OpenCV {}".format(model_names[0], model_names[1])
|
||||
)
|
||||
|
||||
if cmd_args.test:
|
||||
if not self._test_module_config.model:
|
||||
self._test_module_config.model = self._net_model.model_path["full_path"]
|
||||
|
||||
if cmd_args.default_img_preprocess:
|
||||
self._test_module_config.scale = self._default_input_blob_preproc["scale"]
|
||||
self._test_module_config.mean = self._default_input_blob_preproc["mean"]
|
||||
self._test_module_config.std = self._default_input_blob_preproc["std"]
|
||||
self._test_module_config.crop = self._default_input_blob_preproc["crop"]
|
||||
|
||||
if "rsz_height" in self._default_input_blob_preproc and "rsz_width" in self._default_input_blob_preproc:
|
||||
self._test_module_config.rsz_height = self._default_input_blob_preproc["rsz_height"]
|
||||
self._test_module_config.rsz_width = self._default_input_blob_preproc["rsz_width"]
|
||||
|
||||
self._test_module_param_list = [
|
||||
'--model', self._test_module_config.model,
|
||||
'--input', self._test_module_config.input_img,
|
||||
'--width', self._test_module_config.frame_width,
|
||||
'--height', self._test_module_config.frame_height,
|
||||
'--scale', self._test_module_config.scale,
|
||||
'--mean', *self._test_module_config.mean,
|
||||
'--std', *self._test_module_config.std,
|
||||
'--classes', self._test_module_config.classes,
|
||||
]
|
||||
|
||||
if self._default_input_blob_preproc["rgb"]:
|
||||
self._test_module_param_list.append('--rgb')
|
||||
|
||||
self._configure_test_module_params()
|
||||
|
||||
self._test_module.main(
|
||||
self._test_module_param_list
|
||||
)
|
||||
|
||||
if cmd_args.evaluate:
|
||||
original_model_name = model_names[0]
|
||||
dnn_model_name = model_names[1]
|
||||
|
||||
self.run_test_pipeline(
|
||||
[
|
||||
self._model_processor(model_dict[original_model_name], original_model_name),
|
||||
self._dnn_model_processor(model_dict[dnn_model_name], dnn_model_name)
|
||||
],
|
||||
original_model_name.replace(" ", "_")
|
||||
)
|
||||
|
||||
def run_test_pipeline(
|
||||
self,
|
||||
models_list,
|
||||
formatted_exp_name,
|
||||
is_plot_acc=True
|
||||
):
|
||||
log_path, logs_dir = self._configure_eval_log(formatted_exp_name)
|
||||
|
||||
print(
|
||||
"===== Running evaluation of the model with the following params:\n"
|
||||
"\t* val data location: {}\n"
|
||||
"\t* log file location: {}\n".format(
|
||||
self.test_config.img_root_dir,
|
||||
log_path
|
||||
)
|
||||
)
|
||||
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
|
||||
self._configure_acc_eval(log_path)
|
||||
self._accuracy_evaluator.process(models_list, self._data_fetcher)
|
||||
|
||||
if is_plot_acc:
|
||||
plot_acc(
|
||||
np.array(self._accuracy_evaluator.general_inference_time),
|
||||
formatted_exp_name
|
||||
)
|
||||
|
||||
print("===== End of the evaluation pipeline =====")
|
||||
|
||||
def _configure_acc_eval(self, log_path):
|
||||
pass
|
||||
|
||||
def _configure_test_module_params(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _configure_eval_log(formatted_exp_name):
|
||||
common_test_config = CommonConfig()
|
||||
return common_test_config.log_file_path.format(formatted_exp_name), common_test_config.logs_dir
|
||||
Reference in New Issue
Block a user