diff --git a/app/api/business/project_detect_api.py b/app/api/business/project_detect_api.py index 15cf362..e911db2 100644 --- a/app/api/business/project_detect_api.py +++ b/app/api/business/project_detect_api.py @@ -132,8 +132,7 @@ def get_log_pager(detect_log_pager: ProjectDetectLogPager, session: Session = De :return: """ result = pdc.get_log_pager(detect_log_pager, session) - result = jsonable_encoder(result) - return rc.response_success(data=result) + return rc.response_success_pager(result) @detect.get("/get_log_imgs/{log_id}") diff --git a/app/api/business/project_train_api.py b/app/api/business/project_train_api.py index 65ecc52..9ebc0b2 100644 --- a/app/api/business/project_train_api.py +++ b/app/api/business/project_train_api.py @@ -141,9 +141,11 @@ def del_label(label_id: int, session: Session = Depends(get_db)): @project.post("/up_proj_img") def upload_project_image(project_id: int = Form(...), files: List[UploadFile] = File(...), + img_type: str = Form(...), session: Session = Depends(get_db)): """ 上传项目图片 + :param img_type: :param files: 文件图片 :param project_id: :param session: @@ -152,10 +154,10 @@ def upload_project_image(project_id: int = Form(...), project_info = pic.get_project_by_id(project_id, session) if project_info is None: return rc.response_error("项目查询错误,请刷新页面后再试") - is_check, file_name = ps.check_image_name(project_id, files, session) + is_check, file_name = ps.check_image_name(project_id, img_type, files, session) if not is_check: return rc.response_error(msg="存在重名的图片文件:" + file_name) - ps.upload_project_image(project_info, files, session) + ps.upload_project_image(project_info,img_type, files, session) return rc.response_success(msg="上传成功") @@ -228,6 +230,16 @@ async def run_train(project_id: int, session: Session = Depends(get_db)): return rc.response_error("项目查询错误") if project_info.project_status == '1': return rc.response_error("项目当前存在训练进程,请稍后再试") + train_img_count = pimc.get_image_count(project_id, 'train', session) + if train_img_count == 0: + return rc.response_error("请先上传训练图片") + if train_img_count < 10: + return rc.response_error("训练图片少于10张,请继续上传训练图片") + val_img_count = pimc.get_image_count(project_id, 'val', session) + if val_img_count == 0: + return rc.response_error("请先上传验证图片") + if val_img_count < 5: + return rc.response_error("验证图片少于5张,请继续上传验证图片") data, project_name, name = ps.run_train_yolo(project_info, session) return StreamingResponse( ps.run_commend(data, project_name, name, 50, project_id, session), diff --git a/app/api/business/websocket_api.py b/app/api/business/websocket_api.py new file mode 100644 index 0000000..006ad0d --- /dev/null +++ b/app/api/business/websocket_api.py @@ -0,0 +1,27 @@ +from fastapi import APIRouter, WebSocket + +from app.websocket.web_socket_server import room_manager + + +web_socket = APIRouter() + + +@web_socket.websocket("/{room}") +async def websocket_room(websocket: WebSocket, room: str): + """ + websocket 房间管理 + :param websocket: + :param room: + :return: + """ + await websocket.accept() + await room_manager.add_to_room(room, websocket) + try: + while True: + data = await websocket.receive_text() + await room_manager.broadcast_to_room(room, data, exclude_websocket=websocket) + except Exception as e: + print(f"连接关闭: {e}") + finally: + await room_manager.remove_from_room(room, websocket) + await websocket.close() diff --git a/app/application/app.py b/app/application/app.py index dc5af44..2cd6cf9 100644 --- a/app/application/app.py +++ b/app/application/app.py @@ -1,5 +1,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +import asyncio from app.application.token_middleware import TokenMiddleware from app.application.logger_middleware import LoggerMiddleware @@ -9,6 +10,9 @@ from app.api.sys.sys_user_api import user from app.api.business.project_train_api import project from app.api.common.view_img import view from app.api.business.project_detect_api import detect +from app.api.business.websocket_api import web_socket +from app.util.ps_util import get_server_json +from app.websocket.web_socket_server import room_manager my_app = FastAPI() @@ -30,10 +34,27 @@ my_app.add_middleware( my_app.add_middleware(LoggerMiddleware) my_app.add_middleware(TokenMiddleware) +my_app.include_router(view, tags=["查看图片"]) my_app.include_router(login, prefix="/login", tags=["用户登录接口"]) my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"]) -my_app.include_router(view, tags=["查看图片"]) my_app.include_router(user, prefix="/user", tags=["用户管理API"]) my_app.include_router(project, prefix="/proj", tags=["项目训练API"]) my_app.include_router(detect, prefix="/detect", tags=["项目推理API"]) +my_app.include_router(web_socket, prefix="/ws", tags=["websocket管理"]) + + +# fastapi定时执行任务 +async def periodic_task(): + while True: + server_json = get_server_json() + await room_manager.send_to_room('server-info', server_json) + await asyncio.sleep(2) # 每 10 秒执行 + + +@my_app.on_event("startup") +async def start_periodic_task(): + # 在后台启动异步任务 + asyncio.create_task(periodic_task()) + + diff --git a/app/model/bussiness_model.py b/app/model/bussiness_model.py index 16195be..512d2fa 100644 --- a/app/model/bussiness_model.py +++ b/app/model/bussiness_model.py @@ -45,6 +45,7 @@ class ProjectImage(DbCommon): 项目图片表 """ __tablename__ = "project_image" + img_type: Mapped[str] = mapped_column(String(10)) file_name: Mapped[str] = mapped_column(String(64), nullable=False) image_url: Mapped[str] = mapped_column(String(255), nullable=False) thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False) diff --git a/app/model/crud/project_image_crud.py b/app/model/crud/project_image_crud.py index aeb8316..e8ffbfd 100644 --- a/app/model/crud/project_image_crud.py +++ b/app/model/crud/project_image_crud.py @@ -33,7 +33,8 @@ def get_image_pager2(image: ProjectImage, session: Session): ) .outerjoin(subquery, piModel.id == subquery.c.image_id) ) - query = query.filter(piModel.project_id == image.project_id).order_by(asc(piModel.id)) + query = query.filter(piModel.project_id == image.project_id)\ + .filter(piModel.img_type == image.img_type).order_by(asc(piModel.id)) pager = get_pager(query, image.pagerNum, image.pagerSize) datas = [] for result in pager.data: @@ -44,16 +45,20 @@ def get_image_pager2(image: ProjectImage, session: Session): return pager -def check_img_name(project_id: int, file_name: str, session: Session): +def check_img_name(project_id: int, img_type: str, file_name: str, session: Session): """ 根据项目id和文件名称进行查重 :param project_id: + :param img_type: :param file_name: :param session: :return: """ image = session.query(piModel)\ - .filter(piModel.project_id == project_id).filter(piModel.file_name == file_name).first() + .filter(piModel.project_id == project_id)\ + .filter(piModel.file_name == file_name)\ + .filter(piModel.img_type == img_type)\ + .first() if image is not None: return False return True @@ -72,17 +77,26 @@ def get_img_url(image_id: int, session: Session): return sour_url, thumb_url -def get_image_list(project_id: int, session: Session): - query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id)) +def get_image_list(project_id: int, img_type: str, session: Session): + query = session.query(piModel).filter(piModel.project_id == project_id)\ + .filter(piModel.img_type == img_type)\ + .order_by(asc(piModel.id)) image_list = [ProjectImage.from_orm(image) for image in query.all()] return image_list -def get_images(project_id: int, session: Session): - query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id)) +def get_images(project_id: int, img_type: str, session: Session): + query = session.query(piModel).filter(piModel.project_id == project_id)\ + .filter(piModel.img_type == img_type).order_by(asc(piModel.id)) return query.all() +def get_image_count(project_id: int, img_type: str, session: Session): + query = session.query(piModel).filter(piModel.project_id == project_id)\ + .filter(piModel.img_type == img_type).order_by(asc(piModel.id)) + return query.count() + + def add_image(image: ProjectImage, session: Session): session.add(image) session.commit() diff --git a/app/model/schemas/project_image_schemas.py b/app/model/schemas/project_image_schemas.py index bcc8e79..b8d6b5e 100644 --- a/app/model/schemas/project_image_schemas.py +++ b/app/model/schemas/project_image_schemas.py @@ -6,6 +6,7 @@ from typing import Optional, List class ProjectImage(BaseModel): id: Optional[int] = Field(None, description="id") project_id: Optional[int] = Field(..., description="项目id") + img_type: Optional[str] = Field(None, description="图片类别") file_name: Optional[str] = Field(None, description="文件名称") create_time: Optional[datetime] = Field(None, description="上传时间") @@ -32,6 +33,7 @@ class ProjectImageOut(BaseModel): class ProjectImagePager(BaseModel): project_id: Optional[int] = Field(..., description="项目id") + img_type: Optional[str] = Field(None, description="图片类别") pagerNum: Optional[int] = Field(None, description="当前页码") pagerSize: Optional[int] = Field(None, description="每页数量") diff --git a/app/service/project_train_service.py b/app/service/project_train_service.py index 6429497..9acaf02 100644 --- a/app/service/project_train_service.py +++ b/app/service/project_train_service.py @@ -36,17 +36,18 @@ def add_project(info: ProjectInfoIn, session: Session, user_id: int): return project_info.id -def check_image_name(project_id: int, files: List[UploadFile], session: Session): +def check_image_name(project_id: int, img_type: str, files: List[UploadFile], session: Session): for file in files: - if not pimc.check_img_name(project_id, file.filename, session): + if not pimc.check_img_name(project_id, img_type, file.filename, session): return False, file.filename return True, None -def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session): +def upload_project_image(project_info: ProjectInfoOut, img_type: str, files: List[UploadFile], session: Session): """ 上传项目的图片 :param files: 上传的图片 + :param img_type: 上传的图片类别 :param project_info: 项目信息 :param session: :return: @@ -56,6 +57,7 @@ def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], image = ProjectImage() image.project_id = project_info.id image.file_name = file.filename + image.img_type = img_type # 保存原图 path = os.save_images(images_url, project_info.project_no, file=file) image.image_url = path @@ -124,11 +126,9 @@ def run_train_yolo(project_info: ProjectInfoOut, session: Session): :param session: 数据库session :return: """ - # 先获取项目的所有图片 - project_images = pimc.get_images(project_info.id, session) - - # 将图片根据,根据3:1的比例将图片分成train:val的两个数组 - project_images_train, project_images_val = split_array(project_images) + # 先查询两个图片列表 + project_images_train = pimc.get_images(project_info.id, 'train', session) + project_images_val = pimc.get_images(project_info.id, 'val', session) # 得到训练版本 version_path = 'v' + str(project_info.train_version + 1) diff --git a/app/util/ps_util.py b/app/util/ps_util.py new file mode 100644 index 0000000..64a2651 --- /dev/null +++ b/app/util/ps_util.py @@ -0,0 +1,77 @@ +import psutil +import platform +import json +from datetime import datetime + + +def get_server_info(): + + info = {} + + # 1. 系统基本信息 + info["system"] = { + "OS": platform.system(), + "OS版本": platform.version(), + "主机名": platform.node(), + "架构": platform.machine(), + "处理器型号": platform.processor(), + "启动时间": datetime.fromtimestamp(psutil.boot_time()).strftime("%Y-%m-%d %H:%M:%S") + } + + # 2. CPU 信息 + cpu_usage = psutil.cpu_percent(interval=1) + info["cpu"] = { + "物理核心数": psutil.cpu_count(logical=False), + "逻辑核心数": psutil.cpu_count(logical=True), + "当前使用率 (%)": cpu_usage, + "每个核心使用率": psutil.cpu_percent(interval=1, percpu=True) + } + + # 3. 内存信息 + mem = psutil.virtual_memory() + info["memory"] = { + "总内存 (GB)": round(mem.total / (1024**3), 2), + "可用内存 (GB)": round(mem.available / (1024**3), 2), + "使用率 (%)": mem.percent + } + + # 4. 磁盘信息 + disks = [] + for partition in psutil.disk_partitions(): + usage = psutil.disk_usage(partition.mountpoint) + disks.append({ + "设备": partition.device, + "挂载点": partition.mountpoint, + "文件系统": partition.fstype, + "总空间 (GB)": round(usage.total / (1024**3), 2), + "已用空间 (GB)": round(usage.used / (1024**3), 2), + "使用率 (%)": usage.percent + }) + info["disks"] = disks + + # 5. 网络信息 + net = psutil.net_io_counters() + info["network"] = { + "发送流量 (MB)": round(net.bytes_sent / (1024**2), 2), + "接收流量 (MB)": round(net.bytes_recv / (1024**2), 2) + } + + # 6. 进程信息(示例:前5个高CPU进程) + processes = [] + for proc in psutil.process_iter(['pid', 'name', 'cpu_percent']): + if len(processes) >= 5: + break + if proc.info['cpu_percent'] > 0: + processes.append({ + "PID": proc.info['pid'], + "进程名": proc.info['name'], + "CPU使用率 (%)": proc.info['cpu_percent'] + }) + info["top_processes"] = sorted(processes, key=lambda x: x["CPU使用率 (%)"], reverse=True) + + return info + + +def get_server_json(): + server_info = get_server_info() + return json.dumps(server_info, indent=2, ensure_ascii=False) \ No newline at end of file diff --git a/app/websocket/web_socket_server.py b/app/websocket/web_socket_server.py new file mode 100644 index 0000000..24b29fe --- /dev/null +++ b/app/websocket/web_socket_server.py @@ -0,0 +1,37 @@ +from fastapi import WebSocket + + +class SocketManager: + def __init__(self): + self.rooms = {} + + async def add_to_room(self, room: str, websocket: WebSocket): + if room not in self.rooms: + self.rooms[room] = [] + self.rooms[room].append(websocket) + + async def remove_from_room(self, room: str, websocket: WebSocket): + if room in self.rooms: + self.rooms[room].remove(websocket) + if len(self.rooms[room]) == 0: + del self.rooms[room] + + async def broadcast_to_room(self, room: str, message: str, exclude_websocket: WebSocket = None): + if room in self.rooms: + for ws in self.rooms[room]: + if ws != exclude_websocket: + try: + await ws.send_text(message) + except: + await self.remove_from_room(room, ws) + + async def send_to_room(self, room: str, message: str): + if room in self.rooms: + for ws in self.rooms[room]: + try: + await ws.send_text(message) + except: + await self.remove_from_room(room, ws) + + +room_manager = SocketManager() diff --git a/requirements.txt b/requirements.txt index 5363578..dca5744 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ bcrypt==3.2.0 pymysql==1.0.2 pynvml==12.0.0 requests-toolbelt==1.0.0 +python-socketio == 5.12.1 # YOLOV5 ----------------------------------------------------------------------