目录
--tools--train_net.py PS:这个是训练的主程序
参数输入部分
参数修改部分
获取数据和训练网络
def parse_args():"""Parse input arguments"""parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')parser.add_argument('--gpu', dest='gpu_id',help='GPU device id to use [0]',default=0, type=省略..................................if len(sys.argv) == 1: parser.print_help()it(1)args = parser.parse_args() return args#这里是人机交互的界面,gpu是外面输入的关键字,gpu_id是程序里面的关键字
if args.cfg_file is not None:cfg_from_file(args.cfg_file)if args.set_cfgs is not None:cfg_from_list(args.set_cfgs)cfg.GPU_ID = args.gpu_id#一种写好cfg_file文件,导入。另一种是直接赋值
cfg_file文件:
EXP_DIR: faster_rcnn_end2end
TRAIN:HAS_RPN: TrueIMS_PER_BATCH: 1BBOX_NORMALIZE_TARGETS_PRECOMPUTED: TrueRPN_POSITIVE_OVERLAP: 0.7RPN_BATCHSIZE: 256PROPOSAL_METHOD: gtBG_THRESH_LO: 0.0
TEST:HAS_RPN: True
文件导入程序cfg_from_file():
def _merge_a_into_b(a, b):"""Merge config dictionary a into config dictionary b, clobbering theoptions in b whenever they are also specified in a."""if type(a) is not edict:returnfor k, v in a.iteritems():# a must specify keys that are in bif not b.has_key(k):raise KeyError('{} is not a valid config key'.format(k))# the types must match, tooold_type = type(b[k])if old_type is not type(v):if isinstance(b[k], np.ndarray):v = np.array(v, dtype=b[k].dtype)else:raise ValueError(('Type mismatch ({} vs. {}) ''for config key: {}').format(type(b[k]),type(v), k))# recursively merge dictsif type(v) is edict:try:_merge_a_into_b(a[k], b[k])except:print('Error under config key: {}'.format(k))raiseelse:b[k] = vdef cfg_from_file(filename):"""Load a config file and merge it into the default options."""import yamlwith open(filename, 'r') as f:yaml_cfg = edict(yaml.load(f))_merge_a_into_b(yaml_cfg, __C)
imdb, roidb = combined_roidb(args.imdb_name) #获取图片和真实框print '{:d} roidb entries'.format(len(roidb))output_dir = get_output_dir(imdb) #获取输出路径print 'Output will be saved to `{:s}`'.format(output_dir)train_net(args.solver, roidb, output_dir, #训练网络pretrained_model=args.pretrained_model,max_iters=args.max_iters)
def combined_roidb(imdb_names):def get_roidb(imdb_name):imdb = get_imdb(imdb_name)print 'Loaded dataset `{:s}` for training'.format(imdb.name)imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)roidb = get_training_roidb(imdb)return roidbroidbs = [get_roidb(s) for s in imdb_names.split('+')]#这里的imdb_nams是voc_2007_trainval+voc_2012_trainvalroidb = roidbs[0]if len(roidbs) > 1:for r in roidbs[1:]:d(r)imdb = datasets.imdb.imdb(imdb_names)else:imdb = get_imdb(imdb_names)return imdb, roidb
这里主要设计到,图片的读写和接口的写法,参考。
本文发布于:2024-01-28 08:00:32,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/17064000345961.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |