RODY/SetParams_Demo.py

66 lines
2.8 KiB
Python
Raw Normal View History

2022-11-04 17:37:08 +08:00
import sys
import json
from setparams import TrainParams
def ppx_train(params,id):
ppx_num_classes = params.get('num_classes').value
ppx_epoch = params.get('epochnum').value
ppx_saveEpoch = params.get('saveEpoch').value
ppx_device = params.get('device').value
ppx_DatasetDir = params.get('DatasetDir').value
ppx_saveModDir = params.get('saveModDir').value
ppx_lr = params.get('lr').value
ppx_lr_schedulerList = params.get('lr_schedulerList').value
ppx_resumeModPath = params.get('resumeModPath').value
ppx_id = id
if ppx_resumeModPath == '': ppx_resumeModPath= "/mnt/sdc/algorithm/AICheck-MaskRCNN/app/maskrcnn_ppx/pretrain/mask_rcnn_r50_fpn_2x_coco.pdparams" #'COCO'
model.train(
num_epochs=ppx_epoch, #***
save_interval_epochs=ppx_saveEpoch, #***
train_dataset=train_dataset, #***
train_batch_size=2,
eval_dataset=eval_dataset, #***
pretrain_weights=ppx_resumeModPath, #***
learning_rate=ppx_lr, #***
lr_decay_epochs=ppx_lr_schedulerList, #***
warmup_steps=10,
warmup_start_lr=0.0,
save_dir=ppx_saveModDir, #***
use_vdl=True)
#@start_train_algorithm
def main(params_str):
params = TrainParams()
params.read_from_str(params_str)
ppx_train(params,id='1')
if __name__ == "__main__":
params_list = [
{"index":0,"name":"num_classes","value":9,"description":'类别数(加背景)',"default":9,"type":"I", "show":True},
{"index":1,"name":"lr","value":0.0003,"description":'学习率',"default":0.0001,"type":"F", "show":True},
{"index":2,"name":"lr_schedulerList","value":[30,60],"description":'学习率衰减轮次',"default":[30,60],"type":"L", "show":True},
{"index":3,"name":"device","value":"cpu","description":'训练核心',"default":"cpu","type":"S", "show":True},
{"index":4,"name":"DatasetDir","value":"/mnt/sdc/algorithm/PaddleX/datasets/DDX_nb","description":'数据集路径',"default":"/mnt/sdc/algorithm/PaddleX/datasets/DDX_nb","type":"S", "show":False},
{"index":5,"name":"saveModDir","value":"/mnt/sdc/algorithm/PaddleX/output","description":'保存模型路径',"default":"/mnt/sdc/algorithm/PaddleX/output","type":"S", "show":False},
{"index":6,"name":"resumeModPath","value":'',"description":'继续训练路径',"default":'',"type":"S", "show":False},
{"index":7,"name":"epochnum","value":100,"description":'训练轮次',"default":100,"type":"I", "show":True},
{"index":8,"name":"saveEpoch","value":2,"description":'保存模型轮次',"default":2,"type":"I", "show":True}]
params_str = json.dumps(params_list)
print(params_str)
main(params_str)