init - 初始化项目
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
import numpy as np
|
||||
|
||||
from ..accuracy_eval import SemSegmEvaluation
|
||||
from ..utils import plot_acc
|
||||
|
||||
|
||||
def test_segm_models(models_list, data_fetcher, eval_params, experiment_name, is_print_eval_params=True,
|
||||
is_plot_acc=True):
|
||||
if is_print_eval_params:
|
||||
print(
|
||||
"===== Running evaluation of the classification models with the following params:\n"
|
||||
"\t0. val data location: {}\n"
|
||||
"\t1. val data labels: {}\n"
|
||||
"\t2. frame size: {}\n"
|
||||
"\t3. batch size: {}\n"
|
||||
"\t4. transform to RGB: {}\n"
|
||||
"\t5. log file location: {}\n".format(
|
||||
eval_params.imgs_segm_dir,
|
||||
eval_params.img_cls_file,
|
||||
eval_params.frame_size,
|
||||
eval_params.batch_size,
|
||||
eval_params.bgr_to_rgb,
|
||||
eval_params.log
|
||||
)
|
||||
)
|
||||
|
||||
accuracy_evaluator = SemSegmEvaluation(eval_params.log, eval_params.img_cls_file, eval_params.batch_size)
|
||||
accuracy_evaluator.process(models_list, data_fetcher)
|
||||
accuracy_array = np.array(accuracy_evaluator.general_fw_accuracy)
|
||||
|
||||
print(
|
||||
"===== End of processing. Accuracy results:\n"
|
||||
"\t1. max accuracy (top-5) for the original model: {}\n"
|
||||
"\t2. max accuracy (top-5) for the DNN model: {}\n".format(
|
||||
max(accuracy_array[:, 0]),
|
||||
max(accuracy_array[:, 1]),
|
||||
)
|
||||
)
|
||||
|
||||
if is_plot_acc:
|
||||
plot_acc(accuracy_array, experiment_name)
|
||||
@@ -0,0 +1,59 @@
|
||||
from torchvision import models
|
||||
|
||||
from ..pytorch_model import (
|
||||
PyTorchModelPreparer,
|
||||
PyTorchModelProcessor,
|
||||
PyTorchDnnModelProcessor
|
||||
)
|
||||
from ...common.utils import set_pytorch_env, create_parser
|
||||
|
||||
|
||||
class PyTorchFcnResNet50(PyTorchModelPreparer):
|
||||
def __init__(self, model_name, original_model):
|
||||
super(PyTorchFcnResNet50, self).__init__(model_name, original_model)
|
||||
|
||||
|
||||
def main():
|
||||
parser = create_parser()
|
||||
cmd_args = parser.parse_args()
|
||||
set_pytorch_env()
|
||||
|
||||
# Test the base process of model retrieval
|
||||
resnets = PyTorchFcnResNet50(
|
||||
model_name="resnet50",
|
||||
original_model=models.segmentation.fcn_resnet50(pretrained=True)
|
||||
)
|
||||
model_dict = resnets.get_prepared_models()
|
||||
|
||||
if cmd_args.is_evaluate:
|
||||
from ...common.test_config import TestConfig
|
||||
from ...common.accuracy_eval import PASCALDataFetch
|
||||
from ...common.test.voc_segm_test import test_segm_models
|
||||
|
||||
eval_params = TestConfig()
|
||||
|
||||
model_names = list(model_dict.keys())
|
||||
original_model_name = model_names[0]
|
||||
dnn_model_name = model_names[1]
|
||||
|
||||
#img_dir, segm_dir, names_file, segm_cls_colors_file, preproc)
|
||||
data_fetcher = PASCALDataFetch(
|
||||
imgs_dir=eval_params.imgs_segm_dir,
|
||||
frame_size=eval_params.frame_size,
|
||||
bgr_to_rgb=eval_params.bgr_to_rgb,
|
||||
|
||||
)
|
||||
|
||||
test_segm_models(
|
||||
[
|
||||
PyTorchModelProcessor(model_dict[original_model_name], original_model_name),
|
||||
PyTorchDnnModelProcessor(model_dict[dnn_model_name], dnn_model_name)
|
||||
],
|
||||
data_fetcher,
|
||||
eval_params,
|
||||
original_model_name
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user