This commit is contained in:
2022-11-18 14:31:42 +08:00
6 changed files with 136 additions and 21 deletions

View File

@ -566,11 +566,12 @@ def run(
return f # return list of exported files/dirs
def parse_opt(weights,device):
def parse_opt(weights,device,imgsz):
imgsz = [imgsz,imgsz]
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--weights', nargs='+', type=str, default=weights, help='model.pt path(s)')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=imgsz, help='image (h, w)') #default=[640, 640]
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--device', default=device, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
@ -604,13 +605,13 @@ def main(opt):
f = run(**vars(opt))
return f
def Start_Model_Export(weights,device):
def Start_Model_Export(weights,device,imgsz):
# 判断cpu or gpu
if device == 'gpu':
device = '0'
else:
device = 'cpu'
opt = parse_opt(weights,device)
opt = parse_opt(weights,device,imgsz)
f = main(opt)
return f

View File

@ -61,6 +61,8 @@ from app.yolov5.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel,
smart_resume, torch_distributed_zero_first)
from app.schemas.TrainResult import Report, ProcessValueList
from app.controller.AlgorithmController import algorithm_process_value_websocket
from app.configs import global_var
from app.utils.websocket_tool import manager
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
@ -72,7 +74,7 @@ def yaml_rewrite(file='data.yaml',data_list=[]):
with open(file, errors='ignore') as f:
coco_dict = yaml.safe_load(f)
#读取img_label_type.json
with open(data_list[3], 'r') as f:
with open(data_list[3], 'r',encoding='UTF-8') as f:
class_dict = json.load(f)
f.close()
classes = class_dict["classes"]
@ -302,7 +304,17 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
report = Report(rate_of_progess=0, precision=[process_value_list],
id=id, sum=epochs, progress=0,
num_train_img=train_num,
train_mod_savepath=best)
train_mod_savepath=best,
alg_code="R-ODY")
def kill_return():
"""
算法中断,返回
"""
id = report.id
data = report.dict()
data_res = {'code': 1, "type": 'kill', 'msg': 'fail', 'data': data}
manager.send_message_proj_json(message=data_res, id=id)
@algorithm_process_value_websocket()
def report_cellback(i, num_epochs, reportAccu):
@ -314,6 +326,12 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
###################结束#######################
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
#callbacks.run('on_train_epoch_start')
print("start get global_var")
ifkill = global_var.get_value(report.id)
print("get global_var down:",ifkill)
if ifkill:
kill_return()
break
model.train()
# Update image weights (optional, single-GPU only)
@ -334,6 +352,12 @@ def train(hyp, opt, device, data_list,id,callbacks): # hyp is path/to/hyp.yaml
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
#callbacks.run('on_train_batch_start')
print("start get global_var")
ifkill = global_var.get_value(report.id)
print("get global_var down:",ifkill)
if ifkill:
kill_return()
break
if targets.shape[0] == 0:
targets = [[0.00000, 5.00000, 0.97002, 0.24679, 0.05995, 0.05553],
[0.00000, 7.00000, 0.95097, 0.32007, 0.04188, 0.02549],

View File

@ -430,9 +430,12 @@ class LoadImagesAndLabels(Dataset):
self.label_files = img2label_paths(self.im_files) # labels
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == self.cache_version # matches current version
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical has
# cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
# assert cache['version'] == self.cache_version # matches current version
# assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical has
if os.path.exists(cache_path):
os.remove(cache_path)
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
except Exception:
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops