优化目前版本中的问题
This commit is contained in:
@ -132,8 +132,7 @@ def get_log_pager(detect_log_pager: ProjectDetectLogPager, session: Session = De
|
||||
:return:
|
||||
"""
|
||||
result = pdc.get_log_pager(detect_log_pager, session)
|
||||
result = jsonable_encoder(result)
|
||||
return rc.response_success(data=result)
|
||||
return rc.response_success_pager(result)
|
||||
|
||||
|
||||
@detect.get("/get_log_imgs/{log_id}")
|
||||
|
@ -141,9 +141,11 @@ def del_label(label_id: int, session: Session = Depends(get_db)):
|
||||
@project.post("/up_proj_img")
|
||||
def upload_project_image(project_id: int = Form(...),
|
||||
files: List[UploadFile] = File(...),
|
||||
img_type: str = Form(...),
|
||||
session: Session = Depends(get_db)):
|
||||
"""
|
||||
上传项目图片
|
||||
:param img_type:
|
||||
:param files: 文件图片
|
||||
:param project_id:
|
||||
:param session:
|
||||
@ -152,10 +154,10 @@ def upload_project_image(project_id: int = Form(...),
|
||||
project_info = pic.get_project_by_id(project_id, session)
|
||||
if project_info is None:
|
||||
return rc.response_error("项目查询错误,请刷新页面后再试")
|
||||
is_check, file_name = ps.check_image_name(project_id, files, session)
|
||||
is_check, file_name = ps.check_image_name(project_id, img_type, files, session)
|
||||
if not is_check:
|
||||
return rc.response_error(msg="存在重名的图片文件:" + file_name)
|
||||
ps.upload_project_image(project_info, files, session)
|
||||
ps.upload_project_image(project_info,img_type, files, session)
|
||||
return rc.response_success(msg="上传成功")
|
||||
|
||||
|
||||
@ -228,6 +230,16 @@ async def run_train(project_id: int, session: Session = Depends(get_db)):
|
||||
return rc.response_error("项目查询错误")
|
||||
if project_info.project_status == '1':
|
||||
return rc.response_error("项目当前存在训练进程,请稍后再试")
|
||||
train_img_count = pimc.get_image_count(project_id, 'train', session)
|
||||
if train_img_count == 0:
|
||||
return rc.response_error("请先上传训练图片")
|
||||
if train_img_count < 10:
|
||||
return rc.response_error("训练图片少于10张,请继续上传训练图片")
|
||||
val_img_count = pimc.get_image_count(project_id, 'val', session)
|
||||
if val_img_count == 0:
|
||||
return rc.response_error("请先上传验证图片")
|
||||
if val_img_count < 5:
|
||||
return rc.response_error("验证图片少于5张,请继续上传验证图片")
|
||||
data, project_name, name = ps.run_train_yolo(project_info, session)
|
||||
return StreamingResponse(
|
||||
ps.run_commend(data, project_name, name, 50, project_id, session),
|
||||
|
27
app/api/business/websocket_api.py
Normal file
27
app/api/business/websocket_api.py
Normal file
@ -0,0 +1,27 @@
|
||||
from fastapi import APIRouter, WebSocket
|
||||
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
|
||||
|
||||
web_socket = APIRouter()
|
||||
|
||||
|
||||
@web_socket.websocket("/{room}")
|
||||
async def websocket_room(websocket: WebSocket, room: str):
|
||||
"""
|
||||
websocket 房间管理
|
||||
:param websocket:
|
||||
:param room:
|
||||
:return:
|
||||
"""
|
||||
await websocket.accept()
|
||||
await room_manager.add_to_room(room, websocket)
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket)
|
||||
except Exception as e:
|
||||
print(f"连接关闭: {e}")
|
||||
finally:
|
||||
await room_manager.remove_from_room(room, websocket)
|
||||
await websocket.close()
|
@ -1,5 +1,6 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import asyncio
|
||||
|
||||
from app.application.token_middleware import TokenMiddleware
|
||||
from app.application.logger_middleware import LoggerMiddleware
|
||||
@ -9,6 +10,9 @@ from app.api.sys.sys_user_api import user
|
||||
from app.api.business.project_train_api import project
|
||||
from app.api.common.view_img import view
|
||||
from app.api.business.project_detect_api import detect
|
||||
from app.api.business.websocket_api import web_socket
|
||||
from app.util.ps_util import get_server_json
|
||||
from app.websocket.web_socket_server import room_manager
|
||||
|
||||
my_app = FastAPI()
|
||||
|
||||
@ -30,10 +34,27 @@ my_app.add_middleware(
|
||||
my_app.add_middleware(LoggerMiddleware)
|
||||
my_app.add_middleware(TokenMiddleware)
|
||||
|
||||
my_app.include_router(view, tags=["查看图片"])
|
||||
my_app.include_router(login, prefix="/login", tags=["用户登录接口"])
|
||||
my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
|
||||
my_app.include_router(view, tags=["查看图片"])
|
||||
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
|
||||
my_app.include_router(project, prefix="/proj", tags=["项目训练API"])
|
||||
my_app.include_router(detect, prefix="/detect", tags=["项目推理API"])
|
||||
my_app.include_router(web_socket, prefix="/ws", tags=["websocket管理"])
|
||||
|
||||
|
||||
# fastapi定时执行任务
|
||||
async def periodic_task():
|
||||
while True:
|
||||
server_json = get_server_json()
|
||||
await room_manager.send_to_room('server-info', server_json)
|
||||
await asyncio.sleep(2) # 每 10 秒执行
|
||||
|
||||
|
||||
@my_app.on_event("startup")
|
||||
async def start_periodic_task():
|
||||
# 在后台启动异步任务
|
||||
asyncio.create_task(periodic_task())
|
||||
|
||||
|
||||
|
||||
|
@ -45,6 +45,7 @@ class ProjectImage(DbCommon):
|
||||
项目图片表
|
||||
"""
|
||||
__tablename__ = "project_image"
|
||||
img_type: Mapped[str] = mapped_column(String(10))
|
||||
file_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
@ -33,7 +33,8 @@ def get_image_pager2(image: ProjectImage, session: Session):
|
||||
)
|
||||
.outerjoin(subquery, piModel.id == subquery.c.image_id)
|
||||
)
|
||||
query = query.filter(piModel.project_id == image.project_id).order_by(asc(piModel.id))
|
||||
query = query.filter(piModel.project_id == image.project_id)\
|
||||
.filter(piModel.img_type == image.img_type).order_by(asc(piModel.id))
|
||||
pager = get_pager(query, image.pagerNum, image.pagerSize)
|
||||
datas = []
|
||||
for result in pager.data:
|
||||
@ -44,16 +45,20 @@ def get_image_pager2(image: ProjectImage, session: Session):
|
||||
return pager
|
||||
|
||||
|
||||
def check_img_name(project_id: int, file_name: str, session: Session):
|
||||
def check_img_name(project_id: int, img_type: str, file_name: str, session: Session):
|
||||
"""
|
||||
根据项目id和文件名称进行查重
|
||||
:param project_id:
|
||||
:param img_type:
|
||||
:param file_name:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
image = session.query(piModel)\
|
||||
.filter(piModel.project_id == project_id).filter(piModel.file_name == file_name).first()
|
||||
.filter(piModel.project_id == project_id)\
|
||||
.filter(piModel.file_name == file_name)\
|
||||
.filter(piModel.img_type == img_type)\
|
||||
.first()
|
||||
if image is not None:
|
||||
return False
|
||||
return True
|
||||
@ -72,17 +77,26 @@ def get_img_url(image_id: int, session: Session):
|
||||
return sour_url, thumb_url
|
||||
|
||||
|
||||
def get_image_list(project_id: int, session: Session):
|
||||
query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
|
||||
def get_image_list(project_id: int, img_type: str, session: Session):
|
||||
query = session.query(piModel).filter(piModel.project_id == project_id)\
|
||||
.filter(piModel.img_type == img_type)\
|
||||
.order_by(asc(piModel.id))
|
||||
image_list = [ProjectImage.from_orm(image) for image in query.all()]
|
||||
return image_list
|
||||
|
||||
|
||||
def get_images(project_id: int, session: Session):
|
||||
query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
|
||||
def get_images(project_id: int, img_type: str, session: Session):
|
||||
query = session.query(piModel).filter(piModel.project_id == project_id)\
|
||||
.filter(piModel.img_type == img_type).order_by(asc(piModel.id))
|
||||
return query.all()
|
||||
|
||||
|
||||
def get_image_count(project_id: int, img_type: str, session: Session):
|
||||
query = session.query(piModel).filter(piModel.project_id == project_id)\
|
||||
.filter(piModel.img_type == img_type).order_by(asc(piModel.id))
|
||||
return query.count()
|
||||
|
||||
|
||||
def add_image(image: ProjectImage, session: Session):
|
||||
session.add(image)
|
||||
session.commit()
|
||||
|
@ -6,6 +6,7 @@ from typing import Optional, List
|
||||
class ProjectImage(BaseModel):
|
||||
id: Optional[int] = Field(None, description="id")
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
img_type: Optional[str] = Field(None, description="图片类别")
|
||||
file_name: Optional[str] = Field(None, description="文件名称")
|
||||
create_time: Optional[datetime] = Field(None, description="上传时间")
|
||||
|
||||
@ -32,6 +33,7 @@ class ProjectImageOut(BaseModel):
|
||||
|
||||
class ProjectImagePager(BaseModel):
|
||||
project_id: Optional[int] = Field(..., description="项目id")
|
||||
img_type: Optional[str] = Field(None, description="图片类别")
|
||||
pagerNum: Optional[int] = Field(None, description="当前页码")
|
||||
pagerSize: Optional[int] = Field(None, description="每页数量")
|
||||
|
||||
|
@ -36,17 +36,18 @@ def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||
return project_info.id
|
||||
|
||||
|
||||
def check_image_name(project_id: int, files: List[UploadFile], session: Session):
|
||||
def check_image_name(project_id: int, img_type: str, files: List[UploadFile], session: Session):
|
||||
for file in files:
|
||||
if not pimc.check_img_name(project_id, file.filename, session):
|
||||
if not pimc.check_img_name(project_id, img_type, file.filename, session):
|
||||
return False, file.filename
|
||||
return True, None
|
||||
|
||||
|
||||
def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session):
|
||||
def upload_project_image(project_info: ProjectInfoOut, img_type: str, files: List[UploadFile], session: Session):
|
||||
"""
|
||||
上传项目的图片
|
||||
:param files: 上传的图片
|
||||
:param img_type: 上传的图片类别
|
||||
:param project_info: 项目信息
|
||||
:param session:
|
||||
:return:
|
||||
@ -56,6 +57,7 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile],
|
||||
image = ProjectImage()
|
||||
image.project_id = project_info.id
|
||||
image.file_name = file.filename
|
||||
image.img_type = img_type
|
||||
# 保存原图
|
||||
path = os.save_images(images_url, project_info.project_no, file=file)
|
||||
image.image_url = path
|
||||
@ -124,11 +126,9 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session):
|
||||
:param session: 数据库session
|
||||
:return:
|
||||
"""
|
||||
# 先获取项目的所有图片
|
||||
project_images = pimc.get_images(project_info.id, session)
|
||||
|
||||
# 将图片根据,根据3:1的比例将图片分成train:val的两个数组
|
||||
project_images_train, project_images_val = split_array(project_images)
|
||||
# 先查询两个图片列表
|
||||
project_images_train = pimc.get_images(project_info.id, 'train', session)
|
||||
project_images_val = pimc.get_images(project_info.id, 'val', session)
|
||||
|
||||
# 得到训练版本
|
||||
version_path = 'v' + str(project_info.train_version + 1)
|
||||
|
77
app/util/ps_util.py
Normal file
77
app/util/ps_util.py
Normal file
@ -0,0 +1,77 @@
|
||||
import psutil
|
||||
import platform
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_server_info():
|
||||
|
||||
info = {}
|
||||
|
||||
# 1. 系统基本信息
|
||||
info["system"] = {
|
||||
"OS": platform.system(),
|
||||
"OS版本": platform.version(),
|
||||
"主机名": platform.node(),
|
||||
"架构": platform.machine(),
|
||||
"处理器型号": platform.processor(),
|
||||
"启动时间": datetime.fromtimestamp(psutil.boot_time()).strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
# 2. CPU 信息
|
||||
cpu_usage = psutil.cpu_percent(interval=1)
|
||||
info["cpu"] = {
|
||||
"物理核心数": psutil.cpu_count(logical=False),
|
||||
"逻辑核心数": psutil.cpu_count(logical=True),
|
||||
"当前使用率 (%)": cpu_usage,
|
||||
"每个核心使用率": psutil.cpu_percent(interval=1, percpu=True)
|
||||
}
|
||||
|
||||
# 3. 内存信息
|
||||
mem = psutil.virtual_memory()
|
||||
info["memory"] = {
|
||||
"总内存 (GB)": round(mem.total / (1024**3), 2),
|
||||
"可用内存 (GB)": round(mem.available / (1024**3), 2),
|
||||
"使用率 (%)": mem.percent
|
||||
}
|
||||
|
||||
# 4. 磁盘信息
|
||||
disks = []
|
||||
for partition in psutil.disk_partitions():
|
||||
usage = psutil.disk_usage(partition.mountpoint)
|
||||
disks.append({
|
||||
"设备": partition.device,
|
||||
"挂载点": partition.mountpoint,
|
||||
"文件系统": partition.fstype,
|
||||
"总空间 (GB)": round(usage.total / (1024**3), 2),
|
||||
"已用空间 (GB)": round(usage.used / (1024**3), 2),
|
||||
"使用率 (%)": usage.percent
|
||||
})
|
||||
info["disks"] = disks
|
||||
|
||||
# 5. 网络信息
|
||||
net = psutil.net_io_counters()
|
||||
info["network"] = {
|
||||
"发送流量 (MB)": round(net.bytes_sent / (1024**2), 2),
|
||||
"接收流量 (MB)": round(net.bytes_recv / (1024**2), 2)
|
||||
}
|
||||
|
||||
# 6. 进程信息(示例:前5个高CPU进程)
|
||||
processes = []
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cpu_percent']):
|
||||
if len(processes) >= 5:
|
||||
break
|
||||
if proc.info['cpu_percent'] > 0:
|
||||
processes.append({
|
||||
"PID": proc.info['pid'],
|
||||
"进程名": proc.info['name'],
|
||||
"CPU使用率 (%)": proc.info['cpu_percent']
|
||||
})
|
||||
info["top_processes"] = sorted(processes, key=lambda x: x["CPU使用率 (%)"], reverse=True)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def get_server_json():
|
||||
server_info = get_server_info()
|
||||
return json.dumps(server_info, indent=2, ensure_ascii=False)
|
37
app/websocket/web_socket_server.py
Normal file
37
app/websocket/web_socket_server.py
Normal file
@ -0,0 +1,37 @@
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
class SocketManager:
|
||||
def __init__(self):
|
||||
self.rooms = {}
|
||||
|
||||
async def add_to_room(self, room: str, websocket: WebSocket):
|
||||
if room not in self.rooms:
|
||||
self.rooms[room] = []
|
||||
self.rooms[room].append(websocket)
|
||||
|
||||
async def remove_from_room(self, room: str, websocket: WebSocket):
|
||||
if room in self.rooms:
|
||||
self.rooms[room].remove(websocket)
|
||||
if len(self.rooms[room]) == 0:
|
||||
del self.rooms[room]
|
||||
|
||||
async def broadcast_to_room(self, room: str, message: str, exclude_websocket: WebSocket = None):
|
||||
if room in self.rooms:
|
||||
for ws in self.rooms[room]:
|
||||
if ws != exclude_websocket:
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except:
|
||||
await self.remove_from_room(room, ws)
|
||||
|
||||
async def send_to_room(self, room: str, message: str):
|
||||
if room in self.rooms:
|
||||
for ws in self.rooms[room]:
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except:
|
||||
await self.remove_from_room(room, ws)
|
||||
|
||||
|
||||
room_manager = SocketManager()
|
@ -16,6 +16,7 @@ bcrypt==3.2.0
|
||||
pymysql==1.0.2
|
||||
pynvml==12.0.0
|
||||
requests-toolbelt==1.0.0
|
||||
python-socketio == 5.12.1
|
||||
|
||||
|
||||
# YOLOV5 ----------------------------------------------------------------------
|
||||
|
Reference in New Issue
Block a user