每轮返回中间值
This commit is contained in:
parent
bb163578d6
commit
5d3b268958
@ -356,6 +356,7 @@ def train(hyp, opt, device, data_list,id,getsomething,callbacks): # hyp is path
|
||||
if RANK in {-1, 0}:
|
||||
pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
|
||||
optimizer.zero_grad()
|
||||
tempLoss = 0
|
||||
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
||||
#callbacks.run('on_train_batch_start')
|
||||
print("start get global_var")
|
||||
@ -406,7 +407,8 @@ def train(hyp, opt, device, data_list,id,getsomething,callbacks): # hyp is path
|
||||
|
||||
# Backward
|
||||
scaler.scale(loss).backward()
|
||||
report_cellback(epoch, epochs, float(loss))
|
||||
tempLoss = float(loss)
|
||||
|
||||
|
||||
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
||||
if ni - last_opt_step >= accumulate:
|
||||
@ -429,7 +431,7 @@ def train(hyp, opt, device, data_list,id,getsomething,callbacks): # hyp is path
|
||||
if callbacks.stop_training:
|
||||
return
|
||||
# end batch ------------------------------------------------------------------------------------------------
|
||||
|
||||
report_cellback(epoch, epochs, tempLoss)
|
||||
# Scheduler
|
||||
lr = [x['lr'] for x in optimizer.param_groups] # for loggers
|
||||
scheduler.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user