优化目前版本中的问题

This commit is contained in:
2025-03-10 13:56:25 +08:00
parent 758082db14
commit b4b1085403
11 changed files with 211 additions and 20 deletions

View File

@ -132,8 +132,7 @@ def get_log_pager(detect_log_pager: ProjectDetectLogPager, session: Session = De
:return:
"""
result = pdc.get_log_pager(detect_log_pager, session)
result = jsonable_encoder(result)
return rc.response_success(data=result)
return rc.response_success_pager(result)
@detect.get("/get_log_imgs/{log_id}")

View File

@ -141,9 +141,11 @@ def del_label(label_id: int, session: Session = Depends(get_db)):
@project.post("/up_proj_img")
def upload_project_image(project_id: int = Form(...),
files: List[UploadFile] = File(...),
img_type: str = Form(...),
session: Session = Depends(get_db)):
"""
上传项目图片
:param img_type:
:param files: 文件图片
:param project_id:
:param session:
@ -152,10 +154,10 @@ def upload_project_image(project_id: int = Form(...),
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误,请刷新页面后再试")
is_check, file_name = ps.check_image_name(project_id, files, session)
is_check, file_name = ps.check_image_name(project_id, img_type, files, session)
if not is_check:
return rc.response_error(msg="存在重名的图片文件:" + file_name)
ps.upload_project_image(project_info, files, session)
ps.upload_project_image(project_info,img_type, files, session)
return rc.response_success(msg="上传成功")
@ -228,6 +230,16 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
return rc.response_error("项目查询错误")
if project_info.project_status == '1':
return rc.response_error("项目当前存在训练进程,请稍后再试")
train_img_count = pimc.get_image_count(project_id, 'train', session)
if train_img_count == 0:
return rc.response_error("请先上传训练图片")
if train_img_count < 10:
return rc.response_error("训练图片少于10张请继续上传训练图片")
val_img_count = pimc.get_image_count(project_id, 'val', session)
if val_img_count == 0:
return rc.response_error("请先上传验证图片")
if val_img_count < 5:
return rc.response_error("验证图片少于5张请继续上传验证图片")
data, project_name, name = ps.run_train_yolo(project_info, session)
return StreamingResponse(
ps.run_commend(data, project_name, name, 50, project_id, session),

View File

@ -0,0 +1,27 @@
from fastapi import APIRouter, WebSocket
from app.websocket.web_socket_server import room_manager
web_socket = APIRouter()
@web_socket.websocket("/{room}")
async def websocket_room(websocket: WebSocket, room: str):
"""
websocket 房间管理
:param websocket:
:param room:
:return:
"""
await websocket.accept()
await room_manager.add_to_room(room, websocket)
try:
while True:
data = await websocket.receive_text()
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
except Exception as e:
print(f"连接关闭: {e}")
finally:
await room_manager.remove_from_room(room, websocket)
await websocket.close()

View File

@ -1,5 +1,6 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import asyncio
from app.application.token_middleware import TokenMiddleware
from app.application.logger_middleware import LoggerMiddleware
@ -9,6 +10,9 @@ from app.api.sys.sys_user_api import user
from app.api.business.project_train_api import project
from app.api.common.view_img import view
from app.api.business.project_detect_api import detect
from app.api.business.websocket_api import web_socket
from app.util.ps_util import get_server_json
from app.websocket.web_socket_server import room_manager
my_app = FastAPI()
@ -30,10 +34,27 @@ my_app.add_middleware(
my_app.add_middleware(LoggerMiddleware)
my_app.add_middleware(TokenMiddleware)
my_app.include_router(view, tags=["查看图片"])
my_app.include_router(login, prefix="/login", tags=["用户登录接口"])
my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
my_app.include_router(view, tags=["查看图片"])
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
my_app.include_router(project, prefix="/proj", tags=["项目训练API"])
my_app.include_router(detect, prefix="/detect", tags=["项目推理API"])
my_app.include_router(web_socket, prefix="/ws", tags=["websocket管理"])
# fastapi定时执行任务
async def periodic_task():
while True:
server_json = get_server_json()
await room_manager.send_to_room('server-info', server_json)
await asyncio.sleep(2) # 每 10 秒执行
@my_app.on_event("startup")
async def start_periodic_task():
# 在后台启动异步任务
asyncio.create_task(periodic_task())

View File

@ -45,6 +45,7 @@ class ProjectImage(DbCommon):
项目图片表
"""
__tablename__ = "project_image"
img_type: Mapped[str] = mapped_column(String(10))
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -33,7 +33,8 @@ def get_image_pager2(image: ProjectImage, session: Session):
)
.outerjoin(subquery, piModel.id == subquery.c.image_id)
)
query = query.filter(piModel.project_id == image.project_id).order_by(asc(piModel.id))
query = query.filter(piModel.project_id == image.project_id)\
.filter(piModel.img_type == image.img_type).order_by(asc(piModel.id))
pager = get_pager(query, image.pagerNum, image.pagerSize)
datas = []
for result in pager.data:
@ -44,16 +45,20 @@ def get_image_pager2(image: ProjectImage, session: Session):
return pager
def check_img_name(project_id: int, file_name: str, session: Session):
def check_img_name(project_id: int, img_type: str, file_name: str, session: Session):
"""
根据项目id和文件名称进行查重
:param project_id:
:param img_type:
:param file_name:
:param session:
:return:
"""
image = session.query(piModel)\
.filter(piModel.project_id == project_id).filter(piModel.file_name == file_name).first()
.filter(piModel.project_id == project_id)\
.filter(piModel.file_name == file_name)\
.filter(piModel.img_type == img_type)\
.first()
if image is not None:
return False
return True
@ -72,17 +77,26 @@ def get_img_url(image_id: int, session: Session):
return sour_url, thumb_url
def get_image_list(project_id: int, session: Session):
query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
def get_image_list(project_id: int, img_type: str, session: Session):
query = session.query(piModel).filter(piModel.project_id == project_id)\
.filter(piModel.img_type == img_type)\
.order_by(asc(piModel.id))
image_list = [ProjectImage.from_orm(image) for image in query.all()]
return image_list
def get_images(project_id: int, session: Session):
query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
def get_images(project_id: int, img_type: str, session: Session):
query = session.query(piModel).filter(piModel.project_id == project_id)\
.filter(piModel.img_type == img_type).order_by(asc(piModel.id))
return query.all()
def get_image_count(project_id: int, img_type: str, session: Session):
query = session.query(piModel).filter(piModel.project_id == project_id)\
.filter(piModel.img_type == img_type).order_by(asc(piModel.id))
return query.count()
def add_image(image: ProjectImage, session: Session):
session.add(image)
session.commit()

View File

@ -6,6 +6,7 @@ from typing import Optional, List
class ProjectImage(BaseModel):
id: Optional[int] = Field(None, description="id")
project_id: Optional[int] = Field(..., description="项目id")
img_type: Optional[str] = Field(None, description="图片类别")
file_name: Optional[str] = Field(None, description="文件名称")
create_time: Optional[datetime] = Field(None, description="上传时间")
@ -32,6 +33,7 @@ class ProjectImageOut(BaseModel):
class ProjectImagePager(BaseModel):
project_id: Optional[int] = Field(..., description="项目id")
img_type: Optional[str] = Field(None, description="图片类别")
pagerNum: Optional[int] = Field(None, description="当前页码")
pagerSize: Optional[int] = Field(None, description="每页数量")

View File

@ -36,17 +36,18 @@ def add_project(info: ProjectInfoIn, session: Session, user_id: int):
return project_info.id
def check_image_name(project_id: int, files: List[UploadFile], session: Session):
def check_image_name(project_id: int, img_type: str, files: List[UploadFile], session: Session):
for file in files:
if not pimc.check_img_name(project_id, file.filename, session):
if not pimc.check_img_name(project_id, img_type, file.filename, session):
return False, file.filename
return True, None
def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session):
def upload_project_image(project_info: ProjectInfoOut, img_type: str, files: List[UploadFile], session: Session):
"""
上传项目的图片
:param files: 上传的图片
:param img_type: 上传的图片类别
:param project_info: 项目信息
:param session:
:return:
@ -56,6 +57,7 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile],
image = ProjectImage()
image.project_id = project_info.id
image.file_name = file.filename
image.img_type = img_type
# 保存原图
path = os.save_images(images_url, project_info.project_no, file=file)
image.image_url = path
@ -124,11 +126,9 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
:param session: 数据库session
:return:
"""
# 先获取项目的所有图片
project_images = pimc.get_images(project_info.id, session)
# 将图片根据,根据3:1的比例将图片分成trainval的两个数组
project_images_train, project_images_val = split_array(project_images)
# 先查询两个图片列表
project_images_train = pimc.get_images(project_info.id, 'train', session)
project_images_val = pimc.get_images(project_info.id, 'val', session)
# 得到训练版本
version_path = 'v' + str(project_info.train_version + 1)

77
app/util/ps_util.py Normal file
View File

@ -0,0 +1,77 @@
import psutil
import platform
import json
from datetime import datetime
def get_server_info():
info = {}
# 1. 系统基本信息
info["system"] = {
"OS": platform.system(),
"OS版本": platform.version(),
"主机名": platform.node(),
"架构": platform.machine(),
"处理器型号": platform.processor(),
"启动时间": datetime.fromtimestamp(psutil.boot_time()).strftime("%Y-%m-%d %H:%M:%S")
}
# 2. CPU 信息
cpu_usage = psutil.cpu_percent(interval=1)
info["cpu"] = {
"物理核心数": psutil.cpu_count(logical=False),
"逻辑核心数": psutil.cpu_count(logical=True),
"当前使用率 (%)": cpu_usage,
"每个核心使用率": psutil.cpu_percent(interval=1, percpu=True)
}
# 3. 内存信息
mem = psutil.virtual_memory()
info["memory"] = {
"总内存 (GB)": round(mem.total / (1024**3), 2),
"可用内存 (GB)": round(mem.available / (1024**3), 2),
"使用率 (%)": mem.percent
}
# 4. 磁盘信息
disks = []
for partition in psutil.disk_partitions():
usage = psutil.disk_usage(partition.mountpoint)
disks.append({
"设备": partition.device,
"挂载点": partition.mountpoint,
"文件系统": partition.fstype,
"总空间 (GB)": round(usage.total / (1024**3), 2),
"已用空间 (GB)": round(usage.used / (1024**3), 2),
"使用率 (%)": usage.percent
})
info["disks"] = disks
# 5. 网络信息
net = psutil.net_io_counters()
info["network"] = {
"发送流量 (MB)": round(net.bytes_sent / (1024**2), 2),
"接收流量 (MB)": round(net.bytes_recv / (1024**2), 2)
}
# 6. 进程信息示例前5个高CPU进程
processes = []
for proc in psutil.process_iter(['pid', 'name', 'cpu_percent']):
if len(processes) >= 5:
break
if proc.info['cpu_percent'] > 0:
processes.append({
"PID": proc.info['pid'],
"进程名": proc.info['name'],
"CPU使用率 (%)": proc.info['cpu_percent']
})
info["top_processes"] = sorted(processes, key=lambda x: x["CPU使用率 (%)"], reverse=True)
return info
def get_server_json():
server_info = get_server_info()
return json.dumps(server_info, indent=2, ensure_ascii=False)

View File

@ -0,0 +1,37 @@
from fastapi import WebSocket
class SocketManager:
def __init__(self):
self.rooms = {}
async def add_to_room(self, room: str, websocket: WebSocket):
if room not in self.rooms:
self.rooms[room] = []
self.rooms[room].append(websocket)
async def remove_from_room(self, room: str, websocket: WebSocket):
if room in self.rooms:
self.rooms[room].remove(websocket)
if len(self.rooms[room]) == 0:
del self.rooms[room]
async def broadcast_to_room(self, room: str, message: str, exclude_websocket: WebSocket = None):
if room in self.rooms:
for ws in self.rooms[room]:
if ws != exclude_websocket:
try:
await ws.send_text(message)
except:
await self.remove_from_room(room, ws)
async def send_to_room(self, room: str, message: str):
if room in self.rooms:
for ws in self.rooms[room]:
try:
await ws.send_text(message)
except:
await self.remove_from_room(room, ws)
room_manager = SocketManager()

View File

@ -16,6 +16,7 @@ bcrypt==3.2.0
pymysql==1.0.2
pynvml==12.0.0
requests-toolbelt==1.0.0
python-socketio == 5.12.1
# YOLOV5 ----------------------------------------------------------------------