init - 初始化项目

This commit is contained in:
Lee Nony
2022-05-06 01:58:53 +08:00
commit 90a5cc7cb6
6772 changed files with 2837787 additions and 0 deletions

View File

@@ -0,0 +1,41 @@
import numpy as np
from ..accuracy_eval import SemSegmEvaluation
from ..utils import plot_acc
def test_segm_models(models_list, data_fetcher, eval_params, experiment_name, is_print_eval_params=True,
is_plot_acc=True):
if is_print_eval_params:
print(
"===== Running evaluation of the classification models with the following params:\n"
"\t0. val data location: {}\n"
"\t1. val data labels: {}\n"
"\t2. frame size: {}\n"
"\t3. batch size: {}\n"
"\t4. transform to RGB: {}\n"
"\t5. log file location: {}\n".format(
eval_params.imgs_segm_dir,
eval_params.img_cls_file,
eval_params.frame_size,
eval_params.batch_size,
eval_params.bgr_to_rgb,
eval_params.log
)
)
accuracy_evaluator = SemSegmEvaluation(eval_params.log, eval_params.img_cls_file, eval_params.batch_size)
accuracy_evaluator.process(models_list, data_fetcher)
accuracy_array = np.array(accuracy_evaluator.general_fw_accuracy)
print(
"===== End of processing. Accuracy results:\n"
"\t1. max accuracy (top-5) for the original model: {}\n"
"\t2. max accuracy (top-5) for the DNN model: {}\n".format(
max(accuracy_array[:, 0]),
max(accuracy_array[:, 1]),
)
)
if is_plot_acc:
plot_acc(accuracy_array, experiment_name)

View File

@@ -0,0 +1,59 @@
from torchvision import models
from ..pytorch_model import (
PyTorchModelPreparer,
PyTorchModelProcessor,
PyTorchDnnModelProcessor
)
from ...common.utils import set_pytorch_env, create_parser
class PyTorchFcnResNet50(PyTorchModelPreparer):
def __init__(self, model_name, original_model):
super(PyTorchFcnResNet50, self).__init__(model_name, original_model)
def main():
parser = create_parser()
cmd_args = parser.parse_args()
set_pytorch_env()
# Test the base process of model retrieval
resnets = PyTorchFcnResNet50(
model_name="resnet50",
original_model=models.segmentation.fcn_resnet50(pretrained=True)
)
model_dict = resnets.get_prepared_models()
if cmd_args.is_evaluate:
from ...common.test_config import TestConfig
from ...common.accuracy_eval import PASCALDataFetch
from ...common.test.voc_segm_test import test_segm_models
eval_params = TestConfig()
model_names = list(model_dict.keys())
original_model_name = model_names[0]
dnn_model_name = model_names[1]
#img_dir, segm_dir, names_file, segm_cls_colors_file, preproc)
data_fetcher = PASCALDataFetch(
imgs_dir=eval_params.imgs_segm_dir,
frame_size=eval_params.frame_size,
bgr_to_rgb=eval_params.bgr_to_rgb,
)
test_segm_models(
[
PyTorchModelProcessor(model_dict[original_model_name], original_model_name),
PyTorchDnnModelProcessor(model_dict[dnn_model_name], dnn_model_name)
],
data_fetcher,
eval_params,
original_model_name
)
if __name__ == "__main__":
main()