完成推理模块的转移
This commit is contained in:
parent
74e8f0d415
commit
b0379e64c9
@ -12,6 +12,7 @@ from apps.vadmin.record.views import app as vadmin_record_app
|
||||
from apps.vadmin.help.views import app as vadmin_help_app
|
||||
from apps.business.project.views import app as project_app
|
||||
from apps.business.train.views import app as train_app
|
||||
from apps.business.detect.views import app as detect_app
|
||||
|
||||
|
||||
# 引入应用中的路由
|
||||
@ -23,4 +24,5 @@ urlpatterns = [
|
||||
{"ApiRouter": vadmin_help_app, "prefix": "/vadmin/help", "tags": ["帮助中心管理"]},
|
||||
{"ApiRouter": project_app, "prefix": "/business/project", "tags": ["项目管理"]},
|
||||
{"ApiRouter": train_app, "prefix": "/business/train", "tags": ["训练管理"]},
|
||||
{"ApiRouter": detect_app, "prefix": "/business/detect", "tags": ["推理管理"]},
|
||||
]
|
||||
|
@ -5,10 +5,18 @@
|
||||
# @File : crud.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 数据访问层
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from core.crud import DalBase
|
||||
from . import schemas, models
|
||||
from utils.random_utils import random_str
|
||||
from utils import os_utils as os
|
||||
from application.settings import detect_url
|
||||
from utils.huawei_obs import ObsClient
|
||||
from utils import status
|
||||
from core.exception import CustomException
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
|
||||
class ProjectDetectDal(DalBase):
|
||||
@ -17,17 +25,86 @@ class ProjectDetectDal(DalBase):
|
||||
super(ProjectDetectDal, self).__init__()
|
||||
self.db = db
|
||||
self.model = models.ProjectDetect
|
||||
self.schema = schemas.ProjectDetectSimpleOut
|
||||
self.schema = schemas.ProjectDetectOut
|
||||
|
||||
async def check_name(self, name: str,project_id: int) -> bool:
|
||||
"""
|
||||
校验推理集合名称
|
||||
"""
|
||||
count = self.get_count(
|
||||
v_where=[models.ProjectDetect.project_id == project_id, models.ProjectDetect.detect_name == name])
|
||||
return count > 0
|
||||
|
||||
async def add_detect(self, data: schemas.ProjectDetectIn):
|
||||
"""
|
||||
新增集合
|
||||
"""
|
||||
detect = models.ProjectDetect(**data.model_dump())
|
||||
detect.detect_no = random_str(6)
|
||||
detect.detect_version = 0
|
||||
detect.detect_status = '0'
|
||||
url = os.create_folder(detect_url, detect.detect_no, 'images')
|
||||
detect.folder_url = url
|
||||
await self.create_data(data)
|
||||
return detect
|
||||
|
||||
async def delete_detects(self, ids: list[int]):
|
||||
"""
|
||||
删除集合数据+文件夹的文件夹+每次推理日志的文件
|
||||
"""
|
||||
for id_ in ids:
|
||||
detect_info = await self.get_data(data_id=id_)
|
||||
if detect_info.file_type != 'rtsp':
|
||||
os.delete_paths(detect_info.folder_url)
|
||||
logs = await ProjectDetectLogDal(self.db).get_datas(v_where=[models.ProjectDetectLog.detect_id == ids])
|
||||
for log in logs:
|
||||
os.delete_paths(log.folder_url)
|
||||
await self.delete_datas(ids=ids, v_soft=False)
|
||||
|
||||
|
||||
class ProjectDetectImgDal(DalBase):
|
||||
class ProjectDetectFileDal(DalBase):
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
super(ProjectDetectImgDal, self).__init__()
|
||||
super(ProjectDetectFileDal, self).__init__()
|
||||
self.db = db
|
||||
self.model = models.ProjectDetectImg
|
||||
self.schema = schemas.ProjectDetectImgSimpleOut
|
||||
self.model = models.ProjectDetectFile
|
||||
self.schema = schemas.ProjectDetectFileOut
|
||||
|
||||
async def file_count(self, detect_id: int) -> int:
|
||||
count = self.get_count(v_where=[models.ProjectDetectFile.detect_id == detect_id])
|
||||
return count
|
||||
|
||||
async def add_file(self, detect: models.ProjectDetect, files: list[UploadFile]):
|
||||
images = []
|
||||
for file in files:
|
||||
image = models.ProjectDetectFile()
|
||||
image.detect_id = detect.id
|
||||
image.file_name = file.filename
|
||||
# 保存原图
|
||||
path = os.save_images(detect.folder_url, file=file)
|
||||
image.image_url = path
|
||||
# 上传到obs
|
||||
object_key = detect.detect_no + '/' + file.filename
|
||||
success, key, url = ObsClient.put_file(object_key=object_key, file_path=path)
|
||||
if success:
|
||||
image.object_key = object_key
|
||||
image.thumb_image_url = url
|
||||
else:
|
||||
raise CustomException("obs上传失败", code=status.HTTP_ERROR)
|
||||
images.append(image)
|
||||
await self.create_datas(images)
|
||||
|
||||
async def delete_files(self, ids: list[int]):
|
||||
file_urls = []
|
||||
object_keys = []
|
||||
for id_ in ids:
|
||||
file = self.get_data(data_id=id_)
|
||||
if file:
|
||||
file_urls.append(file.file_url)
|
||||
object_keys.append(file.object_key)
|
||||
os.delete_paths(file_urls)
|
||||
ObsClient.del_objects(object_keys)
|
||||
await self.delete_datas(ids, v_soft=False)
|
||||
|
||||
class ProjectDetectLogDal(DalBase):
|
||||
|
||||
@ -38,10 +115,10 @@ class ProjectDetectLogDal(DalBase):
|
||||
self.schema = schemas.ProjectDetectLogSimpleOut
|
||||
|
||||
|
||||
class ProjectDetectLogImgDal(DalBase):
|
||||
class ProjectDetectLogFileDal(DalBase):
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
super(ProjectDetectLogImgDal, self).__init__()
|
||||
super(ProjectDetectLogFileDal, self).__init__()
|
||||
self.db = db
|
||||
self.model = models.ProjectDetectLogImg
|
||||
self.schema = schemas.ProjectDetectLogImgSimpleOut
|
||||
self.model = models.ProjectDetectLogFile
|
||||
self.schema = schemas.ProjectDetectLogFileOut
|
||||
|
@ -0,0 +1 @@
|
||||
from .detect import ProjectDetect, ProjectDetectFile, ProjectDetectLog, ProjectDetectLogFile
|
@ -22,17 +22,18 @@ class ProjectDetect(BaseModel):
|
||||
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
|
||||
class ProjectDetectImg(BaseModel):
|
||||
class ProjectDetectFile(BaseModel):
|
||||
"""
|
||||
待推理图片
|
||||
"""
|
||||
__tablename__ = "project_detect_img"
|
||||
__table_args__ = ({'comment': '待推理图片'})
|
||||
__tablename__ = "project_detect_file"
|
||||
__table_args__ = ({'comment': '待推理文件'})
|
||||
|
||||
detect_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
object_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
thumb_file_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
|
||||
@ -55,13 +56,13 @@ class ProjectDetectLog(BaseModel):
|
||||
user_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
|
||||
class ProjectDetectLogImg(BaseModel):
|
||||
class ProjectDetectLogFile(BaseModel):
|
||||
"""
|
||||
推理完成的图片
|
||||
"""
|
||||
__tablename__ = "project_detect_log_img"
|
||||
__tablename__ = "project_detect_log_file"
|
||||
__table_args__ = ({'comment': '项目训练版本信息表'})
|
||||
|
||||
log_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
@ -1,4 +1,4 @@
|
||||
from .project_detect import ProjectDetectParams
|
||||
from .project_detect_img import ProjectDetectImgParams
|
||||
from .project_detect_file import ProjectDetectFileParams
|
||||
from .project_detect_log import ProjectDetectLogParams
|
||||
from .project_detect_log_img import ProjectDetectLogImgParams
|
||||
from .project_detect_log_file import ProjectDetectLogFileParams
|
||||
|
@ -6,10 +6,14 @@
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理集合信息
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Query
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectParams(QueryParams):
|
||||
def __init__(self, params: Paging = Depends()):
|
||||
def __init__(
|
||||
self,
|
||||
project_id: int | None = Query(None, title="项目id"),
|
||||
params: Paging = Depends()):
|
||||
super().__init__(params)
|
||||
self.project_id = project_id
|
||||
|
19
apps/business/detect/params/project_detect_file.py
Normal file
19
apps/business/detect/params/project_detect_file.py
Normal file
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:30
|
||||
# @File : project_detect_file.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理集合图片信息
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectFileParams(QueryParams):
|
||||
def __init__(
|
||||
self,
|
||||
detect_id: int | None = Query(0, title="推理集合id"),
|
||||
params: Paging = Depends()):
|
||||
super().__init__(params)
|
||||
self.detect_id = detect_id
|
@ -1,15 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:30
|
||||
# @File : project_detect_img.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理集合图片信息
|
||||
|
||||
from fastapi import Depends
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectImgParams(QueryParams):
|
||||
def __init__(self, params: Paging = Depends()):
|
||||
super().__init__(params)
|
@ -6,10 +6,14 @@
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理记录信息
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Query
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectLogParams(QueryParams):
|
||||
def __init__(self, params: Paging = Depends()):
|
||||
def __init__(
|
||||
self,
|
||||
detect_id: int | None = Query(0, title="推理集合id"),
|
||||
params: Paging = Depends()):
|
||||
super().__init__(params)
|
||||
self.detect_id = detect_id
|
||||
|
19
apps/business/detect/params/project_detect_log_file.py
Normal file
19
apps/business/detect/params/project_detect_log_file.py
Normal file
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:31
|
||||
# @File : project_detect_log_file.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理记录图片信息
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectLogFileParams(QueryParams):
|
||||
def __init__(
|
||||
self,
|
||||
log_id: int | None = Query(0, title="推理记录id"),
|
||||
params: Paging = Depends()):
|
||||
super().__init__(params)
|
||||
self.log_id = log_id
|
@ -1,15 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:31
|
||||
# @File : project_detect_log_img.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 项目推理记录图片信息
|
||||
|
||||
from fastapi import Depends
|
||||
from core.dependencies import Paging, QueryParams
|
||||
|
||||
|
||||
class ProjectDetectLogImgParams(QueryParams):
|
||||
def __init__(self, params: Paging = Depends()):
|
||||
super().__init__(params)
|
@ -1,4 +1,4 @@
|
||||
from .project_detect import ProjectDetect, ProjectDetectSimpleOut
|
||||
from .project_detect_img import ProjectDetectImg, ProjectDetectImgSimpleOut
|
||||
from .project_detect_log import ProjectDetectLog, ProjectDetectLogSimpleOut
|
||||
from .project_detect_log_img import ProjectDetectLogImg, ProjectDetectLogImgSimpleOut
|
||||
from .project_detect import ProjectDetectIn, ProjectDetectPager, ProjectDetectOut, ProjectDetectList
|
||||
from .project_detect_file import ProjectDetectFilePager, ProjectDetectFileOut
|
||||
from .project_detect_log import ProjectDetectLogIn, ProjectDetectLogOut
|
||||
from .project_detect_log_file import ProjectDetectLogFileOut
|
||||
|
@ -7,24 +7,43 @@
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from core.data_types import DatetimeStr
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectDetect(BaseModel):
|
||||
project_id: int = Field(..., title="None")
|
||||
detect_name: str = Field(..., title="None")
|
||||
detect_version: int = Field(..., title="None")
|
||||
detect_no: str = Field(..., title="None")
|
||||
detect_status: int = Field(..., title="None")
|
||||
file_type: str = Field(..., title="None")
|
||||
folder_url: str = Field(..., title="None")
|
||||
rtsp_url: str = Field(..., title="None")
|
||||
user_id: int = Field(..., title="None")
|
||||
class ProjectDetectIn(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
file_type: Optional[str] = Field('img', description="推理集合文件类别")
|
||||
detect_name: Optional[str] = Field(..., description="推理集合名称")
|
||||
rtsp_url: Optional[str] = Field(None, description="视频流地址")
|
||||
|
||||
|
||||
class ProjectDetectSimpleOut(ProjectDetect):
|
||||
class ProjectDetectPager(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
detect_name: Optional[str] = Field(None, description="推理集合名称")
|
||||
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||
pagerSize: Optional[int] = Field(10, description="每页数量")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(..., title="编号")
|
||||
create_datetime: DatetimeStr = Field(..., title="创建时间")
|
||||
update_datetime: DatetimeStr = Field(..., title="更新时间")
|
||||
|
||||
class ProjectDetectOut(BaseModel):
|
||||
id: Optional[int]
|
||||
project_id: Optional[int]
|
||||
detect_name: Optional[str]
|
||||
detect_no: Optional[str]
|
||||
detect_version: Optional[int]
|
||||
file_type: Optional[str]
|
||||
folder_url: Optional[str]
|
||||
rtsp_url: Optional[str]
|
||||
create_time: Optional[datetime]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectDetectList(BaseModel):
|
||||
id: Optional[int]
|
||||
file_type: Optional[str]
|
||||
detect_name: Optional[str]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
26
apps/business/detect/schemas/project_detect_file.py
Normal file
26
apps/business/detect/schemas/project_detect_file.py
Normal file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:30
|
||||
# @File : project_detect_file.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectDetectFilePager(BaseModel):
|
||||
detect_id: Optional[int] = Field(..., description="训练集合id")
|
||||
pagerNum: Optional[int] = Field(None, description="当前页码")
|
||||
pagerSize: Optional[int] = Field(None, description="每页数量")
|
||||
|
||||
|
||||
class ProjectDetectFileOut(BaseModel):
|
||||
id: Optional[int] = Field(None, description="id")
|
||||
detect_id: Optional[int] = Field(..., description="训练集合id")
|
||||
file_name: Optional[str] = Field(None, description="文件名称")
|
||||
thumb_file_url: Optional[str] = Field(None, description="文件路径")
|
||||
create_time: Optional[datetime] = Field(None, description="上传时间")
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
@ -1,26 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:30
|
||||
# @File : project_detect_img.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from core.data_types import DatetimeStr
|
||||
|
||||
|
||||
class ProjectDetectImg(BaseModel):
|
||||
detect_id: int = Field(..., title="None")
|
||||
file_name: str = Field(..., title="None")
|
||||
image_url: str = Field(..., title="None")
|
||||
thumb_image_url: str = Field(..., title="None")
|
||||
user_id: int = Field(..., title="None")
|
||||
|
||||
|
||||
class ProjectDetectImgSimpleOut(ProjectDetectImg):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(..., title="编号")
|
||||
create_datetime: DatetimeStr = Field(..., title="创建时间")
|
||||
update_datetime: DatetimeStr = Field(..., title="更新时间")
|
@ -7,25 +7,24 @@
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from core.data_types import DatetimeStr
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectDetectLog(BaseModel):
|
||||
detect_id: int = Field(..., title="None")
|
||||
detect_version: str = Field(..., title="None")
|
||||
detect_name: str = Field(..., title="None")
|
||||
train_id: int = Field(..., title="None")
|
||||
train_version: str = Field(..., title="None")
|
||||
pt_type: str = Field(..., title="None")
|
||||
pt_url: str = Field(..., title="None")
|
||||
folder_url: str = Field(..., title="None")
|
||||
detect_folder_url: str = Field(..., title="None")
|
||||
user_id: int = Field(..., title="None")
|
||||
class ProjectDetectLogIn(BaseModel):
|
||||
detect_id: Optional[int] = Field(..., description="推理集合id")
|
||||
train_id: Optional[int] = Field(..., description="训练结果id")
|
||||
pt_type: Optional[str] = Field('best', description="权重文件类型")
|
||||
|
||||
|
||||
class ProjectDetectLogSimpleOut(ProjectDetectLog):
|
||||
class ProjectDetectLogOut(BaseModel):
|
||||
id: Optional[int]
|
||||
detect_id: Optional[int]
|
||||
detect_version: Optional[str]
|
||||
detect_name: Optional[str]
|
||||
train_id: Optional[int]
|
||||
train_version: Optional[str]
|
||||
pt_type: Optional[str]
|
||||
create_time: Optional[datetime]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(..., title="编号")
|
||||
create_datetime: DatetimeStr = Field(..., title="创建时间")
|
||||
update_datetime: DatetimeStr = Field(..., title="更新时间")
|
||||
|
20
apps/business/detect/schemas/project_detect_log_file.py
Normal file
20
apps/business/detect/schemas/project_detect_log_file.py
Normal file
@ -0,0 +1,20 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:31
|
||||
# @File : project_detect_log_file.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProjectDetectLogFileOut(BaseModel):
|
||||
id: Optional[int]
|
||||
file_name: Optional[str]
|
||||
thumb_file_url: Optional[str]
|
||||
create_time: Optional[datetime]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
@ -1,24 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2025/04/03 10:31
|
||||
# @File : project_detect_log_img.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : pydantic 模型,用于数据库序列化操作
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from core.data_types import DatetimeStr
|
||||
|
||||
|
||||
class ProjectDetectLogImg(BaseModel):
|
||||
log_id: int = Field(..., title="None")
|
||||
file_name: str = Field(..., title="None")
|
||||
image_url: str = Field(..., title="None")
|
||||
|
||||
|
||||
class ProjectDetectLogImgSimpleOut(ProjectDetectLogImg):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int = Field(..., title="编号")
|
||||
create_datetime: DatetimeStr = Field(..., title="创建时间")
|
||||
update_datetime: DatetimeStr = Field(..., title="更新时间")
|
226
apps/business/detect/service.py
Normal file
226
apps/business/detect/service.py
Normal file
@ -0,0 +1,226 @@
|
||||
from application.settings import yolo_url, detect_url
|
||||
from utils.websocket_server import room_manager
|
||||
from utils import os_utils as os
|
||||
from . import models, crud, schemas
|
||||
from apps.business.train import models as train_models
|
||||
from utils.yolov5.models.common import DetectMultiBackend
|
||||
from utils.yolov5.utils.torch_utils import select_device
|
||||
from utils.yolov5.utils.dataloaders import LoadStreams
|
||||
from utils.yolov5.utils.general import check_img_size, Profile, non_max_suppression, cv2, scale_boxes
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
import time
|
||||
import torch
|
||||
import asyncio
|
||||
import subprocess
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
async def before_detect(
|
||||
detect_in: schemas.ProjectDetectLogIn,
|
||||
detect: models.ProjectDetect,
|
||||
train: train_models.ProjectTrain,
|
||||
db: AsyncSession):
|
||||
"""
|
||||
开始推理
|
||||
:param detect:
|
||||
:param detect_in:
|
||||
:param train:
|
||||
:param db:
|
||||
:return:
|
||||
"""
|
||||
# 推理版本
|
||||
version_path = 'v' + str(detect.detect_version + 1)
|
||||
|
||||
# 权重文件
|
||||
pt_url = train.best_pt if detect_in.pt_type == 'best' else train.last_pt
|
||||
|
||||
# 推理集合文件路径
|
||||
img_url = detect.folder_url
|
||||
|
||||
out_url = os.file_path(detect_url, detect.detect_no, 'detect')
|
||||
|
||||
# 构建推理记录数据
|
||||
detect_log = models.ProjectDetectLog()
|
||||
detect_log.detect_name = detect.detect_name
|
||||
detect_log.detect_id = detect.id
|
||||
detect_log.detect_version = version_path
|
||||
detect_log.train_id = train.id
|
||||
detect_log.train_version = train.train_version
|
||||
detect_log.pt_type = detect_in.pt_type
|
||||
detect_log.pt_url = pt_url
|
||||
detect_log.folder_url = img_url
|
||||
detect_log.detect_folder_url = out_url
|
||||
await crud.ProjectDetectLogDal(db).create_data(detect_log)
|
||||
return detect_log
|
||||
|
||||
|
||||
async def run_detect_img(
|
||||
weights: str,
|
||||
source: str,
|
||||
project: str,
|
||||
name: str,
|
||||
log_id: int,
|
||||
detect_id: int,
|
||||
db: AsyncSession,
|
||||
rd: Redis):
|
||||
"""
|
||||
执行yolov5的推理
|
||||
:param weights: 权重文件
|
||||
:param source: 图片所在文件
|
||||
:param project: 推理完成的文件位置
|
||||
:param name: 版本名称
|
||||
:param log_id: 日志id
|
||||
:param detect_id: 推理集合id
|
||||
:param db: 数据库session
|
||||
:param rd: Redis
|
||||
:return:
|
||||
"""
|
||||
yolo_path = os.file_path(yolo_url, 'detect.py')
|
||||
room = 'detect_' + str(detect_id)
|
||||
await room_manager.send_to_room(room, f"AiCheck: 模型训练开始,请稍等。。。\n")
|
||||
commend = ["python", '-u', yolo_path, "--weights", weights, "--source", source, "--name", name, "--project",
|
||||
project, "--save-txt", "--conf-thres", "0.4"]
|
||||
is_gpu = rd.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
commend.append("--device", "0")
|
||||
# 启动子进程
|
||||
with subprocess.Popen(
|
||||
commend,
|
||||
bufsize=1, # bufsize=0时,为不缓存;bufsize=1时,按行缓存;bufsize为其他正整数时,为按照近似该正整数的字节数缓存
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # 这里可以显示yolov5训练过程中出现的进度条等信息
|
||||
text=True, # 缓存内容为文本,避免后续编码显示问题
|
||||
encoding='utf-8',
|
||||
) as process:
|
||||
while process.poll() is None:
|
||||
line = process.stdout.readline()
|
||||
process.stdout.flush() # 刷新缓存,防止缓存过多造成卡死
|
||||
if line != '\n':
|
||||
await room_manager.send_to_room(room, line + '\n')
|
||||
|
||||
# 等待进程结束并获取返回码
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
await room_manager.send_to_room(room, 'error')
|
||||
else:
|
||||
await room_manager.send_to_room(room, 'success')
|
||||
detect_files = crud.ProjectDetectFileDal(db).get_data(
|
||||
v_where=[models.ProjectDetectFile.detect_id == detect_id])
|
||||
detect_log_files = []
|
||||
for detect_file in detect_files:
|
||||
detect_log_img = models.ProjectDetectLogFile()
|
||||
detect_log_img.log_id = log_id
|
||||
image_url = os.file_path(project, name, detect_file.file_name)
|
||||
detect_log_img.image_url = image_url
|
||||
detect_log_img.file_name = detect_file.file_name
|
||||
detect_log_files.append(detect_log_img)
|
||||
await crud.ProjectDetectLogFileDal(db).create_datas(detect_log_files)
|
||||
|
||||
|
||||
async def run_detect_rtsp(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
|
||||
"""
|
||||
rtsp 视频流推理
|
||||
:param detect_id: 训练集的id
|
||||
:param weights_pt: 权重文件
|
||||
:param rtsp_url: 视频流地址
|
||||
:param data: yaml文件
|
||||
:param rd: Redis :redis
|
||||
:return:
|
||||
"""
|
||||
room = 'detect_rtsp_' + str(detect_id)
|
||||
# 选择设备(CPU 或 GPU)
|
||||
device = select_device('cpu')
|
||||
is_gpu = rd.get('is_gpu')
|
||||
# 判断是否存在cuda版本
|
||||
if is_gpu == 'True':
|
||||
device = select_device('cuda:0')
|
||||
|
||||
# 加载模型
|
||||
model = DetectMultiBackend(weights_pt, device=device, dnn=False, data=data, fp16=False)
|
||||
|
||||
stride, names, pt = model.stride, model.names, model.pt
|
||||
imgsz = check_img_size((640, 640), s=stride) # check image size
|
||||
|
||||
dataset = LoadStreams(rtsp_url, img_size=imgsz, stride=stride, auto=pt, vid_stride=1)
|
||||
bs = len(dataset)
|
||||
|
||||
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
|
||||
|
||||
seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
|
||||
|
||||
time.sleep(3) # 等待3s,等待websocket进入
|
||||
|
||||
for path, im, im0s, vid_cap, s in dataset:
|
||||
if room_manager.rooms.get(room):
|
||||
with dt[0]:
|
||||
im = torch.from_numpy(im).to(model.device)
|
||||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
if model.xml and im.shape[0] > 1:
|
||||
ims = torch.chunk(im, im.shape[0], 0)
|
||||
|
||||
# Inference
|
||||
with dt[1]:
|
||||
if model.xml and im.shape[0] > 1:
|
||||
pred = None
|
||||
for image in ims:
|
||||
if pred is None:
|
||||
pred = model(image, augment=False, visualize=False).unsqueeze(0)
|
||||
else:
|
||||
pred = torch.cat((pred, model(image, augment=False, visualize=False).unsqueeze(0)),
|
||||
dim=0)
|
||||
pred = [pred, None]
|
||||
else:
|
||||
pred = model(im, augment=False, visualize=False)
|
||||
# NMS
|
||||
with dt[2]:
|
||||
pred = non_max_suppression(pred, 0.45, 0.45, None, False, max_det=1000)
|
||||
|
||||
# Process predictions
|
||||
for i, det in enumerate(pred): # per image
|
||||
p, im0, frame = path[i], im0s[i].copy(), dataset.count
|
||||
annotator = Annotator(im0, line_width=3, example=str(names))
|
||||
if len(det):
|
||||
# Rescale boxes from img_size to im0 size
|
||||
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
||||
|
||||
# Write results
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
c = int(cls) # integer class
|
||||
label = None if False else (names[c] if False else f"{names[c]} {conf:.2f}")
|
||||
annotator.box_label(xyxy, label, color=colors(c, True))
|
||||
|
||||
# Stream results
|
||||
im0 = annotator.result()
|
||||
# 将帧编码为 JPEG
|
||||
ret, jpeg = cv2.imencode('.jpg', im0)
|
||||
if ret:
|
||||
frame_data = jpeg.tobytes()
|
||||
await room_manager.send_stream_to_room(room, frame_data)
|
||||
else:
|
||||
print(room, '结束推理')
|
||||
break
|
||||
|
||||
|
||||
def run_img_loop(weights: str, source: str, project: str, name: str, log_id: int, detect_id: int, db: AsyncSession):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_detect_img(weights, source, project, name, log_id, detect_id, db))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
||||
|
||||
|
||||
def run_rtsp_loop(weights_pt: str, rtsp_url: str, data: str, detect_id: int, rd: Redis):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# 运行异步函数
|
||||
loop.run_until_complete(run_detect_rtsp(weights_pt, rtsp_url, data, detect_id, rd))
|
||||
# 可选: 关闭循环
|
||||
loop.close()
|
@ -5,15 +5,21 @@
|
||||
# @File : views.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 路由,视图文件
|
||||
from core.dependencies import IdList
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from core.database import db_getter
|
||||
from . import schemas, crud, models, params
|
||||
from fastapi import Depends, APIRouter
|
||||
from utils.response import SuccessResponse
|
||||
|
||||
import service
|
||||
from . import schemas, crud, params
|
||||
from core.dependencies import IdList
|
||||
from core.database import redis_getter
|
||||
from utils.websocket_server import room_manager
|
||||
from apps.business.train.crud import ProjectTrainDal
|
||||
from apps.vadmin.auth.utils.current import AllUserAuth
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
from utils.response import SuccessResponse, ErrorResponse
|
||||
|
||||
import threading
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import Depends, APIRouter, Form, UploadFile
|
||||
|
||||
|
||||
app = APIRouter()
|
||||
@ -22,129 +28,120 @@ app = APIRouter()
|
||||
###########################################################
|
||||
# 项目推理集合信息
|
||||
###########################################################
|
||||
@app.get("/project/detect", summary="获取项目推理集合信息列表", tags=["项目推理集合信息"])
|
||||
async def get_project_detect_list(p: params.ProjectDetectParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
@app.get("/list", summary="获取项目推理集合信息列表")
|
||||
async def detect_list(
|
||||
p: params.ProjectDetectParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.post("/project/detect", summary="创建项目推理集合信息", tags=["项目推理集合信息"])
|
||||
async def create_project_detect(data: schemas.ProjectDetect, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectDal(auth.db).create_data(data=data))
|
||||
@app.post("/", summary="创建项目推理集合信息")
|
||||
async def add_detect(
|
||||
data: schemas.ProjectDetectIn,
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
if await detect_dal.check_name(data.detect_name, data.project_id):
|
||||
return ErrorResponse(msg="该项目中存在相同名称的集合")
|
||||
await detect_dal.create_data(data=data)
|
||||
return SuccessResponse(msg="保存成功")
|
||||
|
||||
|
||||
@app.delete("/project/detect", summary="删除项目推理集合信息", description="硬删除", tags=["项目推理集合信息"])
|
||||
async def delete_project_detect_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
@app.delete("/", summary="删除项目推理集合信息")
|
||||
async def delete_detect(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/{data_id}", summary="更新项目推理集合信息", tags=["项目推理集合信息"])
|
||||
async def put_project_detect(data_id: int, data: schemas.ProjectDetect, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/{data_id}", summary="获取项目推理集合信息信息", tags=["项目推理集合信息"])
|
||||
async def get_project_detect(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectDal(db).get_data(data_id, v_schema=schema))
|
||||
|
||||
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理集合图片信息
|
||||
# 项目推理集合文件信息
|
||||
###########################################################
|
||||
@app.get("/project/detect/img", summary="获取项目推理集合图片信息列表", tags=["项目推理集合图片信息"])
|
||||
async def get_project_detect_img_list(p: params.ProjectDetectImgParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectImgDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
@app.get("/file", summary="获取项目推理集合文件信息列表")
|
||||
async def file_list(
|
||||
p: params.ProjectDetectFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
if p.limit:
|
||||
datas, count = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
else:
|
||||
datas = await crud.ProjectDetectFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
||||
@app.post("/project/detect/img", summary="创建项目推理集合图片信息", tags=["项目推理集合图片信息"])
|
||||
async def create_project_detect_img(data: schemas.ProjectDetectImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectImgDal(auth.db).create_data(data=data))
|
||||
@app.post("/file", summary="上传项目推理集合文件")
|
||||
async def upload_file(
|
||||
detect_id: int = Form(...),
|
||||
files: list[UploadFile] = Form(...),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
file_dal = crud.ProjectDetectFileDal(auth.db)
|
||||
detect_out = file_dal.get_data(data_id=detect_id)
|
||||
if detect_out is None:
|
||||
return ErrorResponse("训练集合查询失败,请刷新后再试")
|
||||
await file_dal.add_file(detect_out, files)
|
||||
return SuccessResponse(msg="上传成功")
|
||||
|
||||
|
||||
@app.delete("/project/detect/img", summary="删除项目推理集合图片信息", description="硬删除", tags=["项目推理集合图片信息"])
|
||||
async def delete_project_detect_img_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectImgDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
@app.delete("/file", summary="删除项目推理集合文件信息")
|
||||
async def delete_file(
|
||||
ids: IdList = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectFileDal(auth.db).delete_files(ids=ids.ids)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/img/{data_id}", summary="更新项目推理集合图片信息", tags=["项目推理集合图片信息"])
|
||||
async def put_project_detect_img(data_id: int, data: schemas.ProjectDetectImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectImgDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/img/{data_id}", summary="获取项目推理集合图片信息信息", tags=["项目推理集合图片信息"])
|
||||
async def get_project_detect_img(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectImgSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectImgDal(db).get_data(data_id, v_schema=schema))
|
||||
|
||||
|
||||
@app.post("/detect", summary="开始推理")
|
||||
def run_detect_yolo(
|
||||
detect_log_in: schemas.ProjectDetectLogIn,
|
||||
auth: Auth = Depends(AllUserAuth()),
|
||||
rd: Redis = Depends(redis_getter)):
|
||||
detect_dal = crud.ProjectDetectDal(auth.db)
|
||||
train_dal = ProjectTrainDal(auth.db)
|
||||
detect = detect_dal.get_data(detect_log_in.detect_id)
|
||||
if detect is None:
|
||||
return ErrorResponse(msg="训练集合不存在")
|
||||
train = train_dal.get_data(detect_log_in.train_id)
|
||||
if train is None:
|
||||
return ErrorResponse("训练权重不存在")
|
||||
file_count = crud.ProjectDetectFileDal(auth.db).file_count(detect_log_in.detect_id)
|
||||
if file_count == 0 and detect.rtsp_url is None:
|
||||
return ErrorResponse("推理集合中没有内容,请先到推理集合中上传图片")
|
||||
if detect.file_type == 'img' or detect.file_type == 'video':
|
||||
detect_log = service.before_detect(detect_log_in, detect, train, auth.db)
|
||||
thread_train = threading.Thread(target=service.run_img_loop,
|
||||
args=(detect_log.pt_url, detect_log.folder_url,
|
||||
detect_log.detect_folder_url, detect_log.detect_version,
|
||||
detect_log.id, detect_log.detect_id, auth.db,))
|
||||
thread_train.start()
|
||||
elif detect.file_type == 'rtsp':
|
||||
room = 'detect_rtsp_' + str(detect.id)
|
||||
if not room_manager.rooms.get(room):
|
||||
if detect_log_in.pt_type == 'best':
|
||||
weights_pt = train.best_pt
|
||||
else:
|
||||
weights_pt = train.last_pt
|
||||
thread_train = threading.Thread(target=service.run_rtsp_loop,
|
||||
args=(weights_pt, detect.rtsp_url, train.train_data, detect.id, rd,))
|
||||
thread_train.start()
|
||||
return SuccessResponse(msg="执行成功")
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理记录信息
|
||||
###########################################################
|
||||
@app.get("/project/detect/log", summary="获取项目推理记录信息列表", tags=["项目推理记录信息"])
|
||||
async def get_project_detect_log_list(p: params.ProjectDetectLogParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
@app.get("/log", summary="获取项目推理记录列表")
|
||||
async def log_pager(
|
||||
p: params.ProjectDetectLogParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectLogDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.post("/project/detect/log", summary="创建项目推理记录信息", tags=["项目推理记录信息"])
|
||||
async def create_project_detect_log(data: schemas.ProjectDetectLog, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogDal(auth.db).create_data(data=data))
|
||||
|
||||
|
||||
@app.delete("/project/detect/log", summary="删除项目推理记录信息", description="硬删除", tags=["项目推理记录信息"])
|
||||
async def delete_project_detect_log_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectLogDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/log/{data_id}", summary="更新项目推理记录信息", tags=["项目推理记录信息"])
|
||||
async def put_project_detect_log(data_id: int, data: schemas.ProjectDetectLog, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/log/{data_id}", summary="获取项目推理记录信息信息", tags=["项目推理记录信息"])
|
||||
async def get_project_detect_log(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectLogSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectLogDal(db).get_data(data_id, v_schema=schema))
|
||||
|
||||
|
||||
|
||||
|
||||
###########################################################
|
||||
# 项目推理记录图片信息
|
||||
###########################################################
|
||||
@app.get("/project/detect/log/img", summary="获取项目推理记录图片信息列表", tags=["项目推理记录图片信息"])
|
||||
async def get_project_detect_log_img_list(p: params.ProjectDetectLogImgParams = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
datas, count = await crud.ProjectDetectLogImgDal(auth.db).get_datas(**p.dict(), v_return_count=True)
|
||||
return SuccessResponse(datas, count=count)
|
||||
|
||||
|
||||
@app.post("/project/detect/log/img", summary="创建项目推理记录图片信息", tags=["项目推理记录图片信息"])
|
||||
async def create_project_detect_log_img(data: schemas.ProjectDetectLogImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogImgDal(auth.db).create_data(data=data))
|
||||
|
||||
|
||||
@app.delete("/project/detect/log/img", summary="删除项目推理记录图片信息", description="硬删除", tags=["项目推理记录图片信息"])
|
||||
async def delete_project_detect_log_img_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):
|
||||
await crud.ProjectDetectLogImgDal(auth.db).delete_datas(ids=ids.ids, v_soft=False)
|
||||
return SuccessResponse("删除成功")
|
||||
|
||||
|
||||
@app.put("/project/detect/log/img/{data_id}", summary="更新项目推理记录图片信息", tags=["项目推理记录图片信息"])
|
||||
async def put_project_detect_log_img(data_id: int, data: schemas.ProjectDetectLogImg, auth: Auth = Depends(AllUserAuth())):
|
||||
return SuccessResponse(await crud.ProjectDetectLogImgDal(auth.db).put_data(data_id, data))
|
||||
|
||||
|
||||
@app.get("/project/detect/log/img/{data_id}", summary="获取项目推理记录图片信息信息", tags=["项目推理记录图片信息"])
|
||||
async def get_project_detect_log_img(data_id: int, db: AsyncSession = Depends(db_getter)):
|
||||
schema = schemas.ProjectDetectLogImgSimpleOut
|
||||
return SuccessResponse(await crud.ProjectDetectLogImgDal(db).get_data(data_id, v_schema=schema))
|
||||
@app.get("/log_files", summary="获取项目推理记录文件列表")
|
||||
async def log_files(
|
||||
p: params.ProjectDetectLogFileParams = Depends(),
|
||||
auth: Auth = Depends(AllUserAuth())):
|
||||
datas = await crud.ProjectDetectLogFileDal(auth.db).get_datas(**p.dict(), v_return_count=False)
|
||||
return SuccessResponse(datas)
|
||||
|
||||
|
@ -10,7 +10,7 @@ from . import params, schemas, crud, models
|
||||
from core.dependencies import IdList
|
||||
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, Form
|
||||
from fastapi import APIRouter, Depends, UploadFile, Form
|
||||
from apps.vadmin.auth.utils.current import FullAdminAuth
|
||||
from apps.vadmin.auth.utils.validation.auth import Auth
|
||||
|
||||
@ -124,7 +124,7 @@ async def project_pager(
|
||||
@app.post("/img", summary="上传图片")
|
||||
async def up_img(
|
||||
project_id: int = Form(...),
|
||||
files: List[UploadFile] = File(...),
|
||||
files: List[UploadFile] = Form(...),
|
||||
img_type: str = Form(...),
|
||||
auth: Auth = Depends(FullAdminAuth())
|
||||
):
|
||||
|
@ -1,14 +1,13 @@
|
||||
from . import schemas, models, crud
|
||||
from apps.business.project import schemas as proj_schemas, models as proj_models, crud as proj_crud
|
||||
from utils import os_utils as os
|
||||
from application.settings import *
|
||||
from . import schemas, models, crud
|
||||
from utils.websocket_server import room_manager
|
||||
from apps.business.project import models as proj_models, crud as proj_crud
|
||||
|
||||
|
||||
import yaml
|
||||
import asyncio
|
||||
import subprocess
|
||||
from typing import List
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@ -73,7 +72,7 @@ async def before_train(proj_info: proj_models.ProjectInfo, db: AsyncSession):
|
||||
|
||||
|
||||
async def operate_img_label(
|
||||
img_list: List[proj_models.ProjectImgLabel],
|
||||
img_list: list[proj_models.ProjectImgLabel],
|
||||
img_path: str,
|
||||
label_path: str,
|
||||
db: AsyncSession,
|
||||
|
@ -1,7 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2022/12/9 15:26
|
||||
# @File : __init__.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 简要说明
|
@ -1,95 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2022/12/9 15:27
|
||||
# @File : main.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 简要说明
|
||||
|
||||
import datetime
|
||||
import os.path
|
||||
from application.settings import BASE_DIR
|
||||
|
||||
|
||||
class CreateApp:
|
||||
|
||||
APPS_ROOT = os.path.join(BASE_DIR, "apps")
|
||||
SCRIPT_DIR = os.path.join(BASE_DIR, 'scripts', 'create_app')
|
||||
|
||||
def __init__(self, path: str):
|
||||
"""
|
||||
:param path: app 路径,根目录为apps,填写apps后面路径即可,例子:vadmin/auth
|
||||
"""
|
||||
self.app_path = os.path.join(self.APPS_ROOT, path)
|
||||
self.path = path
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
自动创建初始化 APP 结构,如何该路径已经存在,则不执行
|
||||
"""
|
||||
if self.exist(self.app_path):
|
||||
print(f"{self.app_path} 已经存在,无法自动创建,请删除后,重新执行。")
|
||||
return False
|
||||
print("开始生成 App 目录:", self.path)
|
||||
path = []
|
||||
for item in self.path.split("/"):
|
||||
path.append(item)
|
||||
self.create_pag(os.path.join(self.APPS_ROOT, *path))
|
||||
self.create_pag(os.path.join(self.app_path, "models"))
|
||||
self.create_pag(os.path.join(self.app_path, "params"))
|
||||
self.create_pag(os.path.join(self.app_path, "schemas"))
|
||||
self.generate_file("views.py")
|
||||
self.generate_file("crud.py")
|
||||
print("App 目录生成结束", self.app_path)
|
||||
|
||||
def create_pag(self, path: str) -> None:
|
||||
"""
|
||||
创建 python 包
|
||||
|
||||
:param path: 绝对路径
|
||||
"""
|
||||
if self.exist(path):
|
||||
return
|
||||
os.makedirs(path)
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
params = {
|
||||
"create_datetime": now,
|
||||
"filename": "__init__.py",
|
||||
"desc": "初始化文件"
|
||||
}
|
||||
self.create_file(os.path.join(path, "__init__.py"), "init.py", **params)
|
||||
|
||||
def generate_file(self, name: str) -> None:
|
||||
"""
|
||||
创建文件
|
||||
"""
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
params = {
|
||||
"create_datetime": now,
|
||||
}
|
||||
self.create_file(os.path.join(self.app_path, name), name, **params)
|
||||
|
||||
def create_file(self, filepath: str, name: str, **kwargs):
|
||||
"""
|
||||
创建文件
|
||||
"""
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
content = self.__get_template(name)
|
||||
f.write(content.format(**kwargs))
|
||||
|
||||
@classmethod
|
||||
def exist(cls, path) -> bool:
|
||||
"""
|
||||
判断路径是否已经存在
|
||||
"""
|
||||
return os.path.exists(path)
|
||||
|
||||
def __get_template(self, name: str) -> str:
|
||||
"""
|
||||
获取模板内容
|
||||
"""
|
||||
template = open(os.path.join(self.SCRIPT_DIR, "template", name), 'r')
|
||||
content = template.read()
|
||||
template.close()
|
||||
return content
|
||||
|
@ -1,7 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : {create_datetime}
|
||||
# @File : crud.py
|
||||
# @IDE : PyCharm
|
||||
# @desc :
|
@ -1,7 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : {create_datetime}
|
||||
# @File : {filename}
|
||||
# @IDE : PyCharm
|
||||
# @desc : {desc}
|
@ -1,16 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : {create_datetime}
|
||||
# @File : views.py
|
||||
# @IDE : PyCharm
|
||||
# @desc :
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from utils.response import SuccessResponse
|
||||
from . import schemas, crud, models
|
||||
|
||||
app = APIRouter()
|
||||
|
||||
|
@ -1,167 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2022/12/9 15:27
|
||||
# @File : main.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 简要说明
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
from typing import Type
|
||||
from application.settings import BASE_DIR
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from core.database import Base
|
||||
from scripts.crud_generate.utils.generate_base import GenerateBase
|
||||
from scripts.crud_generate.utils.schema_generate import SchemaGenerate
|
||||
from scripts.crud_generate.utils.params_generate import ParamsGenerate
|
||||
from scripts.crud_generate.utils.dal_generate import DalGenerate
|
||||
from scripts.crud_generate.utils.view_generate import ViewGenerate
|
||||
|
||||
|
||||
class CrudGenerate(GenerateBase):
|
||||
|
||||
APPS_ROOT = os.path.join(BASE_DIR, "apps")
|
||||
SCRIPT_DIR = os.path.join(BASE_DIR, 'scripts', 'crud_generate')
|
||||
|
||||
def __init__(self, model: Type[Base], zh_name: str, en_name: str = None):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
"""
|
||||
self.model = model
|
||||
self.zh_name = zh_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
# schemas 目录地址
|
||||
self.schemas_dir_path = self.app_dir_path / "schemas"
|
||||
# params 目录地址
|
||||
self.params_dir_path = self.app_dir_path / "params"
|
||||
# crud 文件地址
|
||||
self.crud_file_path = self.app_dir_path / "crud.py"
|
||||
# view 文件地址
|
||||
self.view_file_path = self.app_dir_path / "views.py"
|
||||
|
||||
if en_name:
|
||||
self.en_name = en_name
|
||||
else:
|
||||
self.en_name = self.model.__name__
|
||||
|
||||
self.schema_file_path = self.schemas_dir_path / f"{self.en_name}.py"
|
||||
self.param_file_path = self.params_dir_path / f"{self.en_name}.py"
|
||||
|
||||
self.base_class_name = self.snake_to_camel(self.en_name)
|
||||
self.schema_simple_out_class_name = f"{self.base_class_name}SimpleOut"
|
||||
self.dal_class_name = f"{self.base_class_name}Dal"
|
||||
self.param_class_name = f"{self.base_class_name}Params"
|
||||
|
||||
def generate_codes(self):
|
||||
"""
|
||||
生成代码, 不做实际操作,只是将代码打印出来
|
||||
:return:
|
||||
"""
|
||||
print(f"==========================={self.schema_file_path} 代码内容=================================")
|
||||
schema = SchemaGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.schema_file_path,
|
||||
self.schemas_dir_path,
|
||||
self.base_class_name,
|
||||
self.schema_simple_out_class_name
|
||||
)
|
||||
print(schema.generate_code())
|
||||
|
||||
print(f"==========================={self.dal_class_name} 代码内容=================================")
|
||||
dal = DalGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.dal_class_name,
|
||||
self.schema_simple_out_class_name
|
||||
)
|
||||
print(dal.generate_code())
|
||||
|
||||
print(f"==========================={self.param_file_path} 代码内容=================================")
|
||||
params = ParamsGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.params_dir_path,
|
||||
self.param_file_path,
|
||||
self.param_class_name
|
||||
)
|
||||
print(params.generate_code())
|
||||
|
||||
print(f"==========================={self.view_file_path} 代码内容=================================")
|
||||
view = ViewGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.base_class_name,
|
||||
self.schema_simple_out_class_name,
|
||||
self.dal_class_name,
|
||||
self.param_class_name
|
||||
)
|
||||
print(view.generate_code())
|
||||
|
||||
def main(self):
|
||||
"""
|
||||
开始生成 crud 代码,并直接写入到项目中,目前还未实现
|
||||
1. 生成 schemas 代码
|
||||
2. 生成 dal 代码
|
||||
3. 生成 params 代码
|
||||
4. 生成 views 代码
|
||||
:return:
|
||||
"""
|
||||
schema = SchemaGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.schema_file_path,
|
||||
self.schemas_dir_path,
|
||||
self.base_class_name,
|
||||
self.schema_simple_out_class_name
|
||||
)
|
||||
schema.write_generate_code()
|
||||
|
||||
dal = DalGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.dal_class_name,
|
||||
self.schema_simple_out_class_name
|
||||
)
|
||||
dal.write_generate_code()
|
||||
|
||||
params = ParamsGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.params_dir_path,
|
||||
self.param_file_path,
|
||||
self.param_class_name
|
||||
)
|
||||
params.write_generate_code()
|
||||
|
||||
view = ViewGenerate(
|
||||
self.model,
|
||||
self.zh_name,
|
||||
self.en_name,
|
||||
self.base_class_name,
|
||||
self.schema_simple_out_class_name,
|
||||
self.dal_class_name,
|
||||
self.param_class_name
|
||||
)
|
||||
view.write_generate_code()
|
||||
|
@ -1,106 +0,0 @@
|
||||
import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from core.database import Base
|
||||
from .generate_base import GenerateBase
|
||||
|
||||
|
||||
class DalGenerate(GenerateBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[Base],
|
||||
zh_name: str,
|
||||
en_name: str,
|
||||
dal_class_name: str,
|
||||
schema_simple_out_class_name: str
|
||||
):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
:param dal_class_name:
|
||||
:param schema_simple_out_class_name:
|
||||
"""
|
||||
self.model = model
|
||||
self.dal_class_name = dal_class_name
|
||||
self.schema_simple_out_class_name = schema_simple_out_class_name
|
||||
self.zh_name = zh_name
|
||||
self.en_name = en_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
# crud 文件地址
|
||||
self.crud_file_path = self.app_dir_path / "crud.py"
|
||||
|
||||
def write_generate_code(self):
|
||||
"""
|
||||
生成 crud 文件,以及代码内容
|
||||
:return:
|
||||
"""
|
||||
if self.crud_file_path.exists():
|
||||
codes = self.file_code_split_module(self.crud_file_path)
|
||||
if codes:
|
||||
print(f"==========dal 文件已存在并已有代码内容,正在追加新代码============")
|
||||
if not codes[0]:
|
||||
# 无文件注释则添加文件注释
|
||||
codes[0] = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层")
|
||||
codes[1] = self.merge_dictionaries(codes[1], self.get_base_module_config())
|
||||
codes[2] += self.get_base_code_content()
|
||||
code = ''
|
||||
code += codes[0]
|
||||
code += self.generate_modules_code(codes[1])
|
||||
code += codes[2]
|
||||
self.crud_file_path.write_text(code, "utf-8")
|
||||
print(f"=================dal 代码已创建完成=======================")
|
||||
return
|
||||
self.crud_file_path.touch()
|
||||
code = self.generate_code()
|
||||
self.crud_file_path.write_text(code, "utf-8")
|
||||
print(f"===========================dal 代码创建完成=================================")
|
||||
|
||||
def generate_code(self):
|
||||
"""
|
||||
代码生成
|
||||
:return:
|
||||
"""
|
||||
code = self.generate_file_desc(self.crud_file_path.name, "1.0", "数据访问层")
|
||||
code += self.generate_modules_code(self.get_base_module_config())
|
||||
code += self.get_base_code_content()
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def get_base_module_config():
|
||||
"""
|
||||
获取基础模块导入配置
|
||||
:return:
|
||||
"""
|
||||
modules = {
|
||||
"sqlalchemy.ext.asyncio": ['AsyncSession'],
|
||||
"core.crud": ["DalBase"],
|
||||
".": ["models", "schemas"],
|
||||
}
|
||||
return modules
|
||||
|
||||
def get_base_code_content(self):
|
||||
"""
|
||||
获取基础代码内容
|
||||
:return:
|
||||
"""
|
||||
base_code = f"\n\nclass {self.dal_class_name}(DalBase):\n"
|
||||
base_code += "\n\tdef __init__(self, db: AsyncSession):"
|
||||
base_code += f"\n\t\tsuper({self.dal_class_name}, self).__init__()"
|
||||
base_code += f"\n\t\tself.db = db"
|
||||
base_code += f"\n\t\tself.model = models.{self.model.__name__}"
|
||||
base_code += f"\n\t\tself.schema = schemas.{self.schema_simple_out_class_name}"
|
||||
base_code += "\n"
|
||||
return base_code.replace("\t", " ")
|
||||
|
@ -1,185 +0,0 @@
|
||||
import datetime
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class GenerateBase:
|
||||
|
||||
@staticmethod
|
||||
def camel_to_snake(name: str) -> str:
|
||||
"""
|
||||
将大驼峰命名(CamelCase)转换为下划线命名(snake_case)
|
||||
在大写字母前添加一个空格,然后将字符串分割并用下划线拼接
|
||||
:param name: 大驼峰命名(CamelCase)
|
||||
:return:
|
||||
"""
|
||||
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
||||
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
||||
|
||||
@staticmethod
|
||||
def snake_to_camel(name: str) -> str:
|
||||
"""
|
||||
将下划线命名(snake_case)转换为大驼峰命名(CamelCase)
|
||||
根据下划线分割,然后将字符串转为第一个字符大写后拼接
|
||||
:param name: 下划线命名(snake_case)
|
||||
:return:
|
||||
"""
|
||||
# 按下划线分割字符串
|
||||
words = name.split('_')
|
||||
# 将每个单词的首字母大写,然后拼接
|
||||
return ''.join(word.capitalize() for word in words)
|
||||
|
||||
@staticmethod
|
||||
def generate_file_desc(filename: str, version: str = '1.0', desc: str = '') -> str:
|
||||
"""
|
||||
生成文件注释
|
||||
:param filename:
|
||||
:param version:
|
||||
:param desc:
|
||||
:return:
|
||||
"""
|
||||
code = '#!/usr/bin/python\n# -*- coding: utf-8 -*-'
|
||||
code += f"\n# @version : {version}"
|
||||
code += f"\n# @Create Time : {datetime.datetime.now().strftime('%Y/%m/%d %H:%M')}"
|
||||
code += f"\n# @File : {filename}"
|
||||
code += f"\n# @IDE : PyCharm"
|
||||
code += f"\n# @desc : {desc}"
|
||||
code += f"\n"
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def generate_modules_code(modules: dict[str, list]) -> str:
|
||||
"""
|
||||
生成模块导入代码
|
||||
:param modules: 导入得模块
|
||||
:return:
|
||||
"""
|
||||
code = "\n"
|
||||
args = modules.pop("args", [])
|
||||
for k, v in modules.items():
|
||||
code += f"from {k} import {', '.join(v)}\n"
|
||||
if args:
|
||||
code += f"import {', '.join(args)}\n"
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def update_init_file(init_file: Path, code: str):
|
||||
"""
|
||||
__init__ 文件添加导入内容
|
||||
:param init_file:
|
||||
:param code:
|
||||
:return:
|
||||
"""
|
||||
content = init_file.read_text()
|
||||
if content and code in content:
|
||||
return
|
||||
if content:
|
||||
if content.endswith("\n"):
|
||||
with init_file.open("a+", encoding="utf-8") as f:
|
||||
f.write(f"{code}\n")
|
||||
else:
|
||||
with init_file.open("a+", encoding="utf-8") as f:
|
||||
f.write(f"\n{code}\n")
|
||||
else:
|
||||
init_file.write_text(f"{code}\n", encoding="utf-8")
|
||||
|
||||
@staticmethod
|
||||
def module_code_to_dict(code: str) -> dict:
|
||||
"""
|
||||
将 from import 语句代码转为 dict 格式
|
||||
:param code:
|
||||
:return:
|
||||
"""
|
||||
# 分解代码为单行
|
||||
lines = code.strip().split('\n')
|
||||
|
||||
# 初始化字典
|
||||
modules = {}
|
||||
|
||||
# 遍历每行代码
|
||||
for line in lines:
|
||||
# 处理 'from ... import ...' 类型的导入
|
||||
if line.startswith('from'):
|
||||
parts = line.split(' import ')
|
||||
module = parts[0][5:] # 移除 'from ' 并获取模块路径
|
||||
imports = parts[1].split(',') # 使用逗号分割导入项
|
||||
imports = [item.strip() for item in imports] # 移除多余空格
|
||||
if module in modules:
|
||||
modules[module].extend(imports)
|
||||
else:
|
||||
modules[module] = imports
|
||||
|
||||
# 处理 'import ...' 类型的导入
|
||||
elif line.startswith('import'):
|
||||
imports = line.split('import ')[1]
|
||||
# 分割多个导入项
|
||||
imports = imports.split(', ')
|
||||
for imp in imports:
|
||||
# 处理直接导入的模块
|
||||
modules.setdefault('args', []).append(imp)
|
||||
return modules
|
||||
|
||||
@classmethod
|
||||
def file_code_split_module(cls, file: Path) -> list:
|
||||
"""
|
||||
文件代码内容拆分,分为以下三部分
|
||||
1. 文件开头的注释。
|
||||
2. 全局层面的from import语句。该代码格式会被转换为 dict 格式
|
||||
3. 其他代码内容。
|
||||
:param file:
|
||||
:return:
|
||||
"""
|
||||
content = file.read_text(encoding="utf-8")
|
||||
if not content:
|
||||
return []
|
||||
lines = content.split('\n')
|
||||
part1 = [] # 文件开头注释
|
||||
part2 = [] # from import 语句
|
||||
part3 = [] # 其他代码内容
|
||||
|
||||
# 标记是否已超过注释部分
|
||||
past_comments = False
|
||||
|
||||
for line in lines:
|
||||
# 检查是否为注释行
|
||||
if line.startswith("#") and not past_comments:
|
||||
part1.append(line)
|
||||
else:
|
||||
# 标记已超过注释部分
|
||||
past_comments = True
|
||||
# 检查是否为 from import 语句
|
||||
if line.startswith("from ") or line.startswith("import "):
|
||||
part2.append(line)
|
||||
else:
|
||||
part3.append(line)
|
||||
|
||||
part2 = cls.module_code_to_dict('\n'.join(part2))
|
||||
|
||||
return ['\n'.join(part1), part2, '\n'.join(part3)]
|
||||
|
||||
@staticmethod
|
||||
def merge_dictionaries(dict1, dict2):
|
||||
"""
|
||||
合并两个键为字符串、值为列表的字典
|
||||
:param dict1:
|
||||
:param dict2:
|
||||
:return:
|
||||
"""
|
||||
# 初始化结果字典
|
||||
merged_dict = {}
|
||||
|
||||
# 合并两个字典中的键值对
|
||||
for key in set(dict1) | set(dict2): # 获取两个字典的键的并集
|
||||
merged_dict[key] = list(set(dict1.get(key, []) + dict2.get(key, [])))
|
||||
|
||||
return merged_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_modules = {
|
||||
"sqlalchemy.ext.asyncio": ['AsyncSession'],
|
||||
"core.crud": ["DalBase"],
|
||||
".": ["models", "schemas"],
|
||||
"args": ["test", "test1"]
|
||||
}
|
||||
print(GenerateBase.generate_modules_code(_modules))
|
@ -1,82 +0,0 @@
|
||||
import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from core.database import Base
|
||||
from .generate_base import GenerateBase
|
||||
|
||||
|
||||
class ParamsGenerate(GenerateBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[Base],
|
||||
zh_name: str,
|
||||
en_name: str,
|
||||
params_dir_path: Path,
|
||||
param_file_path: Path,
|
||||
param_class_name: str
|
||||
):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param param_class_name:
|
||||
:param param_file_path:
|
||||
:param params_dir_path:
|
||||
:param en_name: 功能英文名称,主要用于 param、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
"""
|
||||
self.model = model
|
||||
self.param_class_name = param_class_name
|
||||
self.zh_name = zh_name
|
||||
self.en_name = en_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
# params 目录地址
|
||||
self.params_dir_path = params_dir_path
|
||||
self.param_file_path = param_file_path
|
||||
|
||||
def write_generate_code(self):
|
||||
"""
|
||||
生成 params 文件,以及代码内容
|
||||
:return:
|
||||
"""
|
||||
param_init_file_path = self.params_dir_path / "__init__.py"
|
||||
self.param_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if self.param_file_path.exists():
|
||||
self.param_file_path.unlink()
|
||||
self.param_file_path.touch()
|
||||
param_init_file_path.touch()
|
||||
|
||||
code = self.generate_code()
|
||||
self.param_file_path.write_text(code, "utf-8")
|
||||
init_code = f"from .{self.en_name} import {self.param_class_name}"
|
||||
self.update_init_file(param_init_file_path, init_code)
|
||||
print(f"===========================param 代码创建完成=================================")
|
||||
|
||||
def generate_code(self) -> str:
|
||||
"""
|
||||
生成 schema 代码内容
|
||||
:return:
|
||||
"""
|
||||
code = self.generate_file_desc(self.param_file_path.name, "1.0", self.zh_name)
|
||||
|
||||
modules = {
|
||||
"fastapi": ['Depends'],
|
||||
"core.dependencies": ['Paging', "QueryParams"],
|
||||
}
|
||||
code += self.generate_modules_code(modules)
|
||||
|
||||
base_code = f"\n\nclass {self.param_class_name}(QueryParams):"
|
||||
base_code += f"\n\tdef __init__(self, params: Paging = Depends()):"
|
||||
base_code += f"\n\t\tsuper().__init__(params)"
|
||||
base_code += "\n"
|
||||
code += base_code
|
||||
return code.replace("\t", " ")
|
@ -1,11 +0,0 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SchemaField(BaseModel):
|
||||
name: str = Field(..., title="字段名称")
|
||||
field_type: str = Field(..., title="字段类型")
|
||||
nullable: bool = Field(False, title="是否可以为空")
|
||||
default: Any = Field(None, title="默认值")
|
||||
title: str | None = Field(None, title="字段描述")
|
||||
max_length: int | None = Field(None, title="最大长度")
|
@ -1,143 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2024/1/12 17:28
|
||||
# @File : schema_generate.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : schema 代码生成
|
||||
|
||||
|
||||
import sys
|
||||
from typing import Type
|
||||
import inspect
|
||||
from sqlalchemy import inspect as model_inspect
|
||||
from pathlib import Path
|
||||
from core.database import Base
|
||||
from scripts.crud_generate.utils.schema import SchemaField
|
||||
from sqlalchemy.sql.schema import Column as ColumnType
|
||||
from scripts.crud_generate.utils.generate_base import GenerateBase
|
||||
|
||||
|
||||
class SchemaGenerate(GenerateBase):
|
||||
|
||||
BASE_FIELDS = ["id", "create_datetime", "update_datetime", "delete_datetime", "is_delete"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[Base],
|
||||
zh_name: str,
|
||||
en_name: str,
|
||||
schema_file_path: Path,
|
||||
schemas_dir_path: Path,
|
||||
base_class_name: str,
|
||||
schema_simple_out_class_name: str
|
||||
):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param schema_file_path:
|
||||
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
:param base_class_name:
|
||||
:param schema_simple_out_class_name:
|
||||
"""
|
||||
self.model = model
|
||||
self.base_class_name = base_class_name
|
||||
self.schema_simple_out_class_name = schema_simple_out_class_name
|
||||
self.zh_name = zh_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
self.en_name = en_name
|
||||
self.schema_file_path = schema_file_path
|
||||
self.schemas_dir_path = schemas_dir_path
|
||||
self.schema_init_file_path = self.schemas_dir_path / "__init__.py"
|
||||
|
||||
def write_generate_code(self):
|
||||
"""
|
||||
生成 schema 文件,以及代码内容
|
||||
:return:
|
||||
"""
|
||||
self.schema_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if self.schema_file_path.exists():
|
||||
# 存在则直接删除,重新创建写入
|
||||
self.schema_file_path.unlink()
|
||||
self.schema_file_path.touch()
|
||||
self.schema_init_file_path.touch()
|
||||
|
||||
code = self.generate_code()
|
||||
self.schema_file_path.write_text(code, "utf-8")
|
||||
|
||||
init_code = self.generate_init_code()
|
||||
self.update_init_file(self.schema_init_file_path, init_code)
|
||||
print(f"===========================schema 代码创建完成=================================")
|
||||
|
||||
def generate_init_code(self):
|
||||
"""
|
||||
生成 __init__ 文件导入代码
|
||||
todo 如果导入的类已经存在,则应该返回空
|
||||
:return:
|
||||
"""
|
||||
init_code = f"from .{self.en_name} import {self.base_class_name}, {self.schema_simple_out_class_name}"
|
||||
return init_code
|
||||
|
||||
def generate_code(self) -> str:
|
||||
"""
|
||||
生成 schema 代码内容
|
||||
:return:
|
||||
"""
|
||||
fields = []
|
||||
mapper = model_inspect(self.model)
|
||||
for attr_name, column_property in mapper.column_attrs.items():
|
||||
if attr_name in self.BASE_FIELDS:
|
||||
continue
|
||||
# 假设它是单列属性
|
||||
column: ColumnType = column_property.columns[0]
|
||||
item = SchemaField(
|
||||
name=attr_name,
|
||||
field_type=column.type.python_type.__name__,
|
||||
nullable=column.nullable,
|
||||
default=column.default.__dict__.get("arg", None) if column.default else None,
|
||||
title=column.comment,
|
||||
max_length=column.type.__dict__.get("length", None)
|
||||
)
|
||||
fields.append(item)
|
||||
|
||||
code = self.generate_file_desc(self.schema_file_path.name, "1.0", "pydantic 模型,用于数据库序列化操作")
|
||||
|
||||
modules = {
|
||||
"pydantic": ['BaseModel', "Field", "ConfigDict"],
|
||||
"core.data_types": ['DatetimeStr'],
|
||||
}
|
||||
code += self.generate_modules_code(modules)
|
||||
|
||||
base_schema_code = f"\n\nclass {self.base_class_name}(BaseModel):"
|
||||
for item in fields:
|
||||
field = f"\n\t{item.name}: {item.field_type} {'| None ' if item.nullable else ''}"
|
||||
default = None
|
||||
if item.default is not None:
|
||||
if item.field_type == "str":
|
||||
default = f"\"{item.default}\""
|
||||
else:
|
||||
default = item.default
|
||||
elif default is None and not item.nullable:
|
||||
default = "..."
|
||||
|
||||
field += f"= Field({default}, title=\"{item.title}\")"
|
||||
base_schema_code += field
|
||||
base_schema_code += "\n"
|
||||
code += base_schema_code
|
||||
|
||||
base_out_schema_code = f"\n\nclass {self.schema_simple_out_class_name}({self.base_class_name}):"
|
||||
base_out_schema_code += "\n\tmodel_config = ConfigDict(from_attributes=True)\n"
|
||||
base_out_schema_code += "\n\tid: int = Field(..., title=\"编号\")"
|
||||
base_out_schema_code += "\n\tcreate_datetime: DatetimeStr = Field(..., title=\"创建时间\")"
|
||||
base_out_schema_code += "\n\tupdate_datetime: DatetimeStr = Field(..., title=\"更新时间\")"
|
||||
base_out_schema_code += "\n"
|
||||
code += base_out_schema_code
|
||||
return code.replace("\t", " ")
|
@ -1,143 +0,0 @@
|
||||
import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from core.database import Base
|
||||
from .generate_base import GenerateBase
|
||||
|
||||
|
||||
class ViewGenerate(GenerateBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[Base],
|
||||
zh_name: str,
|
||||
en_name: str,
|
||||
schema_class_name: str,
|
||||
schema_simple_out_class_name: str,
|
||||
dal_class_name: str,
|
||||
param_class_name: str
|
||||
):
|
||||
"""
|
||||
初始化工作
|
||||
:param model: 提前定义好的 ORM 模型
|
||||
:param zh_name: 功能中文名称,主要用于描述、注释
|
||||
:param schema_class_name:
|
||||
:param schema_simple_out_class_name:
|
||||
:param dal_class_name:
|
||||
:param param_class_name:
|
||||
:param en_name: 功能英文名称,主要用于 schema、param 文件命名,以及它们的 class 命名,dal、url 命名,默认使用 model class
|
||||
en_name 例子:
|
||||
如果 en_name 由多个单词组成那么请使用 _ 下划线拼接
|
||||
在命名文件名称时,会执行使用 _ 下划线名称
|
||||
在命名 class 名称时,会将下划线名称转换为大驼峰命名(CamelCase)
|
||||
在命名 url 时,会将下划线转换为 /
|
||||
"""
|
||||
self.model = model
|
||||
self.schema_class_name = schema_class_name
|
||||
self.schema_simple_out_class_name = schema_simple_out_class_name
|
||||
self.dal_class_name = dal_class_name
|
||||
self.param_class_name = param_class_name
|
||||
self.zh_name = zh_name
|
||||
self.en_name = en_name
|
||||
# model 文件的地址
|
||||
self.model_file_path = Path(inspect.getfile(sys.modules[model.__module__]))
|
||||
# model 文件 app 路径
|
||||
self.app_dir_path = self.model_file_path.parent.parent
|
||||
# view 文件地址
|
||||
self.view_file_path = self.app_dir_path / "views.py"
|
||||
|
||||
def write_generate_code(self):
|
||||
"""
|
||||
生成 view 文件,以及代码内容
|
||||
:return:
|
||||
"""
|
||||
if self.view_file_path.exists():
|
||||
codes = self.file_code_split_module(self.view_file_path)
|
||||
if codes:
|
||||
print(f"==========view 文件已存在并已有代码内容,正在追加新代码============")
|
||||
if not codes[0]:
|
||||
# 无文件注释则添加文件注释
|
||||
codes[0] = self.generate_file_desc(self.view_file_path.name, "1.0", "视图层")
|
||||
codes[1] = self.merge_dictionaries(codes[1], self.get_base_module_config())
|
||||
codes[2] += self.get_base_code_content()
|
||||
code = ''
|
||||
code += codes[0]
|
||||
code += self.generate_modules_code(codes[1])
|
||||
if "app = APIRouter()" not in codes[2]:
|
||||
code += "\n\napp = APIRouter()"
|
||||
code += codes[2]
|
||||
self.view_file_path.write_text(code, "utf-8")
|
||||
print(f"=================view 代码已创建完成=====================")
|
||||
return
|
||||
else:
|
||||
self.view_file_path.touch()
|
||||
code = self.generate_code()
|
||||
self.view_file_path.write_text(code, encoding="utf-8")
|
||||
print(f"===============view 代码创建完成==================")
|
||||
|
||||
def generate_code(self) -> str:
|
||||
"""
|
||||
生成代码
|
||||
:return:
|
||||
"""
|
||||
code = self.generate_file_desc(self.view_file_path.name, "1.0", "路由,视图文件")
|
||||
code += self.generate_modules_code(self.get_base_module_config())
|
||||
code += "\n\napp = APIRouter()"
|
||||
code += self.get_base_code_content()
|
||||
|
||||
return code.replace("\t", " ")
|
||||
|
||||
@staticmethod
|
||||
def get_base_module_config():
|
||||
"""
|
||||
获取基础模块导入配置
|
||||
:return:
|
||||
"""
|
||||
modules = {
|
||||
"sqlalchemy.ext.asyncio": ['AsyncSession'],
|
||||
"fastapi": ["APIRouter", "Depends"],
|
||||
".": ["models", "schemas", "crud", "params"],
|
||||
"core.dependencies": ["IdList"],
|
||||
"apps.vadmin.auth.utils.current": ["AllUserAuth"],
|
||||
"utils.response": ["SuccessResponse"],
|
||||
"apps.vadmin.auth.utils.validation.auth": ["Auth"],
|
||||
"core.database": ["db_getter"],
|
||||
}
|
||||
return modules
|
||||
|
||||
def get_base_code_content(self):
|
||||
"""
|
||||
获取基础代码内容
|
||||
:return:
|
||||
"""
|
||||
base_code = "\n\n\n###########################################################"
|
||||
base_code += f"\n# {self.zh_name}"
|
||||
base_code += "\n###########################################################"
|
||||
|
||||
router = self.en_name.replace("_", "/")
|
||||
|
||||
base_code += f"\n@app.get(\"/{router}\", summary=\"获取{self.zh_name}列表\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def get_{self.en_name}_list(p: params.{self.param_class_name} = Depends(), auth: Auth = Depends(AllUserAuth())):"
|
||||
base_code += f"\n\tdatas, count = await crud.{self.dal_class_name}(auth.db).get_datas(**p.dict(), v_return_count=True)"
|
||||
base_code += f"\n\treturn SuccessResponse(datas, count=count)\n"
|
||||
|
||||
base_code += f"\n\n@app.post(\"/{router}\", summary=\"创建{self.zh_name}\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def create_{self.en_name}(data: schemas.{self.schema_class_name}, auth: Auth = Depends(AllUserAuth())):"
|
||||
base_code += f"\n\treturn SuccessResponse(await crud.{self.dal_class_name}(auth.db).create_data(data=data))\n"
|
||||
|
||||
base_code += f"\n\n@app.delete(\"/{router}\", summary=\"删除{self.zh_name}\", description=\"硬删除\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def delete_{self.en_name}_list(ids: IdList = Depends(), auth: Auth = Depends(AllUserAuth())):"
|
||||
base_code += f"\n\tawait crud.{self.dal_class_name}(auth.db).delete_datas(ids=ids.ids, v_soft=False)"
|
||||
base_code += f"\n\treturn SuccessResponse(\"删除成功\")\n"
|
||||
|
||||
base_code += f"\n\n@app.put(\"/{router}" + "/{data_id}\"" + f", summary=\"更新{self.zh_name}\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def put_{self.en_name}(data_id: int, data: schemas.{self.schema_class_name}, auth: Auth = Depends(AllUserAuth())):"
|
||||
base_code += f"\n\treturn SuccessResponse(await crud.{self.dal_class_name}(auth.db).put_data(data_id, data))\n"
|
||||
|
||||
base_code += f"\n\n@app.get(\"/{router}" + "/{data_id}\"" + f", summary=\"获取{self.zh_name}信息\", tags=[\"{self.zh_name}\"])"
|
||||
base_code += f"\nasync def get_{self.en_name}(data_id: int, db: AsyncSession = Depends(db_getter)):"
|
||||
base_code += f"\n\tschema = schemas.{self.schema_simple_out_class_name}"
|
||||
base_code += f"\n\treturn SuccessResponse(await crud.{self.dal_class_name}(db).get_data(data_id, v_schema=schema))\n"
|
||||
base_code += "\n"
|
||||
return base_code.replace("\t", " ")
|
@ -1,7 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2021/10/19 15:47
|
||||
# @File : initialize.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 初始化数据
|
||||
|
Binary file not shown.
@ -1,180 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @version : 1.0
|
||||
# @Create Time : 2022/11/23 11:21
|
||||
# @File : initialize.py
|
||||
# @IDE : PyCharm
|
||||
# @desc : 简要说明
|
||||
|
||||
from enum import Enum
|
||||
from sqlalchemy import insert
|
||||
from core.database import db_getter
|
||||
from utils.excel.excel_manage import ExcelManage
|
||||
from application.settings import BASE_DIR, VERSION
|
||||
import os
|
||||
from apps.vadmin.auth import models as auth_models
|
||||
from apps.vadmin.system import models as system_models
|
||||
from apps.vadmin.help import models as help_models
|
||||
import subprocess
|
||||
|
||||
|
||||
class Environment(str, Enum):
|
||||
dev = "dev"
|
||||
pro = "pro"
|
||||
|
||||
|
||||
class InitializeData:
|
||||
"""
|
||||
初始化数据
|
||||
|
||||
生成步骤:
|
||||
1. 读取数据
|
||||
2. 获取数据库
|
||||
3. 创建数据
|
||||
"""
|
||||
|
||||
SCRIPT_DIR = os.path.join(BASE_DIR, 'scripts', 'initialize')
|
||||
|
||||
def __init__(self):
|
||||
self.sheet_names = []
|
||||
self.datas = {}
|
||||
self.ex = None
|
||||
self.db = None
|
||||
self.__serializer_data()
|
||||
self.__get_sheet_data()
|
||||
|
||||
@classmethod
|
||||
def migrate_model(cls, env: Environment = Environment.pro):
|
||||
"""
|
||||
模型迁移映射到数据库
|
||||
"""
|
||||
subprocess.check_call(['alembic', '--name', f'{env.value}', 'revision', '--autogenerate', '-m', f'{VERSION}'], cwd=BASE_DIR)
|
||||
subprocess.check_call(['alembic', '--name', f'{env.value}', 'upgrade', 'head'], cwd=BASE_DIR)
|
||||
print(f"环境:{env} {VERSION} 数据库表迁移完成")
|
||||
|
||||
def __serializer_data(self):
|
||||
"""
|
||||
序列化数据,将excel数据转为python对象
|
||||
"""
|
||||
self.ex = ExcelManage()
|
||||
self.ex.open_workbook(os.path.join(self.SCRIPT_DIR, 'data', 'init.xlsx'), read_only=True)
|
||||
self.sheet_names = self.ex.get_sheets()
|
||||
|
||||
def __get_sheet_data(self):
|
||||
"""
|
||||
获取工作区数据
|
||||
"""
|
||||
for sheet in self.sheet_names:
|
||||
sheet_data = []
|
||||
self.ex.open_sheet(sheet)
|
||||
headers = self.ex.get_header()
|
||||
datas = self.ex.readlines(min_row=2, max_col=len(headers))
|
||||
for row in datas:
|
||||
sheet_data.append(dict(zip(headers, row)))
|
||||
self.datas[sheet] = sheet_data
|
||||
|
||||
async def __generate_data(self, table_name: str, model):
|
||||
"""
|
||||
生成数据
|
||||
|
||||
:param table_name: 表名
|
||||
:param model: 数据表模型
|
||||
"""
|
||||
async_session = db_getter()
|
||||
db = await async_session.__anext__()
|
||||
datas = self.datas.get(table_name)
|
||||
await db.execute(insert(model), datas)
|
||||
await db.flush()
|
||||
await db.commit()
|
||||
print(f"{table_name} 表数据已生成")
|
||||
|
||||
async def generate_dept(self):
|
||||
"""
|
||||
生成部门详情数据
|
||||
"""
|
||||
await self.__generate_data("vadmin_auth_dept", auth_models.VadminDept)
|
||||
|
||||
async def generate_user_dept(self):
|
||||
"""
|
||||
生成用户关联部门详情数据
|
||||
"""
|
||||
await self.__generate_data("vadmin_auth_user_depts", auth_models.vadmin_auth_user_depts)
|
||||
|
||||
async def generate_menu(self):
|
||||
"""
|
||||
生成菜单数据
|
||||
"""
|
||||
await self.__generate_data("vadmin_auth_menu", auth_models.VadminMenu)
|
||||
|
||||
async def generate_role(self):
|
||||
"""
|
||||
生成角色
|
||||
"""
|
||||
await self.__generate_data("vadmin_auth_role", auth_models.VadminRole)
|
||||
|
||||
async def generate_user(self):
|
||||
"""
|
||||
生成用户
|
||||
"""
|
||||
await self.__generate_data("vadmin_auth_user", auth_models.VadminUser)
|
||||
|
||||
async def generate_user_role(self):
|
||||
"""
|
||||
生成用户
|
||||
"""
|
||||
await self.__generate_data("vadmin_auth_user_roles", auth_models.vadmin_auth_user_roles)
|
||||
|
||||
async def generate_system_tab(self):
|
||||
"""
|
||||
生成系统配置分类数据
|
||||
"""
|
||||
await self.__generate_data("vadmin_system_settings_tab", system_models.VadminSystemSettingsTab)
|
||||
|
||||
async def generate_system_config(self):
|
||||
"""
|
||||
生成系统配置数据
|
||||
"""
|
||||
await self.__generate_data("vadmin_system_settings", system_models.VadminSystemSettings)
|
||||
|
||||
async def generate_dict_type(self):
|
||||
"""
|
||||
生成字典类型数据
|
||||
"""
|
||||
await self.__generate_data("vadmin_system_dict_type", system_models.VadminDictType)
|
||||
|
||||
async def generate_dict_details(self):
|
||||
"""
|
||||
生成字典详情数据
|
||||
"""
|
||||
await self.__generate_data("vadmin_system_dict_details", system_models.VadminDictDetails)
|
||||
|
||||
async def generate_help_issue_category(self):
|
||||
"""
|
||||
生成常见问题类别数据
|
||||
"""
|
||||
await self.__generate_data("vadmin_help_issue_category", help_models.VadminIssueCategory)
|
||||
|
||||
async def generate_help_issue(self):
|
||||
"""
|
||||
生成常见问题详情数据
|
||||
"""
|
||||
await self.__generate_data("vadmin_help_issue", help_models.VadminIssue)
|
||||
|
||||
async def run(self, env: Environment = Environment.pro):
|
||||
"""
|
||||
执行初始化工作
|
||||
"""
|
||||
self.migrate_model(env)
|
||||
await self.generate_menu()
|
||||
await self.generate_role()
|
||||
await self.generate_dept()
|
||||
await self.generate_user()
|
||||
await self.generate_user_dept()
|
||||
await self.generate_user_role()
|
||||
await self.generate_system_tab()
|
||||
await self.generate_dict_type()
|
||||
await self.generate_system_config()
|
||||
await self.generate_dict_details()
|
||||
await self.generate_help_issue_category()
|
||||
await self.generate_help_issue()
|
||||
print(f"环境:{env} {VERSION} 数据已初始化完成")
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
Before Width: | Height: | Size: 1.1 KiB |
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
Before Width: | Height: | Size: 256 KiB |
Binary file not shown.
Before Width: | Height: | Size: 20 KiB |
1149
utils/yolov5/models/common.py
Normal file
1149
utils/yolov5/models/common.py
Normal file
File diff suppressed because it is too large
Load Diff
130
utils/yolov5/models/experimental.py
Normal file
130
utils/yolov5/models/experimental.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Experimental modules."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from app.util.yolov5.utils.downloads import attempt_download
|
||||
|
||||
|
||||
class Sum(nn.Module):
|
||||
"""Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070."""
|
||||
|
||||
def __init__(self, n, weight=False):
|
||||
"""Initializes a module to sum outputs of layers with number of inputs `n` and optional weighting, supporting 2+
|
||||
inputs.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = weight # apply weights boolean
|
||||
self.iter = range(n - 1) # iter object
|
||||
if weight:
|
||||
self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes input through a customizable weighted sum of `n` inputs, optionally applying learned weights."""
|
||||
y = x[0] # no weight
|
||||
if self.weight:
|
||||
w = torch.sigmoid(self.w) * 2
|
||||
for i in self.iter:
|
||||
y = y + x[i + 1] * w[i]
|
||||
else:
|
||||
for i in self.iter:
|
||||
y = y + x[i + 1]
|
||||
return y
|
||||
|
||||
|
||||
class MixConv2d(nn.Module):
|
||||
"""Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595."""
|
||||
|
||||
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
|
||||
"""Initializes MixConv2d with mixed depth-wise convolutional layers, taking input and output channels (c1, c2),
|
||||
kernel sizes (k), stride (s), and channel distribution strategy (equal_ch).
|
||||
"""
|
||||
super().__init__()
|
||||
n = len(k) # number of convolutions
|
||||
if equal_ch: # equal c_ per group
|
||||
i = torch.linspace(0, n - 1e-6, c2).floor() # c2 indices
|
||||
c_ = [(i == g).sum() for g in range(n)] # intermediate channels
|
||||
else: # equal weight.numel() per group
|
||||
b = [c2] + [0] * n
|
||||
a = np.eye(n + 1, n, k=-1)
|
||||
a -= np.roll(a, 1, axis=1)
|
||||
a *= np.array(k) ** 2
|
||||
a[0] = 1
|
||||
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
|
||||
|
||||
self.m = nn.ModuleList(
|
||||
[nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)]
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(c2)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
"""Performs forward pass by applying SiLU activation on batch-normalized concatenated convolutional layer
|
||||
outputs.
|
||||
"""
|
||||
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
||||
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
"""Ensemble of models."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes an ensemble of models to be used for aggregated predictions."""
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, augment=False, profile=False, visualize=False):
|
||||
"""Performs forward pass aggregating outputs from an ensemble of models.."""
|
||||
y = [module(x, augment, profile, visualize)[0] for module in self]
|
||||
# y = torch.stack(y).max(0)[0] # max ensemble
|
||||
# y = torch.stack(y).mean(0) # mean ensemble
|
||||
y = torch.cat(y, 1) # nms ensemble
|
||||
return y, None # inference, train output
|
||||
|
||||
|
||||
def attempt_load(weights, device=None, inplace=True, fuse=True):
|
||||
"""
|
||||
Loads and fuses an ensemble or single YOLOv5 model from weights, handling device placement and model adjustments.
|
||||
|
||||
Example inputs: weights=[a,b,c] or a single model weights=[a] or weights=a.
|
||||
"""
|
||||
from app.util.yolov5.models.yolo import Detect, Model
|
||||
|
||||
model = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt = torch.load(attempt_download(w), map_location="cpu") # load
|
||||
ckpt = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
if not hasattr(ckpt, "stride"):
|
||||
ckpt.stride = torch.tensor([32.0])
|
||||
if hasattr(ckpt, "names") and isinstance(ckpt.names, (list, tuple)):
|
||||
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
||||
|
||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, "fuse") else ckpt.eval()) # model in eval mode
|
||||
|
||||
# Module updates
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
||||
m.inplace = inplace
|
||||
if t is Detect and not isinstance(m.anchor_grid, list):
|
||||
delattr(m, "anchor_grid")
|
||||
setattr(m, "anchor_grid", [torch.zeros(1)] * m.nl)
|
||||
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model
|
||||
if len(model) == 1:
|
||||
return model[-1]
|
||||
|
||||
# Return detection ensemble
|
||||
print(f"Ensemble created with {weights}\n")
|
||||
for k in "names", "nc", "yaml":
|
||||
setattr(model, k, getattr(model[0], k))
|
||||
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
||||
assert all(model[0].nc == m.nc for m in model), f"Models have different class counts: {[m.nc for m in model]}"
|
||||
return model
|
57
utils/yolov5/models/hub/anchors.yaml
Normal file
57
utils/yolov5/models/hub/anchors.yaml
Normal file
@ -0,0 +1,57 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Default anchors for COCO data
|
||||
|
||||
# P5 -------------------------------------------------------------------------------------------------------------------
|
||||
# P5-640:
|
||||
anchors_p5_640:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# P6 -------------------------------------------------------------------------------------------------------------------
|
||||
# P6-640: thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11, 21,19, 17,41, 43,32, 39,70, 86,64, 65,131, 134,130, 120,265, 282,180, 247,354, 512,387
|
||||
anchors_p6_640:
|
||||
- [9, 11, 21, 19, 17, 41] # P3/8
|
||||
- [43, 32, 39, 70, 86, 64] # P4/16
|
||||
- [65, 131, 134, 130, 120, 265] # P5/32
|
||||
- [282, 180, 247, 354, 512, 387] # P6/64
|
||||
|
||||
# P6-1280: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27, 44,40, 38,94, 96,68, 86,152, 180,137, 140,301, 303,264, 238,542, 436,615, 739,380, 925,792
|
||||
anchors_p6_1280:
|
||||
- [19, 27, 44, 40, 38, 94] # P3/8
|
||||
- [96, 68, 86, 152, 180, 137] # P4/16
|
||||
- [140, 301, 303, 264, 238, 542] # P5/32
|
||||
- [436, 615, 739, 380, 925, 792] # P6/64
|
||||
|
||||
# P6-1920: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41, 67,59, 57,141, 144,103, 129,227, 270,205, 209,452, 455,396, 358,812, 653,922, 1109,570, 1387,1187
|
||||
anchors_p6_1920:
|
||||
- [28, 41, 67, 59, 57, 141] # P3/8
|
||||
- [144, 103, 129, 227, 270, 205] # P4/16
|
||||
- [209, 452, 455, 396, 358, 812] # P5/32
|
||||
- [653, 922, 1109, 570, 1387, 1187] # P6/64
|
||||
|
||||
# P7 -------------------------------------------------------------------------------------------------------------------
|
||||
# P7-640: thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11, 13,30, 29,20, 30,46, 61,38, 39,92, 78,80, 146,66, 79,163, 149,150, 321,143, 157,303, 257,402, 359,290, 524,372
|
||||
anchors_p7_640:
|
||||
- [11, 11, 13, 30, 29, 20] # P3/8
|
||||
- [30, 46, 61, 38, 39, 92] # P4/16
|
||||
- [78, 80, 146, 66, 79, 163] # P5/32
|
||||
- [149, 150, 321, 143, 157, 303] # P6/64
|
||||
- [257, 402, 359, 290, 524, 372] # P7/128
|
||||
|
||||
# P7-1280: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22, 54,36, 32,77, 70,83, 138,71, 75,173, 165,159, 148,334, 375,151, 334,317, 251,626, 499,474, 750,326, 534,814, 1079,818
|
||||
anchors_p7_1280:
|
||||
- [19, 22, 54, 36, 32, 77] # P3/8
|
||||
- [70, 83, 138, 71, 75, 173] # P4/16
|
||||
- [165, 159, 148, 334, 375, 151] # P5/32
|
||||
- [334, 317, 251, 626, 499, 474] # P6/64
|
||||
- [750, 326, 534, 814, 1079, 818] # P7/128
|
||||
|
||||
# P7-1920: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34, 81,55, 47,115, 105,124, 207,107, 113,259, 247,238, 222,500, 563,227, 501,476, 376,939, 749,711, 1126,489, 801,1222, 1618,1227
|
||||
anchors_p7_1920:
|
||||
- [29, 34, 81, 55, 47, 115] # P3/8
|
||||
- [105, 124, 207, 107, 113, 259] # P4/16
|
||||
- [247, 238, 222, 500, 563, 227] # P5/32
|
||||
- [501, 476, 376, 939, 749, 711] # P6/64
|
||||
- [1126, 489, 801, 1222, 1618, 1227] # P7/128
|
52
utils/yolov5/models/hub/yolov3-spp.yaml
Normal file
52
utils/yolov5/models/hub/yolov3-spp.yaml
Normal file
@ -0,0 +1,52 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# darknet53 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [32, 3, 1]], # 0
|
||||
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
|
||||
[-1, 1, Bottleneck, [64]],
|
||||
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
|
||||
[-1, 2, Bottleneck, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
|
||||
[-1, 8, Bottleneck, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
|
||||
[-1, 8, Bottleneck, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
|
||||
[-1, 4, Bottleneck, [1024]], # 10
|
||||
]
|
||||
|
||||
# YOLOv3-SPP head
|
||||
head: [
|
||||
[-1, 1, Bottleneck, [1024, False]],
|
||||
[-1, 1, SPP, [512, [5, 9, 13]]],
|
||||
[-1, 1, Conv, [1024, 3, 1]],
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
|
||||
|
||||
[-2, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 1, Bottleneck, [512, False]],
|
||||
[-1, 1, Bottleneck, [512, False]],
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
|
||||
|
||||
[-2, 1, Conv, [128, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 1, Bottleneck, [256, False]],
|
||||
[-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
|
||||
|
||||
[[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
42
utils/yolov5/models/hub/yolov3-tiny.yaml
Normal file
42
utils/yolov5/models/hub/yolov3-tiny.yaml
Normal file
@ -0,0 +1,42 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 14, 23, 27, 37, 58] # P4/16
|
||||
- [81, 82, 135, 169, 344, 319] # P5/32
|
||||
|
||||
# YOLOv3-tiny backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [16, 3, 1]], # 0
|
||||
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
|
||||
[-1, 1, Conv, [32, 3, 1]],
|
||||
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
|
||||
[-1, 1, Conv, [64, 3, 1]],
|
||||
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
|
||||
[-1, 1, Conv, [128, 3, 1]],
|
||||
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
|
||||
[-1, 1, Conv, [256, 3, 1]],
|
||||
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
|
||||
[-1, 1, Conv, [512, 3, 1]],
|
||||
[-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
|
||||
[-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
|
||||
]
|
||||
|
||||
# YOLOv3-tiny head
|
||||
head: [
|
||||
[-1, 1, Conv, [1024, 3, 1]],
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
|
||||
|
||||
[-2, 1, Conv, [128, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
|
||||
|
||||
[[19, 15], 1, Detect, [nc, anchors]], # Detect(P4, P5)
|
||||
]
|
52
utils/yolov5/models/hub/yolov3.yaml
Normal file
52
utils/yolov5/models/hub/yolov3.yaml
Normal file
@ -0,0 +1,52 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# darknet53 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [32, 3, 1]], # 0
|
||||
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
|
||||
[-1, 1, Bottleneck, [64]],
|
||||
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
|
||||
[-1, 2, Bottleneck, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
|
||||
[-1, 8, Bottleneck, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
|
||||
[-1, 8, Bottleneck, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
|
||||
[-1, 4, Bottleneck, [1024]], # 10
|
||||
]
|
||||
|
||||
# YOLOv3 head
|
||||
head: [
|
||||
[-1, 1, Bottleneck, [1024, False]],
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, Conv, [1024, 3, 1]],
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
|
||||
|
||||
[-2, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 1, Bottleneck, [512, False]],
|
||||
[-1, 1, Bottleneck, [512, False]],
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
|
||||
|
||||
[-2, 1, Conv, [128, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 1, Bottleneck, [256, False]],
|
||||
[-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
|
||||
|
||||
[[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/hub/yolov5-bifpn.yaml
Normal file
49
utils/yolov5/models/hub/yolov5-bifpn.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 BiFPN head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14, 6], 1, Concat, [1]], # cat P4 <--- BiFPN change
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
43
utils/yolov5/models/hub/yolov5-fpn.yaml
Normal file
43
utils/yolov5/models/hub/yolov5-fpn.yaml
Normal file
@ -0,0 +1,43 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 FPN head
|
||||
head: [
|
||||
[-1, 3, C3, [1024, False]], # 10 (P5/32-large)
|
||||
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 3, C3, [512, False]], # 14 (P4/16-medium)
|
||||
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 3, C3, [256, False]], # 18 (P3/8-small)
|
||||
|
||||
[[18, 14, 10], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
55
utils/yolov5/models/hub/yolov5-p2.yaml
Normal file
55
utils/yolov5/models/hub/yolov5-p2.yaml
Normal file
@ -0,0 +1,55 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors: 3 # AutoAnchor evolves 3 anchors per P output layer
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head with (P2, P3, P4, P5) outputs
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [128, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 2], 1, Concat, [1]], # cat backbone P2
|
||||
[-1, 1, C3, [128, False]], # 21 (P2/4-xsmall)
|
||||
|
||||
[-1, 1, Conv, [128, 3, 2]],
|
||||
[[-1, 18], 1, Concat, [1]], # cat head P3
|
||||
[-1, 3, C3, [256, False]], # 24 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 27 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 30 (P5/32-large)
|
||||
|
||||
[[21, 24, 27, 30], 1, Detect, [nc, anchors]], # Detect(P2, P3, P4, P5)
|
||||
]
|
42
utils/yolov5/models/hub/yolov5-p34.yaml
Normal file
42
utils/yolov5/models/hub/yolov5-p34.yaml
Normal file
@ -0,0 +1,42 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.50 # layer channel multiple
|
||||
anchors: 3 # AutoAnchor evolves 3 anchors per P output layer
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head with (P3, P4) outputs
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[[17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4)
|
||||
]
|
57
utils/yolov5/models/hub/yolov5-p6.yaml
Normal file
57
utils/yolov5/models/hub/yolov5-p6.yaml
Normal file
@ -0,0 +1,57 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors: 3 # AutoAnchor evolves 3 anchors per P output layer
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [768]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 11
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head with (P3, P4, P5, P6) outputs
|
||||
head: [
|
||||
[-1, 1, Conv, [768, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P5
|
||||
[-1, 3, C3, [768, False]], # 15
|
||||
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 19
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 20], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 16], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
|
||||
|
||||
[-1, 1, Conv, [768, 3, 2]],
|
||||
[[-1, 12], 1, Concat, [1]], # cat head P6
|
||||
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
|
||||
|
||||
[[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5, P6)
|
||||
]
|
68
utils/yolov5/models/hub/yolov5-p7.yaml
Normal file
68
utils/yolov5/models/hub/yolov5-p7.yaml
Normal file
@ -0,0 +1,68 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors: 3 # AutoAnchor evolves 3 anchors per P output layer
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [768]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, Conv, [1280, 3, 2]], # 11-P7/128
|
||||
[-1, 3, C3, [1280]],
|
||||
[-1, 1, SPPF, [1280, 5]], # 13
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head with (P3, P4, P5, P6, P7) outputs
|
||||
head: [
|
||||
[-1, 1, Conv, [1024, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat backbone P6
|
||||
[-1, 3, C3, [1024, False]], # 17
|
||||
|
||||
[-1, 1, Conv, [768, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P5
|
||||
[-1, 3, C3, [768, False]], # 21
|
||||
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 25
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 29 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 26], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 32 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 22], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [768, False]], # 35 (P5/32-large)
|
||||
|
||||
[-1, 1, Conv, [768, 3, 2]],
|
||||
[[-1, 18], 1, Concat, [1]], # cat head P6
|
||||
[-1, 3, C3, [1024, False]], # 38 (P6/64-xlarge)
|
||||
|
||||
[-1, 1, Conv, [1024, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P7
|
||||
[-1, 3, C3, [1280, False]], # 41 (P7/128-xxlarge)
|
||||
|
||||
[[29, 32, 35, 38, 41], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5, P6, P7)
|
||||
]
|
49
utils/yolov5/models/hub/yolov5-panet.yaml
Normal file
49
utils/yolov5/models/hub/yolov5-panet.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 PANet head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
61
utils/yolov5/models/hub/yolov5l6.yaml
Normal file
61
utils/yolov5/models/hub/yolov5l6.yaml
Normal file
@ -0,0 +1,61 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors:
|
||||
- [19, 27, 44, 40, 38, 94] # P3/8
|
||||
- [96, 68, 86, 152, 180, 137] # P4/16
|
||||
- [140, 301, 303, 264, 238, 542] # P5/32
|
||||
- [436, 615, 739, 380, 925, 792] # P6/64
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [768]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 11
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [768, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P5
|
||||
[-1, 3, C3, [768, False]], # 15
|
||||
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 19
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 20], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 16], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
|
||||
|
||||
[-1, 1, Conv, [768, 3, 2]],
|
||||
[[-1, 12], 1, Concat, [1]], # cat head P6
|
||||
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
|
||||
|
||||
[[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5, P6)
|
||||
]
|
61
utils/yolov5/models/hub/yolov5m6.yaml
Normal file
61
utils/yolov5/models/hub/yolov5m6.yaml
Normal file
@ -0,0 +1,61 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.67 # model depth multiple
|
||||
width_multiple: 0.75 # layer channel multiple
|
||||
anchors:
|
||||
- [19, 27, 44, 40, 38, 94] # P3/8
|
||||
- [96, 68, 86, 152, 180, 137] # P4/16
|
||||
- [140, 301, 303, 264, 238, 542] # P5/32
|
||||
- [436, 615, 739, 380, 925, 792] # P6/64
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [768]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 11
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [768, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P5
|
||||
[-1, 3, C3, [768, False]], # 15
|
||||
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 19
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 20], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 16], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
|
||||
|
||||
[-1, 1, Conv, [768, 3, 2]],
|
||||
[[-1, 12], 1, Concat, [1]], # cat head P6
|
||||
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
|
||||
|
||||
[[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5, P6)
|
||||
]
|
61
utils/yolov5/models/hub/yolov5n6.yaml
Normal file
61
utils/yolov5/models/hub/yolov5n6.yaml
Normal file
@ -0,0 +1,61 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.25 # layer channel multiple
|
||||
anchors:
|
||||
- [19, 27, 44, 40, 38, 94] # P3/8
|
||||
- [96, 68, 86, 152, 180, 137] # P4/16
|
||||
- [140, 301, 303, 264, 238, 542] # P5/32
|
||||
- [436, 615, 739, 380, 925, 792] # P6/64
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [768]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 11
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [768, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P5
|
||||
[-1, 3, C3, [768, False]], # 15
|
||||
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 19
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 20], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 16], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
|
||||
|
||||
[-1, 1, Conv, [768, 3, 2]],
|
||||
[[-1, 12], 1, Concat, [1]], # cat head P6
|
||||
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
|
||||
|
||||
[[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5, P6)
|
||||
]
|
50
utils/yolov5/models/hub/yolov5s-LeakyReLU.yaml
Normal file
50
utils/yolov5/models/hub/yolov5s-LeakyReLU.yaml
Normal file
@ -0,0 +1,50 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
activation: nn.LeakyReLU(0.1) # <----- Conv() activation used throughout entire YOLOv5 model
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.50 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/hub/yolov5s-ghost.yaml
Normal file
49
utils/yolov5/models/hub/yolov5s-ghost.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.50 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, GhostConv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3Ghost, [128]],
|
||||
[-1, 1, GhostConv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3Ghost, [256]],
|
||||
[-1, 1, GhostConv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3Ghost, [512]],
|
||||
[-1, 1, GhostConv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3Ghost, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, GhostConv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3Ghost, [512, False]], # 13
|
||||
|
||||
[-1, 1, GhostConv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3Ghost, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, GhostConv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3Ghost, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, GhostConv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3Ghost, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/hub/yolov5s-transformer.yaml
Normal file
49
utils/yolov5/models/hub/yolov5s-transformer.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.50 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3TR, [1024]], # 9 <--- C3TR() Transformer module
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
61
utils/yolov5/models/hub/yolov5s6.yaml
Normal file
61
utils/yolov5/models/hub/yolov5s6.yaml
Normal file
@ -0,0 +1,61 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.50 # layer channel multiple
|
||||
anchors:
|
||||
- [19, 27, 44, 40, 38, 94] # P3/8
|
||||
- [96, 68, 86, 152, 180, 137] # P4/16
|
||||
- [140, 301, 303, 264, 238, 542] # P5/32
|
||||
- [436, 615, 739, 380, 925, 792] # P6/64
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [768]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 11
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [768, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P5
|
||||
[-1, 3, C3, [768, False]], # 15
|
||||
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 19
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 20], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 16], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
|
||||
|
||||
[-1, 1, Conv, [768, 3, 2]],
|
||||
[[-1, 12], 1, Concat, [1]], # cat head P6
|
||||
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
|
||||
|
||||
[[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5, P6)
|
||||
]
|
61
utils/yolov5/models/hub/yolov5x6.yaml
Normal file
61
utils/yolov5/models/hub/yolov5x6.yaml
Normal file
@ -0,0 +1,61 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.33 # model depth multiple
|
||||
width_multiple: 1.25 # layer channel multiple
|
||||
anchors:
|
||||
- [19, 27, 44, 40, 38, 94] # P3/8
|
||||
- [96, 68, 86, 152, 180, 137] # P4/16
|
||||
- [140, 301, 303, 264, 238, 542] # P5/32
|
||||
- [436, 615, 739, 380, 925, 792] # P6/64
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [768]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 11
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [768, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 8], 1, Concat, [1]], # cat backbone P5
|
||||
[-1, 3, C3, [768, False]], # 15
|
||||
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 19
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 20], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 16], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
|
||||
|
||||
[-1, 1, Conv, [768, 3, 2]],
|
||||
[[-1, 12], 1, Concat, [1]], # cat head P6
|
||||
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
|
||||
|
||||
[[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5, P6)
|
||||
]
|
49
utils/yolov5/models/segment/yolov5l-seg.yaml
Normal file
49
utils/yolov5/models/segment/yolov5l-seg.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/segment/yolov5m-seg.yaml
Normal file
49
utils/yolov5/models/segment/yolov5m-seg.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.67 # model depth multiple
|
||||
width_multiple: 0.75 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/segment/yolov5n-seg.yaml
Normal file
49
utils/yolov5/models/segment/yolov5n-seg.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.25 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/segment/yolov5s-seg.yaml
Normal file
49
utils/yolov5/models/segment/yolov5s-seg.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.5 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/segment/yolov5x-seg.yaml
Normal file
49
utils/yolov5/models/segment/yolov5x-seg.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.33 # model depth multiple
|
||||
width_multiple: 1.25 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||
]
|
797
utils/yolov5/models/tf.py
Normal file
797
utils/yolov5/models/tf.py
Normal file
@ -0,0 +1,797 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
TensorFlow, Keras and TFLite versions of YOLOv5
|
||||
Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127.
|
||||
|
||||
Usage:
|
||||
$ python models/tf.py --weights yolov5s.pt
|
||||
|
||||
Export:
|
||||
$ python export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[1] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
# ROOT = ROOT.relative_to(Path.cwd()) # relative
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tensorflow import keras
|
||||
|
||||
from models.common import (
|
||||
C3,
|
||||
SPP,
|
||||
SPPF,
|
||||
Bottleneck,
|
||||
BottleneckCSP,
|
||||
C3x,
|
||||
Concat,
|
||||
Conv,
|
||||
CrossConv,
|
||||
DWConv,
|
||||
DWConvTranspose2d,
|
||||
Focus,
|
||||
autopad,
|
||||
)
|
||||
from models.experimental import MixConv2d, attempt_load
|
||||
from models.yolo import Detect, Segment
|
||||
from utils.activations import SiLU
|
||||
from utils.general import LOGGER, make_divisible, print_args
|
||||
|
||||
|
||||
class TFBN(keras.layers.Layer):
|
||||
"""TensorFlow BatchNormalization wrapper for initializing with optional pretrained weights."""
|
||||
|
||||
def __init__(self, w=None):
|
||||
"""Initializes a TensorFlow BatchNormalization layer with optional pretrained weights."""
|
||||
super().__init__()
|
||||
self.bn = keras.layers.BatchNormalization(
|
||||
beta_initializer=keras.initializers.Constant(w.bias.numpy()),
|
||||
gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
|
||||
moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
|
||||
moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
|
||||
epsilon=w.eps,
|
||||
)
|
||||
|
||||
def call(self, inputs):
|
||||
"""Applies batch normalization to the inputs."""
|
||||
return self.bn(inputs)
|
||||
|
||||
|
||||
class TFPad(keras.layers.Layer):
|
||||
"""Pads input tensors in spatial dimensions 1 and 2 with specified integer or tuple padding values."""
|
||||
|
||||
def __init__(self, pad):
|
||||
"""
|
||||
Initializes a padding layer for spatial dimensions 1 and 2 with specified padding, supporting both int and tuple
|
||||
inputs.
|
||||
|
||||
Inputs are
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(pad, int):
|
||||
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
|
||||
else: # tuple/list
|
||||
self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]])
|
||||
|
||||
def call(self, inputs):
|
||||
"""Pads input tensor with zeros using specified padding, suitable for int and tuple pad dimensions."""
|
||||
return tf.pad(inputs, self.pad, mode="constant", constant_values=0)
|
||||
|
||||
|
||||
class TFConv(keras.layers.Layer):
|
||||
"""Implements a standard convolutional layer with optional batch normalization and activation for TensorFlow."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
|
||||
"""
|
||||
Initializes a standard convolution layer with optional batch normalization and activation; supports only
|
||||
group=1.
|
||||
|
||||
Inputs are ch_in, ch_out, weights, kernel, stride, padding, groups.
|
||||
"""
|
||||
super().__init__()
|
||||
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
|
||||
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
|
||||
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
|
||||
conv = keras.layers.Conv2D(
|
||||
filters=c2,
|
||||
kernel_size=k,
|
||||
strides=s,
|
||||
padding="SAME" if s == 1 else "VALID",
|
||||
use_bias=not hasattr(w, "bn"),
|
||||
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
|
||||
bias_initializer="zeros" if hasattr(w, "bn") else keras.initializers.Constant(w.conv.bias.numpy()),
|
||||
)
|
||||
self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
|
||||
self.bn = TFBN(w.bn) if hasattr(w, "bn") else tf.identity
|
||||
self.act = activations(w.act) if act else tf.identity
|
||||
|
||||
def call(self, inputs):
|
||||
"""Applies convolution, batch normalization, and activation function to input tensors."""
|
||||
return self.act(self.bn(self.conv(inputs)))
|
||||
|
||||
|
||||
class TFDWConv(keras.layers.Layer):
|
||||
"""Initializes a depthwise convolution layer with optional batch normalization and activation for TensorFlow."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
|
||||
"""
|
||||
Initializes a depthwise convolution layer with optional batch normalization and activation for TensorFlow
|
||||
models.
|
||||
|
||||
Input are ch_in, ch_out, weights, kernel, stride, padding, groups.
|
||||
"""
|
||||
super().__init__()
|
||||
assert c2 % c1 == 0, f"TFDWConv() output={c2} must be a multiple of input={c1} channels"
|
||||
conv = keras.layers.DepthwiseConv2D(
|
||||
kernel_size=k,
|
||||
depth_multiplier=c2 // c1,
|
||||
strides=s,
|
||||
padding="SAME" if s == 1 else "VALID",
|
||||
use_bias=not hasattr(w, "bn"),
|
||||
depthwise_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
|
||||
bias_initializer="zeros" if hasattr(w, "bn") else keras.initializers.Constant(w.conv.bias.numpy()),
|
||||
)
|
||||
self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
|
||||
self.bn = TFBN(w.bn) if hasattr(w, "bn") else tf.identity
|
||||
self.act = activations(w.act) if act else tf.identity
|
||||
|
||||
def call(self, inputs):
|
||||
"""Applies convolution, batch normalization, and activation function to input tensors."""
|
||||
return self.act(self.bn(self.conv(inputs)))
|
||||
|
||||
|
||||
class TFDWConvTranspose2d(keras.layers.Layer):
|
||||
"""Implements a depthwise ConvTranspose2D layer for TensorFlow with specific settings."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
|
||||
"""
|
||||
Initializes depthwise ConvTranspose2D layer with specific channel, kernel, stride, and padding settings.
|
||||
|
||||
Inputs are ch_in, ch_out, weights, kernel, stride, padding, groups.
|
||||
"""
|
||||
super().__init__()
|
||||
assert c1 == c2, f"TFDWConv() output={c2} must be equal to input={c1} channels"
|
||||
assert k == 4 and p1 == 1, "TFDWConv() only valid for k=4 and p1=1"
|
||||
weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy()
|
||||
self.c1 = c1
|
||||
self.conv = [
|
||||
keras.layers.Conv2DTranspose(
|
||||
filters=1,
|
||||
kernel_size=k,
|
||||
strides=s,
|
||||
padding="VALID",
|
||||
output_padding=p2,
|
||||
use_bias=True,
|
||||
kernel_initializer=keras.initializers.Constant(weight[..., i : i + 1]),
|
||||
bias_initializer=keras.initializers.Constant(bias[i]),
|
||||
)
|
||||
for i in range(c1)
|
||||
]
|
||||
|
||||
def call(self, inputs):
|
||||
"""Processes input through parallel convolutions and concatenates results, trimming border pixels."""
|
||||
return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1]
|
||||
|
||||
|
||||
class TFFocus(keras.layers.Layer):
|
||||
"""Focuses spatial information into channel space using pixel shuffling and convolution for TensorFlow models."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
|
||||
"""
|
||||
Initializes TFFocus layer to focus width and height information into channel space with custom convolution
|
||||
parameters.
|
||||
|
||||
Inputs are ch_in, ch_out, kernel, stride, padding, groups.
|
||||
"""
|
||||
super().__init__()
|
||||
self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
|
||||
|
||||
def call(self, inputs):
|
||||
"""
|
||||
Performs pixel shuffling and convolution on input tensor, downsampling by 2 and expanding channels by 4.
|
||||
|
||||
Example x(b,w,h,c) -> y(b,w/2,h/2,4c).
|
||||
"""
|
||||
inputs = [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]]
|
||||
return self.conv(tf.concat(inputs, 3))
|
||||
|
||||
|
||||
class TFBottleneck(keras.layers.Layer):
|
||||
"""Implements a TensorFlow bottleneck layer with optional shortcut connections for efficient feature extraction."""
|
||||
|
||||
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None):
|
||||
"""
|
||||
Initializes a standard bottleneck layer for TensorFlow models, expanding and contracting channels with optional
|
||||
shortcut.
|
||||
|
||||
Arguments are ch_in, ch_out, shortcut, groups, expansion.
|
||||
"""
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
||||
self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
|
||||
self.add = shortcut and c1 == c2
|
||||
|
||||
def call(self, inputs):
|
||||
"""Performs forward pass; if shortcut is True & input/output channels match, adds input to the convolution
|
||||
result.
|
||||
"""
|
||||
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
|
||||
|
||||
|
||||
class TFCrossConv(keras.layers.Layer):
|
||||
"""Implements a cross convolutional layer with optional expansion, grouping, and shortcut for TensorFlow."""
|
||||
|
||||
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None):
|
||||
"""Initializes cross convolution layer with optional expansion, grouping, and shortcut addition capabilities."""
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1)
|
||||
self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2)
|
||||
self.add = shortcut and c1 == c2
|
||||
|
||||
def call(self, inputs):
|
||||
"""Passes input through two convolutions optionally adding the input if channel dimensions match."""
|
||||
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
|
||||
|
||||
|
||||
class TFConv2d(keras.layers.Layer):
|
||||
"""Implements a TensorFlow 2D convolution layer, mimicking PyTorch's nn.Conv2D for specified filters and stride."""
|
||||
|
||||
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
|
||||
"""Initializes a TensorFlow 2D convolution layer, mimicking PyTorch's nn.Conv2D functionality for given filter
|
||||
sizes and stride.
|
||||
"""
|
||||
super().__init__()
|
||||
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
|
||||
self.conv = keras.layers.Conv2D(
|
||||
filters=c2,
|
||||
kernel_size=k,
|
||||
strides=s,
|
||||
padding="VALID",
|
||||
use_bias=bias,
|
||||
kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
|
||||
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None,
|
||||
)
|
||||
|
||||
def call(self, inputs):
|
||||
"""Applies a convolution operation to the inputs and returns the result."""
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class TFBottleneckCSP(keras.layers.Layer):
|
||||
"""Implements a CSP bottleneck layer for TensorFlow models to enhance gradient flow and efficiency."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
|
||||
"""
|
||||
Initializes CSP bottleneck layer with specified channel sizes, count, shortcut option, groups, and expansion
|
||||
ratio.
|
||||
|
||||
Inputs are ch_in, ch_out, number, shortcut, groups, expansion.
|
||||
"""
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
||||
self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
|
||||
self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
|
||||
self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
|
||||
self.bn = TFBN(w.bn)
|
||||
self.act = lambda x: keras.activations.swish(x)
|
||||
self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
|
||||
|
||||
def call(self, inputs):
|
||||
"""Processes input through the model layers, concatenates, normalizes, activates, and reduces the output
|
||||
dimensions.
|
||||
"""
|
||||
y1 = self.cv3(self.m(self.cv1(inputs)))
|
||||
y2 = self.cv2(inputs)
|
||||
return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
|
||||
|
||||
|
||||
class TFC3(keras.layers.Layer):
|
||||
"""CSP bottleneck layer with 3 convolutions for TensorFlow, supporting optional shortcuts and group convolutions."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
|
||||
"""
|
||||
Initializes CSP Bottleneck with 3 convolutions, supporting optional shortcuts and group convolutions.
|
||||
|
||||
Inputs are ch_in, ch_out, number, shortcut, groups, expansion.
|
||||
"""
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
||||
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
|
||||
self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
|
||||
self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
|
||||
|
||||
def call(self, inputs):
|
||||
"""
|
||||
Processes input through a sequence of transformations for object detection (YOLOv5).
|
||||
|
||||
See https://github.com/ultralytics/yolov5.
|
||||
"""
|
||||
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
|
||||
|
||||
|
||||
class TFC3x(keras.layers.Layer):
|
||||
"""A TensorFlow layer for enhanced feature extraction using cross-convolutions in object detection models."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
|
||||
"""
|
||||
Initializes layer with cross-convolutions for enhanced feature extraction in object detection models.
|
||||
|
||||
Inputs are ch_in, ch_out, number, shortcut, groups, expansion.
|
||||
"""
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
||||
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
|
||||
self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
|
||||
self.m = keras.Sequential(
|
||||
[TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)]
|
||||
)
|
||||
|
||||
def call(self, inputs):
|
||||
"""Processes input through cascaded convolutions and merges features, returning the final tensor output."""
|
||||
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
|
||||
|
||||
|
||||
class TFSPP(keras.layers.Layer):
|
||||
"""Implements spatial pyramid pooling for YOLOv3-SPP with specific channels and kernel sizes."""
|
||||
|
||||
def __init__(self, c1, c2, k=(5, 9, 13), w=None):
|
||||
"""Initializes a YOLOv3-SPP layer with specific input/output channels and kernel sizes for pooling."""
|
||||
super().__init__()
|
||||
c_ = c1 // 2 # hidden channels
|
||||
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
||||
self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
|
||||
self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding="SAME") for x in k]
|
||||
|
||||
def call(self, inputs):
|
||||
"""Processes input through two TFConv layers and concatenates with max-pooled outputs at intermediate stage."""
|
||||
x = self.cv1(inputs)
|
||||
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
|
||||
|
||||
|
||||
class TFSPPF(keras.layers.Layer):
|
||||
"""Implements a fast spatial pyramid pooling layer for TensorFlow with optimized feature extraction."""
|
||||
|
||||
def __init__(self, c1, c2, k=5, w=None):
|
||||
"""Initializes a fast spatial pyramid pooling layer with customizable in/out channels, kernel size, and
|
||||
weights.
|
||||
"""
|
||||
super().__init__()
|
||||
c_ = c1 // 2 # hidden channels
|
||||
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
||||
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
|
||||
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding="SAME")
|
||||
|
||||
def call(self, inputs):
|
||||
"""Executes the model's forward pass, concatenating input features with three max-pooled versions before final
|
||||
convolution.
|
||||
"""
|
||||
x = self.cv1(inputs)
|
||||
y1 = self.m(x)
|
||||
y2 = self.m(y1)
|
||||
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
|
||||
|
||||
|
||||
class TFDetect(keras.layers.Layer):
|
||||
"""Implements YOLOv5 object detection layer in TensorFlow for predicting bounding boxes and class probabilities."""
|
||||
|
||||
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None):
|
||||
"""Initializes YOLOv5 detection layer for TensorFlow with configurable classes, anchors, channels, and image
|
||||
size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
|
||||
self.nc = nc # number of classes
|
||||
self.no = nc + 5 # number of outputs per anchor
|
||||
self.nl = len(anchors) # number of detection layers
|
||||
self.na = len(anchors[0]) // 2 # number of anchors
|
||||
self.grid = [tf.zeros(1)] * self.nl # init grid
|
||||
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
|
||||
self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]), [self.nl, 1, -1, 1, 2])
|
||||
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
|
||||
self.training = False # set to False after building model
|
||||
self.imgsz = imgsz
|
||||
for i in range(self.nl):
|
||||
ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
|
||||
self.grid[i] = self._make_grid(nx, ny)
|
||||
|
||||
def call(self, inputs):
|
||||
"""Performs forward pass through the model layers to predict object bounding boxes and classifications."""
|
||||
z = [] # inference output
|
||||
x = []
|
||||
for i in range(self.nl):
|
||||
x.append(self.m[i](inputs[i]))
|
||||
# x(bs,20,20,255) to x(bs,3,20,20,85)
|
||||
ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
|
||||
x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
|
||||
|
||||
if not self.training: # inference
|
||||
y = x[i]
|
||||
grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
|
||||
anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
|
||||
xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i] # xy
|
||||
wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
|
||||
# Normalize xywh to 0-1 to reduce calibration error
|
||||
xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
|
||||
wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
|
||||
y = tf.concat([xy, wh, tf.sigmoid(y[..., 4 : 5 + self.nc]), y[..., 5 + self.nc :]], -1)
|
||||
z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
|
||||
|
||||
return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1),)
|
||||
|
||||
@staticmethod
|
||||
def _make_grid(nx=20, ny=20):
|
||||
"""Generates a 2D grid of coordinates in (x, y) format with shape [1, 1, ny*nx, 2]."""
|
||||
# return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
||||
xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
|
||||
return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
|
||||
|
||||
|
||||
class TFSegment(TFDetect):
|
||||
"""YOLOv5 segmentation head for TensorFlow, combining detection and segmentation."""
|
||||
|
||||
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None):
|
||||
"""Initializes YOLOv5 Segment head with specified channel depths, anchors, and input size for segmentation
|
||||
models.
|
||||
"""
|
||||
super().__init__(nc, anchors, ch, imgsz, w)
|
||||
self.nm = nm # number of masks
|
||||
self.npr = npr # number of protos
|
||||
self.no = 5 + nc + self.nm # number of outputs per anchor
|
||||
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] # output conv
|
||||
self.proto = TFProto(ch[0], self.npr, self.nm, w=w.proto) # protos
|
||||
self.detect = TFDetect.call
|
||||
|
||||
def call(self, x):
|
||||
"""Applies detection and proto layers on input, returning detections and optionally protos if training."""
|
||||
p = self.proto(x[0])
|
||||
# p = TFUpsample(None, scale_factor=4, mode='nearest')(self.proto(x[0])) # (optional) full-size protos
|
||||
p = tf.transpose(p, [0, 3, 1, 2]) # from shape(1,160,160,32) to shape(1,32,160,160)
|
||||
x = self.detect(self, x)
|
||||
return (x, p) if self.training else (x[0], p)
|
||||
|
||||
|
||||
class TFProto(keras.layers.Layer):
|
||||
"""Implements convolutional and upsampling layers for feature extraction in YOLOv5 segmentation."""
|
||||
|
||||
def __init__(self, c1, c_=256, c2=32, w=None):
|
||||
"""Initializes TFProto layer with convolutional and upsampling layers for feature extraction and
|
||||
transformation.
|
||||
"""
|
||||
super().__init__()
|
||||
self.cv1 = TFConv(c1, c_, k=3, w=w.cv1)
|
||||
self.upsample = TFUpsample(None, scale_factor=2, mode="nearest")
|
||||
self.cv2 = TFConv(c_, c_, k=3, w=w.cv2)
|
||||
self.cv3 = TFConv(c_, c2, w=w.cv3)
|
||||
|
||||
def call(self, inputs):
|
||||
"""Performs forward pass through the model, applying convolutions and upscaling on input tensor."""
|
||||
return self.cv3(self.cv2(self.upsample(self.cv1(inputs))))
|
||||
|
||||
|
||||
class TFUpsample(keras.layers.Layer):
|
||||
"""Implements a TensorFlow upsampling layer with specified size, scale factor, and interpolation mode."""
|
||||
|
||||
def __init__(self, size, scale_factor, mode, w=None):
|
||||
"""
|
||||
Initializes a TensorFlow upsampling layer with specified size, scale_factor, and mode, ensuring scale_factor is
|
||||
even.
|
||||
|
||||
Warning: all arguments needed including 'w'
|
||||
"""
|
||||
super().__init__()
|
||||
assert scale_factor % 2 == 0, "scale_factor must be multiple of 2"
|
||||
self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * scale_factor, x.shape[2] * scale_factor), mode)
|
||||
# self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
|
||||
# with default arguments: align_corners=False, half_pixel_centers=False
|
||||
# self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
|
||||
# size=(x.shape[1] * 2, x.shape[2] * 2))
|
||||
|
||||
def call(self, inputs):
|
||||
"""Applies upsample operation to inputs using nearest neighbor interpolation."""
|
||||
return self.upsample(inputs)
|
||||
|
||||
|
||||
class TFConcat(keras.layers.Layer):
|
||||
"""Implements TensorFlow's version of torch.concat() for concatenating tensors along the last dimension."""
|
||||
|
||||
def __init__(self, dimension=1, w=None):
|
||||
"""Initializes a TensorFlow layer for NCHW to NHWC concatenation, requiring dimension=1."""
|
||||
super().__init__()
|
||||
assert dimension == 1, "convert only NCHW to NHWC concat"
|
||||
self.d = 3
|
||||
|
||||
def call(self, inputs):
|
||||
"""Concatenates a list of tensors along the last dimension, used for NCHW to NHWC conversion."""
|
||||
return tf.concat(inputs, self.d)
|
||||
|
||||
|
||||
def parse_model(d, ch, model, imgsz):
|
||||
"""Parses a model definition dict `d` to create YOLOv5 model layers, including dynamic channel adjustments."""
|
||||
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
||||
anchors, nc, gd, gw, ch_mul = (
|
||||
d["anchors"],
|
||||
d["nc"],
|
||||
d["depth_multiple"],
|
||||
d["width_multiple"],
|
||||
d.get("channel_multiple"),
|
||||
)
|
||||
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
||||
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
||||
if not ch_mul:
|
||||
ch_mul = 8
|
||||
|
||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
||||
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
||||
m_str = m
|
||||
m = eval(m) if isinstance(m, str) else m # eval strings
|
||||
for j, a in enumerate(args):
|
||||
try:
|
||||
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
n = max(round(n * gd), 1) if n > 1 else n # depth gain
|
||||
if m in [
|
||||
nn.Conv2d,
|
||||
Conv,
|
||||
DWConv,
|
||||
DWConvTranspose2d,
|
||||
Bottleneck,
|
||||
SPP,
|
||||
SPPF,
|
||||
MixConv2d,
|
||||
Focus,
|
||||
CrossConv,
|
||||
BottleneckCSP,
|
||||
C3,
|
||||
C3x,
|
||||
]:
|
||||
c1, c2 = ch[f], args[0]
|
||||
c2 = make_divisible(c2 * gw, ch_mul) if c2 != no else c2
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
if m in [BottleneckCSP, C3, C3x]:
|
||||
args.insert(2, n)
|
||||
n = 1
|
||||
elif m is nn.BatchNorm2d:
|
||||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
|
||||
elif m in [Detect, Segment]:
|
||||
args.append([ch[x + 1] for x in f])
|
||||
if isinstance(args[1], int): # number of anchors
|
||||
args[1] = [list(range(args[1] * 2))] * len(f)
|
||||
if m is Segment:
|
||||
args[3] = make_divisible(args[3] * gw, ch_mul)
|
||||
args.append(imgsz)
|
||||
else:
|
||||
c2 = ch[f]
|
||||
|
||||
tf_m = eval("TF" + m_str.replace("nn.", ""))
|
||||
m_ = (
|
||||
keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)])
|
||||
if n > 1
|
||||
else tf_m(*args, w=model.model[i])
|
||||
) # module
|
||||
|
||||
torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
||||
t = str(m)[8:-2].replace("__main__.", "") # module type
|
||||
np = sum(x.numel() for x in torch_m_.parameters()) # number params
|
||||
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
|
||||
LOGGER.info(f"{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}") # print
|
||||
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
||||
layers.append(m_)
|
||||
ch.append(c2)
|
||||
return keras.Sequential(layers), sorted(save)
|
||||
|
||||
|
||||
class TFModel:
|
||||
"""Implements YOLOv5 model in TensorFlow, supporting TensorFlow, Keras, and TFLite formats for object detection."""
|
||||
|
||||
def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None, model=None, imgsz=(640, 640)):
|
||||
"""Initializes TF YOLOv5 model with specified configuration, channels, classes, model instance, and input
|
||||
size.
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(cfg, dict):
|
||||
self.yaml = cfg # model dict
|
||||
else: # is *.yaml
|
||||
import yaml # for torch hub
|
||||
|
||||
self.yaml_file = Path(cfg).name
|
||||
with open(cfg) as f:
|
||||
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
||||
|
||||
# Define model
|
||||
if nc and nc != self.yaml["nc"]:
|
||||
LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml["nc"] = nc # override yaml value
|
||||
self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
inputs,
|
||||
tf_nms=False,
|
||||
agnostic_nms=False,
|
||||
topk_per_class=100,
|
||||
topk_all=100,
|
||||
iou_thres=0.45,
|
||||
conf_thres=0.25,
|
||||
):
|
||||
"""Runs inference on input data, with an option for TensorFlow NMS."""
|
||||
y = [] # outputs
|
||||
x = inputs
|
||||
for m in self.model.layers:
|
||||
if m.f != -1: # if not from previous layer
|
||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
||||
|
||||
x = m(x) # run
|
||||
y.append(x if m.i in self.savelist else None) # save output
|
||||
|
||||
# Add TensorFlow NMS
|
||||
if tf_nms:
|
||||
boxes = self._xywh2xyxy(x[0][..., :4])
|
||||
probs = x[0][:, :, 4:5]
|
||||
classes = x[0][:, :, 5:]
|
||||
scores = probs * classes
|
||||
if agnostic_nms:
|
||||
nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
|
||||
else:
|
||||
boxes = tf.expand_dims(boxes, 2)
|
||||
nms = tf.image.combined_non_max_suppression(
|
||||
boxes, scores, topk_per_class, topk_all, iou_thres, conf_thres, clip_boxes=False
|
||||
)
|
||||
return (nms,)
|
||||
return x # output [1,6300,85] = [xywh, conf, class0, class1, ...]
|
||||
# x = x[0] # [x(1,6300,85), ...] to x(6300,85)
|
||||
# xywh = x[..., :4] # x(6300,4) boxes
|
||||
# conf = x[..., 4:5] # x(6300,1) confidences
|
||||
# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
|
||||
# return tf.concat([conf, cls, xywh], 1)
|
||||
|
||||
@staticmethod
|
||||
def _xywh2xyxy(xywh):
|
||||
"""Converts bounding box format from [x, y, w, h] to [x1, y1, x2, y2], where xy1=top-left and xy2=bottom-
|
||||
right.
|
||||
"""
|
||||
x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
|
||||
return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
|
||||
|
||||
|
||||
class AgnosticNMS(keras.layers.Layer):
|
||||
"""Performs agnostic non-maximum suppression (NMS) on detected objects using IoU and confidence thresholds."""
|
||||
|
||||
def call(self, input, topk_all, iou_thres, conf_thres):
|
||||
"""Performs agnostic NMS on input tensors using given thresholds and top-K selection."""
|
||||
return tf.map_fn(
|
||||
lambda x: self._nms(x, topk_all, iou_thres, conf_thres),
|
||||
input,
|
||||
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
|
||||
name="agnostic_nms",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25):
|
||||
"""Performs agnostic non-maximum suppression (NMS) on detected objects, filtering based on IoU and confidence
|
||||
thresholds.
|
||||
"""
|
||||
boxes, classes, scores = x
|
||||
class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
|
||||
scores_inp = tf.reduce_max(scores, -1)
|
||||
selected_inds = tf.image.non_max_suppression(
|
||||
boxes, scores_inp, max_output_size=topk_all, iou_threshold=iou_thres, score_threshold=conf_thres
|
||||
)
|
||||
selected_boxes = tf.gather(boxes, selected_inds)
|
||||
padded_boxes = tf.pad(
|
||||
selected_boxes,
|
||||
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
|
||||
mode="CONSTANT",
|
||||
constant_values=0.0,
|
||||
)
|
||||
selected_scores = tf.gather(scores_inp, selected_inds)
|
||||
padded_scores = tf.pad(
|
||||
selected_scores,
|
||||
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
|
||||
mode="CONSTANT",
|
||||
constant_values=-1.0,
|
||||
)
|
||||
selected_classes = tf.gather(class_inds, selected_inds)
|
||||
padded_classes = tf.pad(
|
||||
selected_classes,
|
||||
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
|
||||
mode="CONSTANT",
|
||||
constant_values=-1.0,
|
||||
)
|
||||
valid_detections = tf.shape(selected_inds)[0]
|
||||
return padded_boxes, padded_scores, padded_classes, valid_detections
|
||||
|
||||
|
||||
def activations(act=nn.SiLU):
|
||||
"""Converts PyTorch activations to TensorFlow equivalents, supporting LeakyReLU, Hardswish, and SiLU/Swish."""
|
||||
if isinstance(act, nn.LeakyReLU):
|
||||
return lambda x: keras.activations.relu(x, alpha=0.1)
|
||||
elif isinstance(act, nn.Hardswish):
|
||||
return lambda x: x * tf.nn.relu6(x + 3) * 0.166666667
|
||||
elif isinstance(act, (nn.SiLU, SiLU)):
|
||||
return lambda x: keras.activations.swish(x)
|
||||
else:
|
||||
raise Exception(f"no matching TensorFlow activation found for PyTorch activation {act}")
|
||||
|
||||
|
||||
def representative_dataset_gen(dataset, ncalib=100):
|
||||
"""Generates a representative dataset for calibration by yielding transformed numpy arrays from the input
|
||||
dataset.
|
||||
"""
|
||||
for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
|
||||
im = np.transpose(img, [1, 2, 0])
|
||||
im = np.expand_dims(im, axis=0).astype(np.float32)
|
||||
im /= 255
|
||||
yield [im]
|
||||
if n >= ncalib:
|
||||
break
|
||||
|
||||
|
||||
def run(
|
||||
weights=ROOT / "yolov5s.pt", # weights path
|
||||
imgsz=(640, 640), # inference size h,w
|
||||
batch_size=1, # batch size
|
||||
dynamic=False, # dynamic batch size
|
||||
):
|
||||
# PyTorch model
|
||||
"""Exports YOLOv5 model from PyTorch to TensorFlow and Keras formats, performing inference for validation."""
|
||||
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
||||
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
|
||||
_ = model(im) # inference
|
||||
model.info()
|
||||
|
||||
# TensorFlow model
|
||||
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
|
||||
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
|
||||
_ = tf_model.predict(im) # inference
|
||||
|
||||
# Keras model
|
||||
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
|
||||
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
|
||||
keras_model.summary()
|
||||
|
||||
LOGGER.info("PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.")
|
||||
|
||||
|
||||
def parse_opt():
|
||||
"""Parses and returns command-line options for model inference, including weights path, image size, batch size, and
|
||||
dynamic batching.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--weights", type=str, default=ROOT / "yolov5s.pt", help="weights path")
|
||||
parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[640], help="inference size h,w")
|
||||
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--dynamic", action="store_true", help="dynamic batch size")
|
||||
opt = parser.parse_args()
|
||||
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
|
||||
print_args(vars(opt))
|
||||
return opt
|
||||
|
||||
|
||||
def main(opt):
|
||||
"""Executes the YOLOv5 model run function with parsed command line options."""
|
||||
run(**vars(opt))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_opt()
|
||||
main(opt)
|
495
utils/yolov5/models/yolo.py
Normal file
495
utils/yolov5/models/yolo.py
Normal file
@ -0,0 +1,495 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
YOLO-specific modules.
|
||||
|
||||
Usage:
|
||||
$ python models/yolo.py --cfg yolov5s.yaml
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[1] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
if platform.system() != "Windows":
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
from models.common import (
|
||||
C3,
|
||||
C3SPP,
|
||||
C3TR,
|
||||
SPP,
|
||||
SPPF,
|
||||
Bottleneck,
|
||||
BottleneckCSP,
|
||||
C3Ghost,
|
||||
C3x,
|
||||
Classify,
|
||||
Concat,
|
||||
Contract,
|
||||
Conv,
|
||||
CrossConv,
|
||||
DetectMultiBackend,
|
||||
DWConv,
|
||||
DWConvTranspose2d,
|
||||
Expand,
|
||||
Focus,
|
||||
GhostBottleneck,
|
||||
GhostConv,
|
||||
Proto,
|
||||
)
|
||||
from models.experimental import MixConv2d
|
||||
from utils.autoanchor import check_anchor_order
|
||||
from utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
|
||||
from utils.plots import feature_visualization
|
||||
from utils.torch_utils import (
|
||||
fuse_conv_and_bn,
|
||||
initialize_weights,
|
||||
model_info,
|
||||
profile,
|
||||
scale_img,
|
||||
select_device,
|
||||
time_sync,
|
||||
)
|
||||
|
||||
try:
|
||||
import thop # for FLOPs computation
|
||||
except ImportError:
|
||||
thop = None
|
||||
|
||||
|
||||
class Detect(nn.Module):
|
||||
"""YOLOv5 Detect head for processing input tensors and generating detection outputs in object detection models."""
|
||||
|
||||
stride = None # strides computed during build
|
||||
dynamic = False # force grid reconstruction
|
||||
export = False # export mode
|
||||
|
||||
def __init__(self, nc=80, anchors=(), ch=(), inplace=True):
|
||||
"""Initializes YOLOv5 detection layer with specified classes, anchors, channels, and inplace operations."""
|
||||
super().__init__()
|
||||
self.nc = nc # number of classes
|
||||
self.no = nc + 5 # number of outputs per anchor
|
||||
self.nl = len(anchors) # number of detection layers
|
||||
self.na = len(anchors[0]) // 2 # number of anchors
|
||||
self.grid = [torch.empty(0) for _ in range(self.nl)] # init grid
|
||||
self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] # init anchor grid
|
||||
self.register_buffer("anchors", torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
||||
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
||||
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes input through YOLOv5 layers, altering shape for detection: `x(bs, 3, ny, nx, 85)`."""
|
||||
z = [] # inference output
|
||||
for i in range(self.nl):
|
||||
x[i] = self.m[i](x[i]) # conv
|
||||
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
||||
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
||||
|
||||
if not self.training: # inference
|
||||
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
||||
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
||||
|
||||
if isinstance(self, Segment): # (boxes + masks)
|
||||
xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
|
||||
xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy
|
||||
wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh
|
||||
y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
|
||||
else: # Detect (boxes only)
|
||||
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
|
||||
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
|
||||
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
|
||||
y = torch.cat((xy, wh, conf), 4)
|
||||
z.append(y.view(bs, self.na * nx * ny, self.no))
|
||||
|
||||
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
|
||||
|
||||
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, "1.10.0")):
|
||||
"""Generates a mesh grid for anchor boxes with optional compatibility for torch versions < 1.10."""
|
||||
d = self.anchors[i].device
|
||||
t = self.anchors[i].dtype
|
||||
shape = 1, self.na, ny, nx, 2 # grid shape
|
||||
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
||||
yv, xv = torch.meshgrid(y, x, indexing="ij") if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility
|
||||
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
|
||||
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
|
||||
return grid, anchor_grid
|
||||
|
||||
|
||||
class Segment(Detect):
|
||||
"""YOLOv5 Segment head for segmentation models, extending Detect with mask and prototype layers."""
|
||||
|
||||
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True):
|
||||
"""Initializes YOLOv5 Segment head with options for mask count, protos, and channel adjustments."""
|
||||
super().__init__(nc, anchors, ch, inplace)
|
||||
self.nm = nm # number of masks
|
||||
self.npr = npr # number of protos
|
||||
self.no = 5 + nc + self.nm # number of outputs per anchor
|
||||
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
||||
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
||||
self.detect = Detect.forward
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes input through the network, returning detections and prototypes; adjusts output based on
|
||||
training/export mode.
|
||||
"""
|
||||
p = self.proto(x[0])
|
||||
x = self.detect(self, x)
|
||||
return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
"""YOLOv5 base model."""
|
||||
|
||||
def forward(self, x, profile=False, visualize=False):
|
||||
"""Executes a single-scale inference or training pass on the YOLOv5 base model, with options for profiling and
|
||||
visualization.
|
||||
"""
|
||||
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
||||
|
||||
def _forward_once(self, x, profile=False, visualize=False):
|
||||
"""Performs a forward pass on the YOLOv5 model, enabling profiling and feature visualization options."""
|
||||
y, dt = [], [] # outputs
|
||||
for m in self.model:
|
||||
if m.f != -1: # if not from previous layer
|
||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
||||
if profile:
|
||||
self._profile_one_layer(m, x, dt)
|
||||
x = m(x) # run
|
||||
y.append(x if m.i in self.save else None) # save output
|
||||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
return x
|
||||
|
||||
def _profile_one_layer(self, m, x, dt):
|
||||
"""Profiles a single layer's performance by computing GFLOPs, execution time, and parameters."""
|
||||
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
||||
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs
|
||||
t = time_sync()
|
||||
for _ in range(10):
|
||||
m(x.copy() if c else x)
|
||||
dt.append((time_sync() - t) * 100)
|
||||
if m == self.model[0]:
|
||||
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
||||
LOGGER.info(f"{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}")
|
||||
if c:
|
||||
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
||||
|
||||
def fuse(self):
|
||||
"""Fuses Conv2d() and BatchNorm2d() layers in the model to improve inference speed."""
|
||||
LOGGER.info("Fusing layers... ")
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (Conv, DWConv)) and hasattr(m, "bn"):
|
||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||
delattr(m, "bn") # remove batchnorm
|
||||
m.forward = m.forward_fuse # update forward
|
||||
self.info()
|
||||
return self
|
||||
|
||||
def info(self, verbose=False, img_size=640):
|
||||
"""Prints model information given verbosity and image size, e.g., `info(verbose=True, img_size=640)`."""
|
||||
model_info(self, verbose, img_size)
|
||||
|
||||
def _apply(self, fn):
|
||||
"""Applies transformations like to(), cpu(), cuda(), half() to model tensors excluding parameters or registered
|
||||
buffers.
|
||||
"""
|
||||
self = super()._apply(fn)
|
||||
m = self.model[-1] # Detect()
|
||||
if isinstance(m, (Detect, Segment)):
|
||||
m.stride = fn(m.stride)
|
||||
m.grid = list(map(fn, m.grid))
|
||||
if isinstance(m.anchor_grid, list):
|
||||
m.anchor_grid = list(map(fn, m.anchor_grid))
|
||||
return self
|
||||
|
||||
|
||||
class DetectionModel(BaseModel):
|
||||
"""YOLOv5 detection model class for object detection tasks, supporting custom configurations and anchors."""
|
||||
|
||||
def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None, anchors=None):
|
||||
"""Initializes YOLOv5 model with configuration file, input channels, number of classes, and custom anchors."""
|
||||
super().__init__()
|
||||
if isinstance(cfg, dict):
|
||||
self.yaml = cfg # model dict
|
||||
else: # is *.yaml
|
||||
import yaml # for torch hub
|
||||
|
||||
self.yaml_file = Path(cfg).name
|
||||
with open(cfg, encoding="ascii", errors="ignore") as f:
|
||||
self.yaml = yaml.safe_load(f) # model dict
|
||||
|
||||
# Define model
|
||||
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
||||
if nc and nc != self.yaml["nc"]:
|
||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml["nc"] = nc # override yaml value
|
||||
if anchors:
|
||||
LOGGER.info(f"Overriding model.yaml anchors with anchors={anchors}")
|
||||
self.yaml["anchors"] = round(anchors) # override yaml value
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
||||
self.names = [str(i) for i in range(self.yaml["nc"])] # default names
|
||||
self.inplace = self.yaml.get("inplace", True)
|
||||
|
||||
# Build strides, anchors
|
||||
m = self.model[-1] # Detect()
|
||||
if isinstance(m, (Detect, Segment)):
|
||||
|
||||
def _forward(x):
|
||||
"""Passes the input 'x' through the model and returns the processed output."""
|
||||
return self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
|
||||
|
||||
s = 256 # 2x min stride
|
||||
m.inplace = self.inplace
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
|
||||
check_anchor_order(m)
|
||||
m.anchors /= m.stride.view(-1, 1, 1)
|
||||
self.stride = m.stride
|
||||
self._initialize_biases() # only run once
|
||||
|
||||
# Init weights, biases
|
||||
initialize_weights(self)
|
||||
self.info()
|
||||
LOGGER.info("")
|
||||
|
||||
def forward(self, x, augment=False, profile=False, visualize=False):
|
||||
"""Performs single-scale or augmented inference and may include profiling or visualization."""
|
||||
if augment:
|
||||
return self._forward_augment(x) # augmented inference, None
|
||||
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
||||
|
||||
def _forward_augment(self, x):
|
||||
"""Performs augmented inference across different scales and flips, returning combined detections."""
|
||||
img_size = x.shape[-2:] # height, width
|
||||
s = [1, 0.83, 0.67] # scales
|
||||
f = [None, 3, None] # flips (2-ud, 3-lr)
|
||||
y = [] # outputs
|
||||
for si, fi in zip(s, f):
|
||||
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
||||
yi = self._forward_once(xi)[0] # forward
|
||||
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
||||
yi = self._descale_pred(yi, fi, si, img_size)
|
||||
y.append(yi)
|
||||
y = self._clip_augmented(y) # clip augmented tails
|
||||
return torch.cat(y, 1), None # augmented inference, train
|
||||
|
||||
def _descale_pred(self, p, flips, scale, img_size):
|
||||
"""De-scales predictions from augmented inference, adjusting for flips and image size."""
|
||||
if self.inplace:
|
||||
p[..., :4] /= scale # de-scale
|
||||
if flips == 2:
|
||||
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
|
||||
elif flips == 3:
|
||||
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
|
||||
else:
|
||||
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
|
||||
if flips == 2:
|
||||
y = img_size[0] - y # de-flip ud
|
||||
elif flips == 3:
|
||||
x = img_size[1] - x # de-flip lr
|
||||
p = torch.cat((x, y, wh, p[..., 4:]), -1)
|
||||
return p
|
||||
|
||||
def _clip_augmented(self, y):
|
||||
"""Clips augmented inference tails for YOLOv5 models, affecting first and last tensors based on grid points and
|
||||
layer counts.
|
||||
"""
|
||||
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
||||
g = sum(4**x for x in range(nl)) # grid points
|
||||
e = 1 # exclude layer count
|
||||
i = (y[0].shape[1] // g) * sum(4**x for x in range(e)) # indices
|
||||
y[0] = y[0][:, :-i] # large
|
||||
i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
||||
y[-1] = y[-1][:, i:] # small
|
||||
return y
|
||||
|
||||
def _initialize_biases(self, cf=None):
|
||||
"""
|
||||
Initializes biases for YOLOv5's Detect() module, optionally using class frequencies (cf).
|
||||
|
||||
For details see https://arxiv.org/abs/1708.02002 section 3.3.
|
||||
"""
|
||||
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
||||
m = self.model[-1] # Detect() module
|
||||
for mi, s in zip(m.m, m.stride): # from
|
||||
b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
|
||||
b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
|
||||
b.data[:, 5 : 5 + m.nc] += (
|
||||
math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum())
|
||||
) # cls
|
||||
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
||||
|
||||
|
||||
Model = DetectionModel # retain YOLOv5 'Model' class for backwards compatibility
|
||||
|
||||
|
||||
class SegmentationModel(DetectionModel):
|
||||
"""YOLOv5 segmentation model for object detection and segmentation tasks with configurable parameters."""
|
||||
|
||||
def __init__(self, cfg="yolov5s-seg.yaml", ch=3, nc=None, anchors=None):
|
||||
"""Initializes a YOLOv5 segmentation model with configurable params: cfg (str) for configuration, ch (int) for channels, nc (int) for num classes, anchors (list)."""
|
||||
super().__init__(cfg, ch, nc, anchors)
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
"""YOLOv5 classification model for image classification tasks, initialized with a config file or detection model."""
|
||||
|
||||
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10):
|
||||
"""Initializes YOLOv5 model with config file `cfg`, input channels `ch`, number of classes `nc`, and `cuttoff`
|
||||
index.
|
||||
"""
|
||||
super().__init__()
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
|
||||
|
||||
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
||||
"""Creates a classification model from a YOLOv5 detection model, slicing at `cutoff` and adding a classification
|
||||
layer.
|
||||
"""
|
||||
if isinstance(model, DetectMultiBackend):
|
||||
model = model.model # unwrap DetectMultiBackend
|
||||
model.model = model.model[:cutoff] # backbone
|
||||
m = model.model[-1] # last layer
|
||||
ch = m.conv.in_channels if hasattr(m, "conv") else m.cv1.conv.in_channels # ch into module
|
||||
c = Classify(ch, nc) # Classify()
|
||||
c.i, c.f, c.type = m.i, m.f, "models.common.Classify" # index, from, type
|
||||
model.model[-1] = c # replace
|
||||
self.model = model.model
|
||||
self.stride = model.stride
|
||||
self.save = []
|
||||
self.nc = nc
|
||||
|
||||
def _from_yaml(self, cfg):
|
||||
"""Creates a YOLOv5 classification model from a specified *.yaml configuration file."""
|
||||
self.model = None
|
||||
|
||||
|
||||
def parse_model(d, ch):
|
||||
"""Parses a YOLOv5 model from a dict `d`, configuring layers based on input channels `ch` and model architecture."""
|
||||
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
||||
anchors, nc, gd, gw, act, ch_mul = (
|
||||
d["anchors"],
|
||||
d["nc"],
|
||||
d["depth_multiple"],
|
||||
d["width_multiple"],
|
||||
d.get("activation"),
|
||||
d.get("channel_multiple"),
|
||||
)
|
||||
if act:
|
||||
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
||||
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
||||
if not ch_mul:
|
||||
ch_mul = 8
|
||||
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
||||
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
||||
|
||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
||||
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
||||
m = eval(m) if isinstance(m, str) else m # eval strings
|
||||
for j, a in enumerate(args):
|
||||
with contextlib.suppress(NameError):
|
||||
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
||||
|
||||
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
||||
if m in {
|
||||
Conv,
|
||||
GhostConv,
|
||||
Bottleneck,
|
||||
GhostBottleneck,
|
||||
SPP,
|
||||
SPPF,
|
||||
DWConv,
|
||||
MixConv2d,
|
||||
Focus,
|
||||
CrossConv,
|
||||
BottleneckCSP,
|
||||
C3,
|
||||
C3TR,
|
||||
C3SPP,
|
||||
C3Ghost,
|
||||
nn.ConvTranspose2d,
|
||||
DWConvTranspose2d,
|
||||
C3x,
|
||||
}:
|
||||
c1, c2 = ch[f], args[0]
|
||||
if c2 != no: # if not output
|
||||
c2 = make_divisible(c2 * gw, ch_mul)
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}:
|
||||
args.insert(2, n) # number of repeats
|
||||
n = 1
|
||||
elif m is nn.BatchNorm2d:
|
||||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
# TODO: channel, gw, gd
|
||||
elif m in {Detect, Segment}:
|
||||
args.append([ch[x] for x in f])
|
||||
if isinstance(args[1], int): # number of anchors
|
||||
args[1] = [list(range(args[1] * 2))] * len(f)
|
||||
if m is Segment:
|
||||
args[3] = make_divisible(args[3] * gw, ch_mul)
|
||||
elif m is Contract:
|
||||
c2 = ch[f] * args[0] ** 2
|
||||
elif m is Expand:
|
||||
c2 = ch[f] // args[0] ** 2
|
||||
else:
|
||||
c2 = ch[f]
|
||||
|
||||
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
||||
t = str(m)[8:-2].replace("__main__.", "") # module type
|
||||
np = sum(x.numel() for x in m_.parameters()) # number params
|
||||
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
|
||||
LOGGER.info(f"{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}") # print
|
||||
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
||||
layers.append(m_)
|
||||
if i == 0:
|
||||
ch = []
|
||||
ch.append(c2)
|
||||
return nn.Sequential(*layers), sorted(save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--cfg", type=str, default="yolov5s.yaml", help="model.yaml")
|
||||
parser.add_argument("--batch-size", type=int, default=1, help="total batch size for all GPUs")
|
||||
parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
|
||||
parser.add_argument("--profile", action="store_true", help="profile model speed")
|
||||
parser.add_argument("--line-profile", action="store_true", help="profile model speed layer by layer")
|
||||
parser.add_argument("--test", action="store_true", help="test all yolo*.yaml")
|
||||
opt = parser.parse_args()
|
||||
opt.cfg = check_yaml(opt.cfg) # check YAML
|
||||
print_args(vars(opt))
|
||||
device = select_device(opt.device)
|
||||
|
||||
# Create model
|
||||
im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
|
||||
model = Model(opt.cfg).to(device)
|
||||
|
||||
# Options
|
||||
if opt.line_profile: # profile layer by layer
|
||||
model(im, profile=True)
|
||||
|
||||
elif opt.profile: # profile forward-backward
|
||||
results = profile(input=im, ops=[model], n=3)
|
||||
|
||||
elif opt.test: # test all models
|
||||
for cfg in Path(ROOT / "models").rglob("yolo*.yaml"):
|
||||
try:
|
||||
_ = Model(cfg)
|
||||
except Exception as e:
|
||||
print(f"Error in {cfg}: {e}")
|
||||
|
||||
else: # report fused model summary
|
||||
model.fuse()
|
49
utils/yolov5/models/yolov5l.yaml
Normal file
49
utils/yolov5/models/yolov5l.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.0 # model depth multiple
|
||||
width_multiple: 1.0 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/yolov5m.yaml
Normal file
49
utils/yolov5/models/yolov5m.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.67 # model depth multiple
|
||||
width_multiple: 0.75 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/yolov5n.yaml
Normal file
49
utils/yolov5/models/yolov5n.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.25 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/yolov5s.yaml
Normal file
49
utils/yolov5/models/yolov5s.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.50 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
49
utils/yolov5/models/yolov5x.yaml
Normal file
49
utils/yolov5/models/yolov5x.yaml
Normal file
@ -0,0 +1,49 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 1.33 # model depth multiple
|
||||
width_multiple: 1.25 # layer channel multiple
|
||||
anchors:
|
||||
- [10, 13, 16, 30, 33, 23] # P3/8
|
||||
- [30, 61, 62, 45, 59, 119] # P4/16
|
||||
- [116, 90, 156, 198, 373, 326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[
|
||||
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head: [
|
||||
[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
|
||||
]
|
97
utils/yolov5/utils/__init__.py
Normal file
97
utils/yolov5/utils/__init__.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""utils/initialization."""
|
||||
|
||||
import contextlib
|
||||
import platform
|
||||
import threading
|
||||
|
||||
|
||||
def emojis(str=""):
|
||||
"""Returns an emoji-safe version of a string, stripped of emojis on Windows platforms."""
|
||||
return str.encode().decode("ascii", "ignore") if platform.system() == "Windows" else str
|
||||
|
||||
|
||||
class TryExcept(contextlib.ContextDecorator):
|
||||
"""A context manager and decorator for error handling that prints an optional message with emojis on exception."""
|
||||
|
||||
def __init__(self, msg=""):
|
||||
"""Initializes TryExcept with an optional message, used as a decorator or context manager for error handling."""
|
||||
self.msg = msg
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the runtime context related to this object for error handling with an optional message."""
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
"""Context manager exit method that prints an error message with emojis if an exception occurred, always returns
|
||||
True.
|
||||
"""
|
||||
if value:
|
||||
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
|
||||
return True
|
||||
|
||||
|
||||
def threaded(func):
|
||||
"""Decorator @threaded to run a function in a separate thread, returning the thread instance."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Runs the decorated function in a separate daemon thread and returns the thread instance."""
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def join_threads(verbose=False):
|
||||
"""
|
||||
Joins all daemon threads, optionally printing their names if verbose is True.
|
||||
|
||||
Example: atexit.register(lambda: join_threads())
|
||||
"""
|
||||
main_thread = threading.current_thread()
|
||||
for t in threading.enumerate():
|
||||
if t is not main_thread:
|
||||
if verbose:
|
||||
print(f"Joining thread {t.name}")
|
||||
t.join()
|
||||
|
||||
|
||||
def notebook_init(verbose=True):
|
||||
"""Initializes notebook environment by checking requirements, cleaning up, and displaying system info."""
|
||||
print("Checking setup...")
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
from utils.general import check_font, is_colab
|
||||
from utils.torch_utils import select_device # imports
|
||||
|
||||
check_font()
|
||||
|
||||
import psutil
|
||||
|
||||
if check_requirements("wandb", install=False):
|
||||
os.system("pip uninstall -y wandb") # eliminate unexpected account creation prompt with infinite hang
|
||||
if is_colab():
|
||||
shutil.rmtree("/content/sample_data", ignore_errors=True) # remove colab /sample_data directory
|
||||
|
||||
# System info
|
||||
display = None
|
||||
if verbose:
|
||||
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
||||
ram = psutil.virtual_memory().total
|
||||
total, used, free = shutil.disk_usage("/")
|
||||
with contextlib.suppress(Exception): # clear display if ipython is installed
|
||||
from IPython import display
|
||||
|
||||
display.clear_output()
|
||||
s = f"({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)"
|
||||
else:
|
||||
s = ""
|
||||
|
||||
select_device(newline=False)
|
||||
print(emojis(f"Setup complete ✅ {s}"))
|
||||
return display
|
134
utils/yolov5/utils/activations.py
Normal file
134
utils/yolov5/utils/activations.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Activation functions."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SiLU(nn.Module):
|
||||
"""Applies the Sigmoid-weighted Linear Unit (SiLU) activation function, also known as Swish."""
|
||||
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
"""
|
||||
Applies the Sigmoid-weighted Linear Unit (SiLU) activation function.
|
||||
|
||||
https://arxiv.org/pdf/1606.08415.pdf.
|
||||
"""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Hardswish(nn.Module):
|
||||
"""Applies the Hardswish activation function, which is efficient for mobile and embedded devices."""
|
||||
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
"""
|
||||
Applies the Hardswish activation function, compatible with TorchScript, CoreML, and ONNX.
|
||||
|
||||
Equivalent to x * F.hardsigmoid(x)
|
||||
"""
|
||||
return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for TorchScript, CoreML and ONNX
|
||||
|
||||
|
||||
class Mish(nn.Module):
|
||||
"""Mish activation https://github.com/digantamisra98/Mish."""
|
||||
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
"""Applies the Mish activation function, a smooth alternative to ReLU."""
|
||||
return x * F.softplus(x).tanh()
|
||||
|
||||
|
||||
class MemoryEfficientMish(nn.Module):
|
||||
"""Efficiently applies the Mish activation function using custom autograd for reduced memory usage."""
|
||||
|
||||
class F(torch.autograd.Function):
|
||||
"""Implements a custom autograd function for memory-efficient Mish activation."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
"""Applies the Mish activation function, a smooth ReLU alternative, to the input tensor `x`."""
|
||||
ctx.save_for_backward(x)
|
||||
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""Computes the gradient of the Mish activation function with respect to input `x`."""
|
||||
x = ctx.saved_tensors[0]
|
||||
sx = torch.sigmoid(x)
|
||||
fx = F.softplus(x).tanh()
|
||||
return grad_output * (fx + x * sx * (1 - fx * fx))
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies the Mish activation function to the input tensor `x`."""
|
||||
return self.F.apply(x)
|
||||
|
||||
|
||||
class FReLU(nn.Module):
|
||||
"""FReLU activation https://arxiv.org/abs/2007.11824."""
|
||||
|
||||
def __init__(self, c1, k=3): # ch_in, kernel
|
||||
"""Initializes FReLU activation with channel `c1` and kernel size `k`."""
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(c1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Applies FReLU activation with max operation between input and BN-convolved input.
|
||||
|
||||
https://arxiv.org/abs/2007.11824
|
||||
"""
|
||||
return torch.max(x, self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class AconC(nn.Module):
|
||||
"""
|
||||
ACON activation (activate or not) function.
|
||||
|
||||
AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
|
||||
See "Activate or Not: Learning Customized Activation" https://arxiv.org/pdf/2009.04759.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self, c1):
|
||||
"""Initializes AconC with learnable parameters p1, p2, and beta for channel-wise activation control."""
|
||||
super().__init__()
|
||||
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
|
||||
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
|
||||
self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies AconC activation function with learnable parameters for channel-wise control on input tensor x."""
|
||||
dpx = (self.p1 - self.p2) * x
|
||||
return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x
|
||||
|
||||
|
||||
class MetaAconC(nn.Module):
|
||||
"""
|
||||
ACON activation (activate or not) function.
|
||||
|
||||
AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
|
||||
See "Activate or Not: Learning Customized Activation" https://arxiv.org/pdf/2009.04759.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self, c1, k=1, s=1, r=16):
|
||||
"""Initializes MetaAconC with params: channel_in (c1), kernel size (k=1), stride (s=1), reduction (r=16)."""
|
||||
super().__init__()
|
||||
c2 = max(r, c1 // r)
|
||||
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
|
||||
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
|
||||
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)
|
||||
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)
|
||||
# self.bn1 = nn.BatchNorm2d(c2)
|
||||
# self.bn2 = nn.BatchNorm2d(c1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies a forward pass transforming input `x` using learnable parameters and sigmoid activation."""
|
||||
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
|
||||
# batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891
|
||||
# beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable
|
||||
beta = torch.sigmoid(self.fc2(self.fc1(y))) # bug patch BN layers removed
|
||||
dpx = (self.p1 - self.p2) * x
|
||||
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x
|
440
utils/yolov5/utils/augmentations.py
Normal file
440
utils/yolov5/utils/augmentations.py
Normal file
@ -0,0 +1,440 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Image augmentation functions."""
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
from utils.yolov5.utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box, xywhn2xyxy
|
||||
from utils.yolov5.utils.metrics import bbox_ioa
|
||||
|
||||
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
||||
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
||||
|
||||
|
||||
class Albumentations:
|
||||
"""Provides optional data augmentation for YOLOv5 using Albumentations library if installed."""
|
||||
|
||||
def __init__(self, size=640):
|
||||
"""Initializes Albumentations class for optional data augmentation in YOLOv5 with specified input size."""
|
||||
self.transform = None
|
||||
prefix = colorstr("albumentations: ")
|
||||
try:
|
||||
import albumentations as A
|
||||
|
||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
||||
|
||||
T = [
|
||||
A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
|
||||
A.Blur(p=0.01),
|
||||
A.MedianBlur(p=0.01),
|
||||
A.ToGray(p=0.01),
|
||||
A.CLAHE(p=0.01),
|
||||
A.RandomBrightnessContrast(p=0.0),
|
||||
A.RandomGamma(p=0.0),
|
||||
A.ImageCompression(quality_lower=75, p=0.0),
|
||||
] # transforms
|
||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
|
||||
|
||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
except Exception as e:
|
||||
LOGGER.info(f"{prefix}{e}")
|
||||
|
||||
def __call__(self, im, labels, p=1.0):
|
||||
"""Applies transformations to an image and labels with probability `p`, returning updated image and labels."""
|
||||
if self.transform and random.random() < p:
|
||||
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
|
||||
im, labels = new["image"], np.array([[c, *b] for c, b in zip(new["class_labels"], new["bboxes"])])
|
||||
return im, labels
|
||||
|
||||
|
||||
def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False):
|
||||
"""
|
||||
Applies ImageNet normalization to RGB images in BCHW format, modifying them in-place if specified.
|
||||
|
||||
Example: y = (x - mean) / std
|
||||
"""
|
||||
return TF.normalize(x, mean, std, inplace=inplace)
|
||||
|
||||
|
||||
def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD):
|
||||
"""Reverses ImageNet normalization for BCHW format RGB images by applying `x = x * std + mean`."""
|
||||
for i in range(3):
|
||||
x[:, i] = x[:, i] * std[i] + mean[i]
|
||||
return x
|
||||
|
||||
|
||||
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
||||
"""Applies HSV color-space augmentation to an image with random gains for hue, saturation, and value."""
|
||||
if hgain or sgain or vgain:
|
||||
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
|
||||
hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
|
||||
dtype = im.dtype # uint8
|
||||
|
||||
x = np.arange(0, 256, dtype=r.dtype)
|
||||
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
||||
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
||||
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
||||
|
||||
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
||||
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
|
||||
|
||||
|
||||
def hist_equalize(im, clahe=True, bgr=False):
|
||||
"""Equalizes image histogram, with optional CLAHE, for BGR or RGB image with shape (n,m,3) and range 0-255."""
|
||||
yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
|
||||
if clahe:
|
||||
c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
yuv[:, :, 0] = c.apply(yuv[:, :, 0])
|
||||
else:
|
||||
yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
|
||||
return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
|
||||
|
||||
|
||||
def replicate(im, labels):
|
||||
"""
|
||||
Replicates half of the smallest object labels in an image for data augmentation.
|
||||
|
||||
Returns augmented image and labels.
|
||||
"""
|
||||
h, w = im.shape[:2]
|
||||
boxes = labels[:, 1:].astype(int)
|
||||
x1, y1, x2, y2 = boxes.T
|
||||
s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
|
||||
for i in s.argsort()[: round(s.size * 0.5)]: # smallest indices
|
||||
x1b, y1b, x2b, y2b = boxes[i]
|
||||
bh, bw = y2b - y1b, x2b - x1b
|
||||
yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
|
||||
x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
|
||||
im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax]
|
||||
labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
|
||||
|
||||
return im, labels
|
||||
|
||||
|
||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
|
||||
"""Resizes and pads image to new_shape with stride-multiple constraints, returns resized image, ratio, padding."""
|
||||
shape = im.shape[:2] # current shape [height, width]
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
# Scale ratio (new / old)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
||||
r = min(r, 1.0)
|
||||
|
||||
# Compute padding
|
||||
ratio = r, r # width, height ratios
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
if auto: # minimum rectangle
|
||||
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
||||
elif scaleFill: # stretch
|
||||
dw, dh = 0.0, 0.0
|
||||
new_unpad = (new_shape[1], new_shape[0])
|
||||
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
||||
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
||||
return im, ratio, (dw, dh)
|
||||
|
||||
|
||||
def random_perspective(
|
||||
im, targets=(), segments=(), degrees=10, translate=0.1, scale=0.1, shear=10, perspective=0.0, border=(0, 0)
|
||||
):
|
||||
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
|
||||
# targets = [cls, xyxy]
|
||||
"""Applies random perspective transformation to an image, modifying the image and corresponding labels."""
|
||||
height = im.shape[0] + border[0] * 2 # shape(h,w,c)
|
||||
width = im.shape[1] + border[1] * 2
|
||||
|
||||
# Center
|
||||
C = np.eye(3)
|
||||
C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
|
||||
C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
|
||||
|
||||
# Perspective
|
||||
P = np.eye(3)
|
||||
P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
|
||||
P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
|
||||
|
||||
# Rotation and Scale
|
||||
R = np.eye(3)
|
||||
a = random.uniform(-degrees, degrees)
|
||||
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
||||
s = random.uniform(1 - scale, 1 + scale)
|
||||
# s = 2 ** random.uniform(-scale, scale)
|
||||
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
||||
|
||||
# Shear
|
||||
S = np.eye(3)
|
||||
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
|
||||
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
|
||||
|
||||
# Translation
|
||||
T = np.eye(3)
|
||||
T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
|
||||
T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
|
||||
|
||||
# Combined rotation matrix
|
||||
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
|
||||
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
||||
if perspective:
|
||||
im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
|
||||
else: # affine
|
||||
im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
|
||||
|
||||
if n := len(targets):
|
||||
use_segments = any(x.any() for x in segments) and len(segments) == n
|
||||
new = np.zeros((n, 4))
|
||||
if use_segments: # warp segments
|
||||
segments = resample_segments(segments) # upsample
|
||||
for i, segment in enumerate(segments):
|
||||
xy = np.ones((len(segment), 3))
|
||||
xy[:, :2] = segment
|
||||
xy = xy @ M.T # transform
|
||||
xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
|
||||
|
||||
# clip
|
||||
new[i] = segment2box(xy, width, height)
|
||||
|
||||
else: # warp boxes
|
||||
xy = np.ones((n * 4, 3))
|
||||
xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
||||
xy = xy @ M.T # transform
|
||||
xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
|
||||
|
||||
# create new boxes
|
||||
x = xy[:, [0, 2, 4, 6]]
|
||||
y = xy[:, [1, 3, 5, 7]]
|
||||
new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
||||
|
||||
# clip
|
||||
new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
|
||||
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
|
||||
|
||||
# filter candidates
|
||||
i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
|
||||
targets = targets[i]
|
||||
targets[:, 1:5] = new[i]
|
||||
|
||||
return im, targets
|
||||
|
||||
|
||||
def copy_paste(im, labels, segments, p=0.5):
|
||||
"""
|
||||
Applies Copy-Paste augmentation by flipping and merging segments and labels on an image.
|
||||
|
||||
Details at https://arxiv.org/abs/2012.07177.
|
||||
"""
|
||||
n = len(segments)
|
||||
if p and n:
|
||||
h, w, c = im.shape # height, width, channels
|
||||
im_new = np.zeros(im.shape, np.uint8)
|
||||
for j in random.sample(range(n), k=round(p * n)):
|
||||
l, s = labels[j], segments[j]
|
||||
box = w - l[3], l[2], w - l[1], l[4]
|
||||
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
|
||||
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
|
||||
labels = np.concatenate((labels, [[l[0], *box]]), 0)
|
||||
segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
|
||||
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (1, 1, 1), cv2.FILLED)
|
||||
|
||||
result = cv2.flip(im, 1) # augment segments (flip left-right)
|
||||
i = cv2.flip(im_new, 1).astype(bool)
|
||||
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
||||
|
||||
return im, labels, segments
|
||||
|
||||
|
||||
def cutout(im, labels, p=0.5):
|
||||
"""
|
||||
Applies cutout augmentation to an image with optional label adjustment, using random masks of varying sizes.
|
||||
|
||||
Details at https://arxiv.org/abs/1708.04552.
|
||||
"""
|
||||
if random.random() < p:
|
||||
h, w = im.shape[:2]
|
||||
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
|
||||
for s in scales:
|
||||
mask_h = random.randint(1, int(h * s)) # create random masks
|
||||
mask_w = random.randint(1, int(w * s))
|
||||
|
||||
# box
|
||||
xmin = max(0, random.randint(0, w) - mask_w // 2)
|
||||
ymin = max(0, random.randint(0, h) - mask_h // 2)
|
||||
xmax = min(w, xmin + mask_w)
|
||||
ymax = min(h, ymin + mask_h)
|
||||
|
||||
# apply random color mask
|
||||
im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
|
||||
|
||||
# return unobscured labels
|
||||
if len(labels) and s > 0.03:
|
||||
box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
|
||||
ioa = bbox_ioa(box, xywhn2xyxy(labels[:, 1:5], w, h)) # intersection over area
|
||||
labels = labels[ioa < 0.60] # remove >60% obscured labels
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def mixup(im, labels, im2, labels2):
|
||||
"""
|
||||
Applies MixUp augmentation by blending images and labels.
|
||||
|
||||
See https://arxiv.org/pdf/1710.09412.pdf for details.
|
||||
"""
|
||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
||||
im = (im * r + im2 * (1 - r)).astype(np.uint8)
|
||||
labels = np.concatenate((labels, labels2), 0)
|
||||
return im, labels
|
||||
|
||||
|
||||
def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
|
||||
"""
|
||||
Filters bounding box candidates by minimum width-height threshold `wh_thr` (pixels), aspect ratio threshold
|
||||
`ar_thr`, and area ratio threshold `area_thr`.
|
||||
|
||||
box1(4,n) is before augmentation, box2(4,n) is after augmentation.
|
||||
"""
|
||||
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
||||
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
||||
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
||||
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
||||
|
||||
|
||||
def classify_albumentations(
|
||||
augment=True,
|
||||
size=224,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(0.75, 1.0 / 0.75), # 0.75, 1.33
|
||||
hflip=0.5,
|
||||
vflip=0.0,
|
||||
jitter=0.4,
|
||||
mean=IMAGENET_MEAN,
|
||||
std=IMAGENET_STD,
|
||||
auto_aug=False,
|
||||
):
|
||||
# YOLOv5 classification Albumentations (optional, only used if package is installed)
|
||||
"""Sets up and returns Albumentations transforms for YOLOv5 classification tasks depending on augmentation
|
||||
settings.
|
||||
"""
|
||||
prefix = colorstr("albumentations: ")
|
||||
try:
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
||||
if augment: # Resize and crop
|
||||
T = [A.RandomResizedCrop(height=size, width=size, scale=scale, ratio=ratio)]
|
||||
if auto_aug:
|
||||
# TODO: implement AugMix, AutoAug & RandAug in albumentation
|
||||
LOGGER.info(f"{prefix}auto augmentations are currently not supported")
|
||||
else:
|
||||
if hflip > 0:
|
||||
T += [A.HorizontalFlip(p=hflip)]
|
||||
if vflip > 0:
|
||||
T += [A.VerticalFlip(p=vflip)]
|
||||
if jitter > 0:
|
||||
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, saturation, 0 hue
|
||||
T += [A.ColorJitter(*color_jitter, 0)]
|
||||
else: # Use fixed crop for eval set (reproducibility)
|
||||
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
||||
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
||||
return A.Compose(T)
|
||||
|
||||
except ImportError: # package not installed, skip
|
||||
LOGGER.warning(f"{prefix}⚠️ not found, install with `pip install albumentations` (recommended)")
|
||||
except Exception as e:
|
||||
LOGGER.info(f"{prefix}{e}")
|
||||
|
||||
|
||||
def classify_transforms(size=224):
|
||||
"""Applies a series of transformations including center crop, ToTensor, and normalization for classification."""
|
||||
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
|
||||
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
|
||||
|
||||
class LetterBox:
|
||||
"""Resizes and pads images to specified dimensions while maintaining aspect ratio for YOLOv5 preprocessing."""
|
||||
|
||||
def __init__(self, size=(640, 640), auto=False, stride=32):
|
||||
"""Initializes a LetterBox object for YOLOv5 image preprocessing with optional auto sizing and stride
|
||||
adjustment.
|
||||
"""
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
||||
self.stride = stride # used with auto
|
||||
|
||||
def __call__(self, im):
|
||||
"""
|
||||
Resizes and pads input image `im` (HWC format) to specified dimensions, maintaining aspect ratio.
|
||||
|
||||
im = np.array HWC
|
||||
"""
|
||||
imh, imw = im.shape[:2]
|
||||
r = min(self.h / imh, self.w / imw) # ratio of new/old
|
||||
h, w = round(imh * r), round(imw * r) # resized image
|
||||
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
|
||||
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
|
||||
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
|
||||
im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
return im_out
|
||||
|
||||
|
||||
class CenterCrop:
|
||||
"""Applies center crop to an image, resizing it to the specified size while maintaining aspect ratio."""
|
||||
|
||||
def __init__(self, size=640):
|
||||
"""Initializes CenterCrop for image preprocessing, accepting single int or tuple for size, defaults to 640."""
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
|
||||
def __call__(self, im):
|
||||
"""
|
||||
Applies center crop to the input image and resizes it to a specified size, maintaining aspect ratio.
|
||||
|
||||
im = np.array HWC
|
||||
"""
|
||||
imh, imw = im.shape[:2]
|
||||
m = min(imh, imw) # min dimension
|
||||
top, left = (imh - m) // 2, (imw - m) // 2
|
||||
return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
|
||||
class ToTensor:
|
||||
"""Converts BGR np.array image from HWC to RGB CHW format, normalizes to [0, 1], and supports FP16 if half=True."""
|
||||
|
||||
def __init__(self, half=False):
|
||||
"""Initializes ToTensor for YOLOv5 image preprocessing, with optional half precision (half=True for FP16)."""
|
||||
super().__init__()
|
||||
self.half = half
|
||||
|
||||
def __call__(self, im):
|
||||
"""
|
||||
Converts BGR np.array image from HWC to RGB CHW format, and normalizes to [0, 1], with support for FP16 if
|
||||
`half=True`.
|
||||
|
||||
im = np.array HWC in BGR order
|
||||
"""
|
||||
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
|
||||
im = torch.from_numpy(im) # to torch
|
||||
im = im.half() if self.half else im.float() # uint8 to fp16/32
|
||||
im /= 255.0 # 0-255 to 0.0-1.0
|
||||
return im
|
175
utils/yolov5/utils/autoanchor.py
Normal file
175
utils/yolov5/utils/autoanchor.py
Normal file
@ -0,0 +1,175 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""AutoAnchor utils."""
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.yolov5.utils import TryExcept
|
||||
from utils.yolov5.utils.general import LOGGER, TQDM_BAR_FORMAT, colorstr
|
||||
|
||||
PREFIX = colorstr("AutoAnchor: ")
|
||||
|
||||
|
||||
def check_anchor_order(m):
|
||||
"""Checks and corrects anchor order against stride in YOLOv5 Detect() module if necessary."""
|
||||
a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer
|
||||
da = a[-1] - a[0] # delta a
|
||||
ds = m.stride[-1] - m.stride[0] # delta s
|
||||
if da and (da.sign() != ds.sign()): # same order
|
||||
LOGGER.info(f"{PREFIX}Reversing anchor order")
|
||||
m.anchors[:] = m.anchors.flip(0)
|
||||
|
||||
|
||||
@TryExcept(f"{PREFIX}ERROR")
|
||||
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
||||
"""Evaluates anchor fit to dataset and adjusts if necessary, supporting customizable threshold and image size."""
|
||||
m = model.module.model[-1] if hasattr(model, "module") else model.model[-1] # Detect()
|
||||
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
||||
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
|
||||
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
|
||||
|
||||
def metric(k): # compute metric
|
||||
"""Computes ratio metric, anchors above threshold, and best possible recall for YOLOv5 anchor evaluation."""
|
||||
r = wh[:, None] / k[None]
|
||||
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
|
||||
best = x.max(1)[0] # best_x
|
||||
aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
|
||||
bpr = (best > 1 / thr).float().mean() # best possible recall
|
||||
return bpr, aat
|
||||
|
||||
stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides
|
||||
anchors = m.anchors.clone() * stride # current anchors
|
||||
bpr, aat = metric(anchors.cpu().view(-1, 2))
|
||||
s = f"\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). "
|
||||
if bpr > 0.98: # threshold to recompute
|
||||
LOGGER.info(f"{s}Current anchors are a good fit to dataset ✅")
|
||||
else:
|
||||
LOGGER.info(f"{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...")
|
||||
na = m.anchors.numel() // 2 # number of anchors
|
||||
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
||||
new_bpr = metric(anchors)[0]
|
||||
if new_bpr > bpr: # replace anchors
|
||||
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
|
||||
m.anchors[:] = anchors.clone().view_as(m.anchors)
|
||||
check_anchor_order(m) # must be in pixel-space (not grid-space)
|
||||
m.anchors /= stride
|
||||
s = f"{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)"
|
||||
else:
|
||||
s = f"{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)"
|
||||
LOGGER.info(s)
|
||||
|
||||
|
||||
def kmean_anchors(dataset="./data/coco128.yaml", n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
|
||||
"""
|
||||
Creates kmeans-evolved anchors from training dataset.
|
||||
|
||||
Arguments:
|
||||
dataset: path to data.yaml, or a loaded dataset
|
||||
n: number of anchors
|
||||
img_size: image size used for training
|
||||
thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
|
||||
gen: generations to evolve anchors using genetic algorithm
|
||||
verbose: print all results
|
||||
|
||||
Return:
|
||||
k: kmeans evolved anchors
|
||||
|
||||
Usage:
|
||||
from utils.autoanchor import *; _ = kmean_anchors()
|
||||
"""
|
||||
from scipy.cluster.vq import kmeans
|
||||
|
||||
npr = np.random
|
||||
thr = 1 / thr
|
||||
|
||||
def metric(k, wh): # compute metrics
|
||||
"""Computes ratio metric, anchors above threshold, and best possible recall for YOLOv5 anchor evaluation."""
|
||||
r = wh[:, None] / k[None]
|
||||
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
|
||||
# x = wh_iou(wh, torch.tensor(k)) # iou metric
|
||||
return x, x.max(1)[0] # x, best_x
|
||||
|
||||
def anchor_fitness(k): # mutation fitness
|
||||
"""Evaluates fitness of YOLOv5 anchors by computing recall and ratio metrics for an anchor evolution process."""
|
||||
_, best = metric(torch.tensor(k, dtype=torch.float32), wh)
|
||||
return (best * (best > thr).float()).mean() # fitness
|
||||
|
||||
def print_results(k, verbose=True):
|
||||
"""Sorts and logs kmeans-evolved anchor metrics and best possible recall values for YOLOv5 anchor evaluation."""
|
||||
k = k[np.argsort(k.prod(1))] # sort small to large
|
||||
x, best = metric(k, wh0)
|
||||
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
|
||||
s = (
|
||||
f"{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n"
|
||||
f"{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, "
|
||||
f"past_thr={x[x > thr].mean():.3f}-mean: "
|
||||
)
|
||||
for x in k:
|
||||
s += "%i,%i, " % (round(x[0]), round(x[1]))
|
||||
if verbose:
|
||||
LOGGER.info(s[:-2])
|
||||
return k
|
||||
|
||||
if isinstance(dataset, str): # *.yaml file
|
||||
with open(dataset, errors="ignore") as f:
|
||||
data_dict = yaml.safe_load(f) # model dict
|
||||
from utils.dataloaders import LoadImagesAndLabels
|
||||
|
||||
dataset = LoadImagesAndLabels(data_dict["train"], augment=True, rect=True)
|
||||
|
||||
# Get label wh
|
||||
shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
||||
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
|
||||
|
||||
# Filter
|
||||
i = (wh0 < 3.0).any(1).sum()
|
||||
if i:
|
||||
LOGGER.info(f"{PREFIX}WARNING ⚠️ Extremely small objects found: {i} of {len(wh0)} labels are <3 pixels in size")
|
||||
wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
|
||||
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
|
||||
|
||||
# Kmeans init
|
||||
try:
|
||||
LOGGER.info(f"{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...")
|
||||
assert n <= len(wh) # apply overdetermined constraint
|
||||
s = wh.std(0) # sigmas for whitening
|
||||
k = kmeans(wh / s, n, iter=30)[0] * s # points
|
||||
assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar
|
||||
except Exception:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ switching strategies from kmeans to random init")
|
||||
k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
|
||||
wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
|
||||
k = print_results(k, verbose=False)
|
||||
|
||||
# Plot
|
||||
# k, d = [None] * 20, [None] * 20
|
||||
# for i in tqdm(range(1, 21)):
|
||||
# k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
|
||||
# fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
|
||||
# ax = ax.ravel()
|
||||
# ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
|
||||
# fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
|
||||
# ax[0].hist(wh[wh[:, 0]<100, 0],400)
|
||||
# ax[1].hist(wh[wh[:, 1]<100, 1],400)
|
||||
# fig.savefig('wh.png', dpi=200)
|
||||
|
||||
# Evolve
|
||||
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
|
||||
pbar = tqdm(range(gen), bar_format=TQDM_BAR_FORMAT) # progress bar
|
||||
for _ in pbar:
|
||||
v = np.ones(sh)
|
||||
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
|
||||
v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
|
||||
kg = (k.copy() * v).clip(min=2.0)
|
||||
fg = anchor_fitness(kg)
|
||||
if fg > f:
|
||||
f, k = fg, kg.copy()
|
||||
pbar.desc = f"{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}"
|
||||
if verbose:
|
||||
print_results(k, verbose)
|
||||
|
||||
return print_results(k).astype(np.float32)
|
70
utils/yolov5/utils/autobatch.py
Normal file
70
utils/yolov5/utils/autobatch.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Auto-batch utils."""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from utils.yolov5.utils.general import LOGGER, colorstr
|
||||
from utils.yolov5.utils.torch_utils import profile
|
||||
|
||||
|
||||
def check_train_batch_size(model, imgsz=640, amp=True):
|
||||
"""Checks and computes optimal training batch size for YOLOv5 model, given image size and AMP setting."""
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
||||
|
||||
|
||||
def autobatch(model, imgsz=640, fraction=0.8, batch_size=16):
|
||||
"""Estimates optimal YOLOv5 batch size using `fraction` of CUDA memory."""
|
||||
# Usage:
|
||||
# import torch
|
||||
# from utils.autobatch import autobatch
|
||||
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
|
||||
# print(autobatch(model))
|
||||
|
||||
# Check device
|
||||
prefix = colorstr("AutoBatch: ")
|
||||
LOGGER.info(f"{prefix}Computing optimal batch size for --imgsz {imgsz}")
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type == "cpu":
|
||||
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")
|
||||
return batch_size
|
||||
if torch.backends.cudnn.benchmark:
|
||||
LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
|
||||
return batch_size
|
||||
|
||||
# Inspect CUDA memory
|
||||
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
||||
d = str(device).upper() # 'CUDA:0'
|
||||
properties = torch.cuda.get_device_properties(device) # device properties
|
||||
t = properties.total_memory / gb # GiB total
|
||||
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
|
||||
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
|
||||
f = t - (r + a) # GiB free
|
||||
LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
|
||||
|
||||
# Profile batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16]
|
||||
try:
|
||||
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||
results = profile(img, model, n=3, device=device)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{prefix}{e}")
|
||||
|
||||
# Fit a solution
|
||||
y = [x[2] for x in results if x] # memory [2]
|
||||
p = np.polyfit(batch_sizes[: len(y)], y, deg=1) # first degree polynomial fit
|
||||
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
|
||||
if None in results: # some sizes failed
|
||||
i = results.index(None) # first fail index
|
||||
if b >= batch_sizes[i]: # y intercept above failure point
|
||||
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
||||
if b < 1 or b > 1024: # b outside of safe range
|
||||
b = batch_size
|
||||
LOGGER.warning(f"{prefix}WARNING ⚠️ CUDA anomaly detected, recommend restart environment and retry command.")
|
||||
|
||||
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
|
||||
LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
|
||||
return b
|
1
utils/yolov5/utils/aws/__init__.py
Normal file
1
utils/yolov5/utils/aws/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
26
utils/yolov5/utils/aws/mime.sh
Normal file
26
utils/yolov5/utils/aws/mime.sh
Normal file
@ -0,0 +1,26 @@
|
||||
# AWS EC2 instance startup 'MIME' script https://aws.amazon.com/premiumsupport/knowledge-center/execute-user-data-ec2/
|
||||
# This script will run on every instance restart, not only on first start
|
||||
# --- DO NOT COPY ABOVE COMMENTS WHEN PASTING INTO USERDATA ---
|
||||
|
||||
Content-Type: multipart/mixed; boundary="//"
|
||||
MIME-Version: 1.0
|
||||
|
||||
--//
|
||||
Content-Type: text/cloud-config; charset="us-ascii"
|
||||
MIME-Version: 1.0
|
||||
Content-Transfer-Encoding: 7bit
|
||||
Content-Disposition: attachment; filename="cloud-config.txt"
|
||||
|
||||
#cloud-config
|
||||
cloud_final_modules:
|
||||
- [scripts-user, always]
|
||||
|
||||
--//
|
||||
Content-Type: text/x-shellscript; charset="us-ascii"
|
||||
MIME-Version: 1.0
|
||||
Content-Transfer-Encoding: 7bit
|
||||
Content-Disposition: attachment; filename="userdata.txt"
|
||||
|
||||
#!/bin/bash
|
||||
# --- paste contents of userdata.sh here ---
|
||||
--//
|
42
utils/yolov5/utils/aws/resume.py
Normal file
42
utils/yolov5/utils/aws/resume.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Resume all interrupted trainings in yolov5/ dir including DDP trainings
|
||||
# Usage: $ python utils/aws/resume.py
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[2] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
|
||||
port = 0 # --master_port
|
||||
path = Path("").resolve()
|
||||
for last in path.rglob("*/**/last.pt"):
|
||||
ckpt = torch.load(last)
|
||||
if ckpt["optimizer"] is None:
|
||||
continue
|
||||
|
||||
# Load opt.yaml
|
||||
with open(last.parent.parent / "opt.yaml", errors="ignore") as f:
|
||||
opt = yaml.safe_load(f)
|
||||
|
||||
# Get device count
|
||||
d = opt["device"].split(",") # devices
|
||||
nd = len(d) # number of devices
|
||||
ddp = nd > 1 or (nd == 0 and torch.cuda.device_count() > 1) # distributed data parallel
|
||||
|
||||
if ddp: # multi-GPU
|
||||
port += 1
|
||||
cmd = f"python -m torch.distributed.run --nproc_per_node {nd} --master_port {port} train.py --resume {last}"
|
||||
else: # single-GPU
|
||||
cmd = f"python train.py --resume {last}"
|
||||
|
||||
cmd += " > /dev/null 2>&1 &" # redirect output to dev/null and run in daemon thread
|
||||
print(cmd)
|
||||
os.system(cmd)
|
27
utils/yolov5/utils/aws/userdata.sh
Normal file
27
utils/yolov5/utils/aws/userdata.sh
Normal file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
# AWS EC2 instance startup script https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/user-data.html
|
||||
# This script will run only once on first instance start (for a re-start script see mime.sh)
|
||||
# /home/ubuntu (ubuntu) or /home/ec2-user (amazon-linux) is working dir
|
||||
# Use >300 GB SSD
|
||||
|
||||
cd home/ubuntu
|
||||
if [ ! -d yolov5 ]; then
|
||||
echo "Running first-time script." # install dependencies, download COCO, pull Docker
|
||||
git clone https://github.com/ultralytics/yolov5 -b master && sudo chmod -R 777 yolov5
|
||||
cd yolov5
|
||||
bash data/scripts/get_coco.sh && echo "COCO done." &
|
||||
sudo docker pull ultralytics/yolov5:latest && echo "Docker done." &
|
||||
python -m pip install --upgrade pip && pip install -r requirements.txt && python detect.py && echo "Requirements done." &
|
||||
wait && echo "All tasks done." # finish background tasks
|
||||
else
|
||||
echo "Running re-start script." # resume interrupted runs
|
||||
i=0
|
||||
list=$(sudo docker ps -qa) # container list i.e. $'one\ntwo\nthree\nfour'
|
||||
while IFS= read -r id; do
|
||||
((i++))
|
||||
echo "restarting container $i: $id"
|
||||
sudo docker start $id
|
||||
# sudo docker exec -it $id python train.py --resume # single-GPU
|
||||
sudo docker exec -d $id python utils/aws/resume.py # multi-scenario
|
||||
done <<<"$list"
|
||||
fi
|
72
utils/yolov5/utils/callbacks.py
Normal file
72
utils/yolov5/utils/callbacks.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Callback utils."""
|
||||
|
||||
import threading
|
||||
|
||||
|
||||
class Callbacks:
|
||||
"""Handles all registered callbacks for YOLOv5 Hooks."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a Callbacks object to manage registered YOLOv5 training event hooks."""
|
||||
self._callbacks = {
|
||||
"on_pretrain_routine_start": [],
|
||||
"on_pretrain_routine_end": [],
|
||||
"on_train_start": [],
|
||||
"on_train_epoch_start": [],
|
||||
"on_train_batch_start": [],
|
||||
"optimizer_step": [],
|
||||
"on_before_zero_grad": [],
|
||||
"on_train_batch_end": [],
|
||||
"on_train_epoch_end": [],
|
||||
"on_val_start": [],
|
||||
"on_val_batch_start": [],
|
||||
"on_val_image_end": [],
|
||||
"on_val_batch_end": [],
|
||||
"on_val_end": [],
|
||||
"on_fit_epoch_end": [], # fit = train + val
|
||||
"on_model_save": [],
|
||||
"on_train_end": [],
|
||||
"on_params_update": [],
|
||||
"teardown": [],
|
||||
}
|
||||
self.stop_training = False # set True to interrupt training
|
||||
|
||||
def register_action(self, hook, name="", callback=None):
|
||||
"""
|
||||
Register a new action to a callback hook.
|
||||
|
||||
Args:
|
||||
hook: The callback hook name to register the action to
|
||||
name: The name of the action for later reference
|
||||
callback: The callback to fire
|
||||
"""
|
||||
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
||||
assert callable(callback), f"callback '{callback}' is not callable"
|
||||
self._callbacks[hook].append({"name": name, "callback": callback})
|
||||
|
||||
def get_registered_actions(self, hook=None):
|
||||
"""
|
||||
Returns all the registered actions by callback hook.
|
||||
|
||||
Args:
|
||||
hook: The name of the hook to check, defaults to all
|
||||
"""
|
||||
return self._callbacks[hook] if hook else self._callbacks
|
||||
|
||||
def run(self, hook, *args, thread=False, **kwargs):
|
||||
"""
|
||||
Loop through the registered actions and fire all callbacks on main thread.
|
||||
|
||||
Args:
|
||||
hook: The name of the hook to check, defaults to all
|
||||
args: Arguments to receive from YOLOv5
|
||||
thread: (boolean) Run callbacks in daemon thread
|
||||
kwargs: Keyword Arguments to receive from YOLOv5
|
||||
"""
|
||||
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
||||
for logger in self._callbacks[hook]:
|
||||
if thread:
|
||||
threading.Thread(target=logger["callback"], args=args, kwargs=kwargs, daemon=True).start()
|
||||
else:
|
||||
logger["callback"](*args, **kwargs)
|
1378
utils/yolov5/utils/dataloaders.py
Normal file
1378
utils/yolov5/utils/dataloaders.py
Normal file
File diff suppressed because it is too large
Load Diff
73
utils/yolov5/utils/docker/Dockerfile
Normal file
73
utils/yolov5/utils/docker/Dockerfile
Normal file
@ -0,0 +1,73 @@
|
||||
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
||||
# Builds ultralytics/yolov5:latest image on DockerHub https://hub.docker.com/r/ultralytics/yolov5
|
||||
# Image is CUDA-optimized for YOLOv5 single/multi-GPU training and inference
|
||||
|
||||
# Start FROM PyTorch image https://hub.docker.com/r/pytorch/pytorch
|
||||
FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
|
||||
|
||||
# Downloads to user config dir
|
||||
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
||||
|
||||
# Install linux packages
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
RUN apt update
|
||||
RUN TZ=Etc/UTC apt install -y tzdata
|
||||
RUN apt install --no-install-recommends -y gcc git zip curl htop libgl1 libglib2.0-0 libpython3-dev gnupg
|
||||
# RUN alias python=python3
|
||||
|
||||
# Security updates
|
||||
# https://security.snyk.io/vuln/SNYK-UBUNTU1804-OPENSSL-3314796
|
||||
RUN apt upgrade --no-install-recommends -y openssl
|
||||
|
||||
# Create working directory
|
||||
RUN rm -rf /usr/src/app && mkdir -p /usr/src/app
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Copy contents
|
||||
COPY . /usr/src/app
|
||||
|
||||
# Install pip packages
|
||||
COPY requirements.txt .
|
||||
RUN python3 -m pip install --upgrade pip wheel
|
||||
RUN pip install --no-cache -r requirements.txt albumentations comet gsutil notebook \
|
||||
coremltools onnx onnx-simplifier onnxruntime 'openvino-dev>=2023.0'
|
||||
# tensorflow tensorflowjs \
|
||||
|
||||
# Set environment variables
|
||||
ENV OMP_NUM_THREADS=1
|
||||
|
||||
# Cleanup
|
||||
ENV DEBIAN_FRONTEND teletype
|
||||
|
||||
|
||||
# Usage Examples -------------------------------------------------------------------------------------------------------
|
||||
|
||||
# Build and Push
|
||||
# t=ultralytics/yolov5:latest && sudo docker build -f utils/docker/Dockerfile -t $t . && sudo docker push $t
|
||||
|
||||
# Pull and Run
|
||||
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t
|
||||
|
||||
# Pull and Run with local directory access
|
||||
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/datasets:/usr/src/datasets $t
|
||||
|
||||
# Kill all
|
||||
# sudo docker kill $(sudo docker ps -q)
|
||||
|
||||
# Kill all image-based
|
||||
# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest)
|
||||
|
||||
# DockerHub tag update
|
||||
# t=ultralytics/yolov5:latest tnew=ultralytics/yolov5:v6.2 && sudo docker pull $t && sudo docker tag $t $tnew && sudo docker push $tnew
|
||||
|
||||
# Clean up
|
||||
# sudo docker system prune -a --volumes
|
||||
|
||||
# Update Ubuntu drivers
|
||||
# https://www.maketecheasier.com/install-nvidia-drivers-ubuntu/
|
||||
|
||||
# DDP test
|
||||
# python -m torch.distributed.run --nproc_per_node 2 --master_port 1 train.py --epochs 3
|
||||
|
||||
# GCP VM from Image
|
||||
# docker.io/ultralytics/yolov5:latest
|
40
utils/yolov5/utils/docker/Dockerfile-arm64
Normal file
40
utils/yolov5/utils/docker/Dockerfile-arm64
Normal file
@ -0,0 +1,40 @@
|
||||
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
||||
# Builds ultralytics/yolov5:latest-arm64 image on DockerHub https://hub.docker.com/r/ultralytics/yolov5
|
||||
# Image is aarch64-compatible for Apple M1 and other ARM architectures i.e. Jetson Nano and Raspberry Pi
|
||||
|
||||
# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu
|
||||
FROM arm64v8/ubuntu:22.10
|
||||
|
||||
# Downloads to user config dir
|
||||
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
||||
|
||||
# Install linux packages
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
RUN apt update
|
||||
RUN TZ=Etc/UTC apt install -y tzdata
|
||||
RUN apt install --no-install-recommends -y python3-pip git zip curl htop gcc libgl1 libglib2.0-0 libpython3-dev
|
||||
# RUN alias python=python3
|
||||
|
||||
# Install pip packages
|
||||
COPY requirements.txt .
|
||||
RUN python3 -m pip install --upgrade pip wheel
|
||||
RUN pip install --no-cache -r requirements.txt albumentations gsutil notebook \
|
||||
coremltools onnx onnxruntime
|
||||
# tensorflow-aarch64 tensorflowjs \
|
||||
|
||||
# Create working directory
|
||||
RUN mkdir -p /usr/src/app
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Copy contents
|
||||
COPY . /usr/src/app
|
||||
ENV DEBIAN_FRONTEND teletype
|
||||
|
||||
|
||||
# Usage Examples -------------------------------------------------------------------------------------------------------
|
||||
|
||||
# Build and Push
|
||||
# t=ultralytics/yolov5:latest-arm64 && sudo docker build --platform linux/arm64 -f utils/docker/Dockerfile-arm64 -t $t . && sudo docker push $t
|
||||
|
||||
# Pull and Run
|
||||
# t=ultralytics/yolov5:latest-arm64 && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t
|
42
utils/yolov5/utils/docker/Dockerfile-cpu
Normal file
42
utils/yolov5/utils/docker/Dockerfile-cpu
Normal file
@ -0,0 +1,42 @@
|
||||
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
||||
# Builds ultralytics/yolov5:latest-cpu image on DockerHub https://hub.docker.com/r/ultralytics/yolov5
|
||||
# Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLOv5 deployments
|
||||
|
||||
# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu
|
||||
FROM ubuntu:23.10
|
||||
|
||||
# Downloads to user config dir
|
||||
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
||||
|
||||
# Install linux packages
|
||||
# g++ required to build 'tflite_support' and 'lap' packages, libusb-1.0-0 required for 'tflite_support' package
|
||||
RUN apt update \
|
||||
&& apt install --no-install-recommends -y python3-pip git zip curl htop libgl1 libglib2.0-0 libpython3-dev gnupg g++ libusb-1.0-0
|
||||
# RUN alias python=python3
|
||||
|
||||
# Remove python3.11/EXTERNALLY-MANAGED or use 'pip install --break-system-packages' avoid 'externally-managed-environment' Ubuntu nightly error
|
||||
RUN rm -rf /usr/lib/python3.11/EXTERNALLY-MANAGED
|
||||
|
||||
# Install pip packages
|
||||
COPY requirements.txt .
|
||||
RUN python3 -m pip install --upgrade pip wheel
|
||||
RUN pip install --no-cache -r requirements.txt albumentations gsutil notebook \
|
||||
coremltools onnx onnx-simplifier onnxruntime 'openvino-dev>=2023.0' \
|
||||
# tensorflow tensorflowjs \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Create working directory
|
||||
RUN mkdir -p /usr/src/app
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Copy contents
|
||||
COPY . /usr/src/app
|
||||
|
||||
|
||||
# Usage Examples -------------------------------------------------------------------------------------------------------
|
||||
|
||||
# Build and Push
|
||||
# t=ultralytics/yolov5:latest-cpu && sudo docker build -f utils/docker/Dockerfile-cpu -t $t . && sudo docker push $t
|
||||
|
||||
# Pull and Run
|
||||
# t=ultralytics/yolov5:latest-cpu && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t
|
136
utils/yolov5/utils/downloads.py
Normal file
136
utils/yolov5/utils/downloads.py
Normal file
@ -0,0 +1,136 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Download utils."""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import urllib
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
def is_url(url, check=True):
|
||||
"""Determines if a string is a URL and optionally checks its existence online, returning a boolean."""
|
||||
try:
|
||||
url = str(url)
|
||||
result = urllib.parse.urlparse(url)
|
||||
assert all([result.scheme, result.netloc]) # check if is url
|
||||
return (urllib.request.urlopen(url).getcode() == 200) if check else True # check if exists online
|
||||
except (AssertionError, urllib.request.HTTPError):
|
||||
return False
|
||||
|
||||
|
||||
def gsutil_getsize(url=""):
|
||||
"""
|
||||
Returns the size in bytes of a file at a Google Cloud Storage URL using `gsutil du`.
|
||||
|
||||
Returns 0 if the command fails or output is empty.
|
||||
"""
|
||||
output = subprocess.check_output(["gsutil", "du", url], shell=True, encoding="utf-8")
|
||||
return int(output.split()[0]) if output else 0
|
||||
|
||||
|
||||
def url_getsize(url="https://ultralytics.com/images/bus.jpg"):
|
||||
"""Returns the size in bytes of a downloadable file at a given URL; defaults to -1 if not found."""
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
return int(response.headers.get("content-length", -1))
|
||||
|
||||
|
||||
def curl_download(url, filename, *, silent: bool = False) -> bool:
|
||||
"""Download a file from a url to a filename using curl."""
|
||||
silent_option = "sS" if silent else "" # silent
|
||||
proc = subprocess.run(
|
||||
[
|
||||
"curl",
|
||||
"-#",
|
||||
f"-{silent_option}L",
|
||||
url,
|
||||
"--output",
|
||||
filename,
|
||||
"--retry",
|
||||
"9",
|
||||
"-C",
|
||||
"-",
|
||||
]
|
||||
)
|
||||
return proc.returncode == 0
|
||||
|
||||
|
||||
def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""):
|
||||
"""
|
||||
Downloads a file from a URL (or alternate URL) to a specified path if file is above a minimum size.
|
||||
|
||||
Removes incomplete downloads.
|
||||
"""
|
||||
from utils.general import LOGGER
|
||||
|
||||
file = Path(file)
|
||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
||||
try: # url1
|
||||
LOGGER.info(f"Downloading {url} to {file}...")
|
||||
torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
|
||||
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
||||
except Exception as e: # url2
|
||||
if file.exists():
|
||||
file.unlink() # remove partial downloads
|
||||
LOGGER.info(f"ERROR: {e}\nRe-attempting {url2 or url} to {file}...")
|
||||
# curl download, retry and resume on fail
|
||||
curl_download(url2 or url, file)
|
||||
finally:
|
||||
if not file.exists() or file.stat().st_size < min_bytes: # check
|
||||
if file.exists():
|
||||
file.unlink() # remove partial downloads
|
||||
LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
|
||||
LOGGER.info("")
|
||||
|
||||
|
||||
def attempt_download(file, repo="ultralytics/yolov5", release="v7.0"):
|
||||
"""Downloads a file from GitHub release assets or via direct URL if not found locally, supporting backup
|
||||
versions.
|
||||
"""
|
||||
from utils.general import LOGGER
|
||||
|
||||
def github_assets(repository, version="latest"):
|
||||
"""Fetches GitHub repository release tag and asset names using the GitHub API."""
|
||||
if version != "latest":
|
||||
version = f"tags/{version}" # i.e. tags/v7.0
|
||||
response = requests.get(f"https://api.github.com/repos/{repository}/releases/{version}").json() # github api
|
||||
return response["tag_name"], [x["name"] for x in response["assets"]] # tag, assets
|
||||
|
||||
file = Path(str(file).strip().replace("'", ""))
|
||||
if not file.exists():
|
||||
# URL specified
|
||||
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
||||
if str(file).startswith(("http:/", "https:/")): # download
|
||||
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
|
||||
file = name.split("?")[0] # parse authentication https://url.com/file.txt?auth...
|
||||
if Path(file).is_file():
|
||||
LOGGER.info(f"Found {url} locally at {file}") # file already exists
|
||||
else:
|
||||
safe_download(file=file, url=url, min_bytes=1e5)
|
||||
return file
|
||||
|
||||
# GitHub assets
|
||||
assets = [f"yolov5{size}{suffix}.pt" for size in "nsmlx" for suffix in ("", "6", "-cls", "-seg")] # default
|
||||
try:
|
||||
tag, assets = github_assets(repo, release)
|
||||
except Exception:
|
||||
try:
|
||||
tag, assets = github_assets(repo) # latest release
|
||||
except Exception:
|
||||
try:
|
||||
tag = subprocess.check_output("git tag", shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
|
||||
except Exception:
|
||||
tag = release
|
||||
|
||||
if name in assets:
|
||||
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
||||
safe_download(
|
||||
file,
|
||||
url=f"https://github.com/{repo}/releases/download/{tag}/{name}",
|
||||
min_bytes=1e5,
|
||||
error_msg=f"{file} missing, try downloading from https://github.com/{repo}/releases/{tag}",
|
||||
)
|
||||
|
||||
return str(file)
|
70
utils/yolov5/utils/flask_rest_api/README.md
Normal file
70
utils/yolov5/utils/flask_rest_api/README.md
Normal file
@ -0,0 +1,70 @@
|
||||
# Flask REST API
|
||||
|
||||
[REST](https://en.wikipedia.org/wiki/Representational_state_transfer) [API](https://en.wikipedia.org/wiki/API)s are commonly used to expose Machine Learning (ML) models to other services. This folder contains an example REST API created using Flask to expose the YOLOv5s model from [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/).
|
||||
|
||||
## Requirements
|
||||
|
||||
[Flask](https://palletsprojects.com/projects/flask/) is required. Install with:
|
||||
|
||||
```shell
|
||||
$ pip install Flask
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
After Flask installation run:
|
||||
|
||||
```shell
|
||||
$ python3 restapi.py --port 5000
|
||||
```
|
||||
|
||||
Then use [curl](https://curl.se/) to perform a request:
|
||||
|
||||
```shell
|
||||
$ curl -X POST -F image=@zidane.jpg 'http://localhost:5000/v1/object-detection/yolov5s'
|
||||
```
|
||||
|
||||
The model inference results are returned as a JSON response:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"class": 0,
|
||||
"confidence": 0.8900438547,
|
||||
"height": 0.9318675399,
|
||||
"name": "person",
|
||||
"width": 0.3264600933,
|
||||
"xcenter": 0.7438579798,
|
||||
"ycenter": 0.5207948685
|
||||
},
|
||||
{
|
||||
"class": 0,
|
||||
"confidence": 0.8440024257,
|
||||
"height": 0.7155083418,
|
||||
"name": "person",
|
||||
"width": 0.6546785235,
|
||||
"xcenter": 0.427829951,
|
||||
"ycenter": 0.6334488392
|
||||
},
|
||||
{
|
||||
"class": 27,
|
||||
"confidence": 0.3771208823,
|
||||
"height": 0.3902671337,
|
||||
"name": "tie",
|
||||
"width": 0.0696444362,
|
||||
"xcenter": 0.3675483763,
|
||||
"ycenter": 0.7991207838
|
||||
},
|
||||
{
|
||||
"class": 27,
|
||||
"confidence": 0.3527112305,
|
||||
"height": 0.1540903747,
|
||||
"name": "tie",
|
||||
"width": 0.0336618312,
|
||||
"xcenter": 0.7814827561,
|
||||
"ycenter": 0.5065554976
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
An example python script to perform inference using [requests](https://docs.python-requests.org/en/master/) is given in `example_request.py`
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user