RODY/setparams.py
552068321@qq.com 6f7de660aa first commit
2022-11-04 17:37:08 +08:00

47 lines
1.6 KiB
Python

from SetParams.DataType.TypeDef import *
from SetParams.DataType.BaseParam import *
class TrainParams(BaseParam):
"""
训练参数示例,使用时需要根据实际情况修改,如有其他参数(如推理函数的参数)需要自定义,并继承自 BaseParam
# 例如训练的时候需要传入以下参数
gpu_num = IntType("gpu_num", 2)
support_cpu = BoolType("support_cpu", True)
labels = EnumType("labels", 0, ["dog", "cat"])
labels.default = 0
self.add_param(gpu_num)
self.add_param(support_cpu)
self.add_param(labels)
"""
def __init__(self):
super().__init__()
num_classes = IntType("num_classes", default=9)
lr=FloatType("lr",default=0.005)
lr_schedulerList = ListType("lr_schedulerList",default=[30,60])
device = StringType("device",default="cpu")
DatasetDir=StringType("DatasetDir",default="./datasets/M006B_duanmian")
saveModDir=StringType("saveModDir",default="./saved_model/M006B_duanmian.pt")
resumeModPath=StringType("resumeModPath",default="")
epochnum=IntType("epochnum", default=100)
saveEpoch=IntType("saveEpoch", default=1)
self.add_param(num_classes)
self.add_param(lr)
self.add_param(lr_schedulerList)
self.add_param(device)
self.add_param(DatasetDir)
self.add_param(saveModDir)
self.add_param(resumeModPath)
self.add_param(epochnum)
self.add_param(saveEpoch)
if __name__ == "__main__":
params = TrainParams()
params.save_to_file('./SetParams/TrainParams.json')