补全yolov5的代码
This commit is contained in:
@ -17,7 +17,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
|
||||
from yolov5.utils.general import LOGGER, check_version, colorstr, file_date, git_describe
|
||||
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv("RANK", -1))
|
||||
@ -68,7 +68,7 @@ def smart_DDP(model):
|
||||
|
||||
def reshape_classifier_output(model, n=1000):
|
||||
"""Reshapes last layer of model to match class count 'n', supporting Classify, Linear, Sequential types."""
|
||||
from models.common import Classify
|
||||
from yolov5.models.common import Classify
|
||||
|
||||
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
||||
if isinstance(m, Classify): # YOLOv5 Classify() head
|
||||
|
Reference in New Issue
Block a user