补全yolov5的代码

This commit is contained in:
2025-03-13 17:48:59 +08:00
parent 6daea23f0a
commit 9d369b9898
42 changed files with 8571 additions and 171 deletions

View File

@ -17,7 +17,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
from yolov5.utils.general import LOGGER, check_version, colorstr, file_date, git_describe
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv("RANK", -1))
@ -68,7 +68,7 @@ def smart_DDP(model):
def reshape_classifier_output(model, n=1000):
"""Reshapes last layer of model to match class count 'n', supporting Classify, Linear, Sequential types."""
from models.common import Classify
from yolov5.models.common import Classify
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
if isinstance(m, Classify): # YOLOv5 Classify() head