from SetParams.DataType.TypeDef import * from SetParams.DataType.BaseParam import * class TrainParams(BaseParam): """ 训练参数示例,使用时需要根据实际情况修改,如有其他参数(如推理函数的参数)需要自定义,并继承自 BaseParam # 例如训练的时候需要传入以下参数 gpu_num = IntType("gpu_num", 2) support_cpu = BoolType("support_cpu", True) labels = EnumType("labels", 0, ["dog", "cat"]) labels.default = 0 self.add_param(gpu_num) self.add_param(support_cpu) self.add_param(labels) """ def __init__(self): super().__init__() num_classes = IntType("num_classes", default=9) lr=FloatType("lr",default=0.005) lr_schedulerList = ListType("lr_schedulerList",default=[30,60]) device = StringType("device",default="cpu") DatasetDir=StringType("DatasetDir",default="./datasets/M006B_duanmian") saveModDir=StringType("saveModDir",default="./saved_model/M006B_duanmian.pt") resumeModPath=StringType("resumeModPath",default="") epochnum=IntType("epochnum", default=100) saveEpoch=IntType("saveEpoch", default=1) self.add_param(num_classes) self.add_param(lr) self.add_param(lr_schedulerList) self.add_param(device) self.add_param(DatasetDir) self.add_param(saveModDir) self.add_param(resumeModPath) self.add_param(epochnum) self.add_param(saveEpoch) if __name__ == "__main__": params = TrainParams() params.save_to_file('./SetParams/TrainParams.json')