From 31302bcd17e1d6456e8e3ccd6a4c8805a4af7b98 Mon Sep 17 00:00:00 2001 From: sunyugang Date: Wed, 19 Feb 2025 11:47:33 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E5=9F=BA=E7=A1=80=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/business/project_api.py | 86 ++++++++++++++++++++++ app/api/common/upload_file.py | 11 +-- app/api/sys/sys_user_api.py | 10 +-- app/application/app.py | 15 +++- app/application/exception_handler.py | 10 +++ app/application/token_middleware.py | 4 +- app/common/jwt_check.py | 9 +++ app/common/reponse_code.py | 3 +- app/config/application_config_dev.ini | 9 ++- app/config/application_config_pro.ini | 11 --- app/config/application_config_prod.ini | 18 +++++ app/config/config_reader.py | 9 ++- app/db/db_base.py | 4 - app/model/bussiness_model.py | 42 +++++++++++ app/model/crud/project_image_crud.py | 35 +++++++++ app/model/crud/project_info_crud.py | 42 +++++++++++ app/model/crud/project_label_crud.py | 72 ++++++++++++++++++ app/model/crud/project_type_crud.py | 13 ++++ app/model/crud/sys_user_crud.py | 14 ++-- app/model/model.py | 29 -------- app/model/schemas/project_image_schemas.py | 15 ++++ app/model/schemas/project_info_schemas.py | 32 ++++++++ app/model/schemas/project_label_schemas.py | 12 +++ app/model/schemas/project_type_schemas.py | 26 +++++++ app/model/schemas/sys_user_schemas.py | 7 +- app/model/sys_model.py | 15 ++++ app/service/project_service.py | 51 +++++++++++++ app/util/os_utils.py | 43 +++++++++++ app/util/random_utils.py | 12 +++ requirements.txt | 3 +- 30 files changed, 588 insertions(+), 74 deletions(-) create mode 100644 app/api/business/project_api.py create mode 100644 app/application/exception_handler.py delete mode 100644 app/config/application_config_pro.ini create mode 100644 app/config/application_config_prod.ini create mode 100644 app/model/bussiness_model.py create mode 100644 app/model/crud/project_image_crud.py create mode 100644 app/model/crud/project_info_crud.py create mode 100644 app/model/crud/project_label_crud.py create mode 100644 app/model/crud/project_type_crud.py delete mode 100644 app/model/model.py create mode 100644 app/model/schemas/project_image_schemas.py create mode 100644 app/model/schemas/project_info_schemas.py create mode 100644 app/model/schemas/project_label_schemas.py create mode 100644 app/model/schemas/project_type_schemas.py create mode 100644 app/model/sys_model.py create mode 100644 app/service/project_service.py create mode 100644 app/util/os_utils.py create mode 100644 app/util/random_utils.py diff --git a/app/api/business/project_api.py b/app/api/business/project_api.py new file mode 100644 index 0000000..5b55c2b --- /dev/null +++ b/app/api/business/project_api.py @@ -0,0 +1,86 @@ +from app.model.crud import project_type_crud as ptc +from app.model.crud import project_label_crud as plc +from app.model.crud import project_info_crud as pic +from app.service import project_service as ps +from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager +from app.model.schemas.project_label_schemas import ProjectLabel +from app.db.db_session import get_db +from app.common.jwt_check import get_user_id +from app.common import reponse_code as rc + +from fastapi import APIRouter, Depends, Request +from sqlalchemy.orm import Session + +"""项目管理API""" +project = APIRouter() + + +@project.get("/types") +def get_type_list(session: Session = Depends(get_db)): + """获取项目类别""" + type_list = ptc.get_list(session) + return rc.response_success(data=type_list) + + +@project.post("/list") +def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)): + pager = pic.get_project_pager(info, session) + return rc.response_success_pager(pager) + + +@project.post("/add") +def add_project(request: Request, info: ProjectInfoIn, session: Session = Depends(get_db)): + """新建项目""" + if pic.check_project_name(info.project_name, session): + return rc.response_error("已经存在相同名称的项目") + user_id = get_user_id(request) + ps.add_project(info, session, user_id) + return rc.response_success(msg="新建成功") + + +@project.post("/add_label") +def add_label(label: ProjectLabel, session: Session = Depends(get_db)): + """ + 新增标签 + :param label: + :param session: + :return: + """ + if plc.check_label_name(label.project_id, label.label_name, session): + return rc.response_error("标签名称已经存在,不能重复") + label_save = ProjectLabel(**label.dict()) + plc.add_label(label_save, session) + return rc.response_success(msg="保存成功") + + +@project.post("/up_label") +def up_label(label: ProjectLabel, session: Session = Depends(get_db)): + """ + 修改标签 + :param label: + :param session: + :return: + """ + if plc.check_label_name(label.project_id, label.label_name, session, label.id): + return rc.response_error("修改的标签名称已经存在,不能重复") + label_save = ProjectLabel(**label.dict()) + plc.update_label(label_save, session) + return rc.response_success(msg="修改成功") + + +@project.post("/del_label") +def del_label(label: ProjectLabel, session: Session = Depends(get_db)): + """ + 删除标签 + :param label: + :param session: + :return: + """ + row_del = plc.update_label(label.id, session) + if row_del > 0: + return rc.response_success(msg="删除成功") + else: + return rc.response_error("删除失败") + + + diff --git a/app/api/common/upload_file.py b/app/api/common/upload_file.py index f417836..e379b1b 100644 --- a/app/api/common/upload_file.py +++ b/app/api/common/upload_file.py @@ -1,17 +1,18 @@ from typing import List - from fastapi import APIRouter from fastapi import UploadFile import os -upload = APIRouter() +from app.config.config_reader import images_url + +upload_files = APIRouter() -@upload.post("/") -def upload_file(files: List[UploadFile]): +@upload_files.post("/") +def upload(files: List[UploadFile], project_no: str): paths = [] for file in files: - path = os.path.join("images", file.filename) + path = os.path.join(images_url, project_no, file.filename) with open(path, "wb") as f: for line in file.file: f.write(line) diff --git a/app/api/sys/sys_user_api.py b/app/api/sys/sys_user_api.py index 7f10082..5d84f4f 100644 --- a/app/api/sys/sys_user_api.py +++ b/app/api/sys/sys_user_api.py @@ -1,8 +1,8 @@ from fastapi import APIRouter, Depends -from app.model.schemas.sys_user_schemas import SysUserOut, SysUserIN, SysUserPager +from app.model.schemas.sys_user_schemas import SysUserOut, SysUserIn, SysUserPager from app.common import reponse_code as rc from app.model.crud import sys_user_crud as us -from app.model.model import SysUser +from app.model.sys_model import SysUser from app.common.redis_cli import redis_conn from sqlalchemy.orm import Session @@ -20,7 +20,7 @@ def user_pager(user: SysUserPager, session: Session = Depends(get_db)): @user.post("/") -def add_user(user: SysUserIN, session: Session = Depends(get_db)): +def add_user(user: SysUserIn, session: Session = Depends(get_db)): """ 新增用户 :param session: @@ -30,7 +30,7 @@ def add_user(user: SysUserIN, session: Session = Depends(get_db)): if us.check_username(user.username, session): return rc.response_error(msg="该用户名已存在!") else: - user_in= SysUser(**user.dict()) + user_in = SysUser(**user.dict()) user_in.user_status = '0' if us.add_user(user_in, session): return rc.response_success(msg="保存成功") @@ -49,7 +49,7 @@ def get_user(id: int, session: Session = Depends(get_db)): user = us.get_user_by_id(id, session) if user is None: return rc.response_success(data=None) - user_out = SysUserOut(**dict(user)) + user_out = SysUserOut.from_orm(user) return rc.response_success(data=user_out.dict()) diff --git a/app/application/app.py b/app/application/app.py index a57b47b..cac7628 100644 --- a/app/application/app.py +++ b/app/application/app.py @@ -4,6 +4,11 @@ from fastapi.middleware.cors import CORSMiddleware from app.application.token_middleware import TokenMiddleware from app.application.logger_middleware import LoggerMiddleware +from app.api.common.upload_file import upload_files +from app.api.sys.login_api import login +from app.api.sys.sys_user_api import user +from app.api.business.project_api import project + my_app = FastAPI() @@ -18,6 +23,14 @@ my_app.add_middleware( allow_headers=["*"], ) -#注意中间的顺序,这个地方是倒序执行的 +''' +注意中间的顺序,这个地方是倒序执行的 +''' my_app.add_middleware(LoggerMiddleware) my_app.add_middleware(TokenMiddleware) + +my_app.include_router(user, prefix="/user", tags=["用户管理API"]) +my_app.include_router(login, prefix="/login", tags=["用户登录接口"]) +my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"]) +my_app.include_router(project, prefix="/proj", tags=["项目管理API"]) + diff --git a/app/application/exception_handler.py b/app/application/exception_handler.py new file mode 100644 index 0000000..f5b394f --- /dev/null +++ b/app/application/exception_handler.py @@ -0,0 +1,10 @@ +from fastapi import HTTPException, Request +from app.common.reponse_code import response_error + +from app import my_app +"""全局异常处理""" + + +@my_app.exception_handlers(HTTPException) +async def http_exception(request: Request, he: HTTPException): + return response_error(request.url + "出现异常:" + he.detail) \ No newline at end of file diff --git a/app/application/token_middleware.py b/app/application/token_middleware.py index 922e850..15e88d5 100644 --- a/app/application/token_middleware.py +++ b/app/application/token_middleware.py @@ -2,8 +2,8 @@ from fastapi import status from starlette.middleware.base import BaseHTTPMiddleware from urllib.request import Request from jwt import PyJWTError -from common import reponse_code as rc -from common import jwt_check as jc +from app.common import reponse_code as rc +from app.common import jwt_check as jc diff --git a/app/common/jwt_check.py b/app/common/jwt_check.py index f48f912..5510357 100644 --- a/app/common/jwt_check.py +++ b/app/common/jwt_check.py @@ -2,6 +2,7 @@ import datetime import jwt from app.common.redis_cli import redis_conn +from fastapi import Request # 过期时间,单位S exp = 6000 @@ -48,3 +49,11 @@ def check_token(token: str): raise jwt.ExpiredSignatureError("Expired Token") except jwt.InvalidTokenError: raise jwt.InvalidTokenError("Invalid Token") + + +def get_user_id(request: Request): + """根据Request请求获取token""" + token = request.headers.get("Authorization") + decoded_payload = check_token(token) + user_id = decoded_payload['user_id'] + return user_id \ No newline at end of file diff --git a/app/common/reponse_code.py b/app/common/reponse_code.py index fe96f9b..ebaa3ee 100644 --- a/app/common/reponse_code.py +++ b/app/common/reponse_code.py @@ -12,6 +12,7 @@ def response_code_view(code: int,msg: str) -> Response: } ) + def response_success(msg: str = "查询成功", data: object = None): return JSONResponse( status_code=status.HTTP_200_OK, @@ -35,7 +36,7 @@ def response_success_pager(pager: Pager): ) -def response_error(msg:str): +def response_error(msg: str): return JSONResponse( status_code=status.HTTP_200_OK, content={ diff --git a/app/config/application_config_dev.ini b/app/config/application_config_dev.ini index 12c1c87..bcb9855 100644 --- a/app/config/application_config_dev.ini +++ b/app/config/application_config_dev.ini @@ -8,4 +8,11 @@ db = 0 password = sdust2020 [log] -dir = D:\syg\workspace\logs \ No newline at end of file +dir = D:\syg\workspace\logs + +[yolo] +datasets_url = D:\syg\yolov5\datasets +runs_url = D:\syg\yolov5\runs + +[images] +image_url = D:\syg\images \ No newline at end of file diff --git a/app/config/application_config_pro.ini b/app/config/application_config_pro.ini deleted file mode 100644 index 3a382ba..0000000 --- a/app/config/application_config_pro.ini +++ /dev/null @@ -1,11 +0,0 @@ -[mysql] -database_url = mysql+pymysql://root:root@localhost:3306/sun - -[redis] -host = localhost -port = 6379 -db = 0 -password = 123456 - -[log] -dir = /Users/macbookpro/sunyg/workspace/logs \ No newline at end of file diff --git a/app/config/application_config_prod.ini b/app/config/application_config_prod.ini new file mode 100644 index 0000000..7f0c086 --- /dev/null +++ b/app/config/application_config_prod.ini @@ -0,0 +1,18 @@ +[mysql] +database_url = mysql+pymysql://root:root@localhost:3306/sun + +[redis] +host = localhost +port = 6379 +db = 0 +password = 123456 + +[log] +dir = /Users/macbookpro/sunyg/workspace/logs + +[yolo] +datasets_url = /home/yolov5/datasets +runs_url = /home/yolov5/runs + +[images] +image_url = /home/images \ No newline at end of file diff --git a/app/config/config_reader.py b/app/config/config_reader.py index ca90e93..ea6c15c 100644 --- a/app/config/config_reader.py +++ b/app/config/config_reader.py @@ -3,6 +3,8 @@ import os env = "dev" +# env = "prod" + script_directory = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(script_directory, f'application_config_{env}.ini') @@ -18,4 +20,9 @@ redis_port = config.get('redis', 'port') redis_db = config.get('redis', 'db') redis_password = config.get('redis', 'password') -log_dir = config.get('log', 'dir') \ No newline at end of file +log_dir = config.get('log', 'dir') + +datasets_url = config.get('yolo', 'datasets_url') +runs_url = config.get('yolo', 'runs_url') + +images_url = config.get('images', 'image_url') \ No newline at end of file diff --git a/app/db/db_base.py b/app/db/db_base.py index a141492..de04cb7 100644 --- a/app/db/db_base.py +++ b/app/db/db_base.py @@ -11,10 +11,6 @@ class DbCommon(Base): create_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow()) update_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow(), onupdate=datetime.datetime.utcnow()) - def keys(self): - return ["id", "create_time", "update_time"] - - @declared_attr def __tablename__(cls): return cls.__name__.lower() diff --git a/app/model/bussiness_model.py b/app/model/bussiness_model.py new file mode 100644 index 0000000..86787a1 --- /dev/null +++ b/app/model/bussiness_model.py @@ -0,0 +1,42 @@ +from app.db.db_base import DbCommon +from sqlalchemy import String, Integer +from sqlalchemy.orm import Mapped, mapped_column + + +class ProjectType(DbCommon): + """ + 项目类别表 - 标识项目的类型目前存在的(目标识别,OCR识别,瑕疵检测,图像分类) + """ + __tablename__ = "project_type" + type_code = Mapped[str] = mapped_column(String(20), unique=True, nullable=False) + type_name = Mapped[str] = mapped_column(String(20)) + icon_path = Mapped[str] = mapped_column(String(255)) + description = Mapped[str] = mapped_column(String(255)) + type_status = Mapped[str] = mapped_column(String(10)) + + +class ProjectInfo(DbCommon): + """项目信息表""" + __tablename__ = "project_info" + project_no = Mapped[str] = mapped_column(String(32), unique=True, nullable=False) + project_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False) + type_code = Mapped[str] = mapped_column(String(10)) + description = Mapped[str] = mapped_column(String(255)) + project_status = Mapped[str] = mapped_column(String(10)) + user_id = Mapped[int] = mapped_column(Integer) + train_version = Mapped[int] = mapped_column(Integer) + + +class ProjectLabel(DbCommon): + """项目标签表""" + __tablename__ = "project_label" + label_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False) + project_id = Mapped[int] = mapped_column(Integer, nullable=False) + + +class ProjectImage(DbCommon): + """项目图片表""" + __tablename__ = "project_image" + image_url = Mapped[str] = mapped_column(String(255), nullable=False) + thumb_image_url = Mapped[str] = mapped_column(String(255), nullable=False) + project_id = Mapped[int] = mapped_column(Integer) diff --git a/app/model/crud/project_image_crud.py b/app/model/crud/project_image_crud.py new file mode 100644 index 0000000..6573116 --- /dev/null +++ b/app/model/crud/project_image_crud.py @@ -0,0 +1,35 @@ +from sqlalchemy.orm import Session +from sqlalchemy import asc + +from app.model.bussiness_model import ProjectImage as piModel +from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager +from app.db.page_util import get_pager + + +def get_image_pager(image: ProjectImagePager, session: Session): + query = session.query(piModel).filter(piModel.project_id == image.project_id).order_by(asc(piModel.id)) + pager = get_pager(query, image.pagerNum, image.pagerSize) + pager.data = [ProjectImage.from_orm(image) for image in pager.data] + return pager + + +def get_image_list(image: ProjectImage, session: Session): + query = session.query(piModel).filter(piModel.project_id == image.project_id).order_by(asc(piModel.id)) + image_list = [ProjectImage.from_orm(image) for image in query.all()] + return image_list + + +def add_image(image: ProjectImage, session: Session): + session.add(image) + session.commit() + return image + + +def del_image(id: str, session: Session): + row_del = session.query(piModel).filter_by(id=id).delete() + session.commit() + return row_del + + + + diff --git a/app/model/crud/project_info_crud.py b/app/model/crud/project_info_crud.py new file mode 100644 index 0000000..a3a9418 --- /dev/null +++ b/app/model/crud/project_info_crud.py @@ -0,0 +1,42 @@ +from sqlalchemy.orm import Session +from sqlalchemy import desc + +from app.model.bussiness_model import ProjectInfo +from app.model.schemas.project_info_schemas import ProjectInfoOut +from app.model.schemas.project_info_schemas import ProjectInfoPager +from app.db.page_util import get_pager + + +def get_project_pager(info: ProjectInfoPager, session: Session): + """分页查询项目信息""" + query = session.query(ProjectInfo).order_by(desc(ProjectInfo.id)) + filters = [] + if info.project_name is not None: + filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%")) + if len(filters) > 0: + query.filter(*filters) + pager = get_pager(query, info.pagerNum, info.pagerSize) + pager.data = [ProjectInfoOut.from_orm(info).dict() for info in pager.data] + return pager + + +def get_project_by_id(id: str, session: Session): + info = session.query(ProjectInfo).filter_by(id=id).first() + info_out = ProjectInfoOut.from_orm(info) + return info_out + +def add_project(info: ProjectInfo, session: Session): + """新建项目,并在对应文件夹下面创建文件夹""" + session.add(info) + session.commit() + return info + + +def check_project_name(project_name: str, session: Session): + """检验是否存在重名的项目名称""" + count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name).count() + if count > 0: + return True + else: + return False + diff --git a/app/model/crud/project_label_crud.py b/app/model/crud/project_label_crud.py new file mode 100644 index 0000000..9c0d674 --- /dev/null +++ b/app/model/crud/project_label_crud.py @@ -0,0 +1,72 @@ +from sqlalchemy.orm import Session + +from app.model.bussiness_model import ProjectLabel as plModel +from app.model.schemas.project_label_schemas import ProjectLabel + + +def get_label_list(project_id: int, session: Session): + """ + 根绝项目id获取标签列表 + :param project_id: 项目id + :param session: + :return: + """ + label_list = session.query(plModel).filter(plModel.project_id == project_id).all() + label_list = [ProjectLabel.from_orm(label) for label in label_list] + return label_list + + +def add_label(label: plModel, session: Session): + """ + 新增标签 + :param label: + :param session: + :return: + """ + session.add(label) + session.commit() + return label.id + + +def check_label_name(project_id: int, label_name: str, session: Session, label_id: int = None): + """ + 检验标签名称是否存在 + :param label_id: + :param project_id: 项目id + :param label_name: 标签名称 + :param session: + :return: + """ + query = session.query(plModel) + filters = [plModel.project_id == project_id, plModel.label_name == label_name] + if label_id is not None: + filters.append(plModel.id != label_id) + query.filter(*filters) + if query.count() > 0: + return False + else: + return True + + +def update_label(label: plModel, session: Session): + """ + 修改标签名称 + :param label: + :param session: + :return: + """ + session.query(plModel).filter_by(id=label.id).update({ + "label_name": label.label_name + }) + session.commit() + + +def del_label(id: str, session: Session): + """ + 根据标签id删除标签 + :param id: 标签id + :param session: + :return: + """ + row_del = session.query(plModel).filter_by(id=id).delete() + return row_del diff --git a/app/model/crud/project_type_crud.py b/app/model/crud/project_type_crud.py new file mode 100644 index 0000000..291687f --- /dev/null +++ b/app/model/crud/project_type_crud.py @@ -0,0 +1,13 @@ +from app.model.bussiness_model import ProjectType +from app.model.schemas.project_type_schemas import ProjectTypeOut + +from sqlalchemy import asc +from sqlalchemy.orm import Session + + +def get_list(session: Session): + """获取项目类型列表""" + query = session.query(ProjectType).order_by(asc(ProjectType.id)) + query.filter(ProjectType.type_status == "0") + result_list = [ProjectTypeOut.from_orm(project_type).dict() for project_type in query.all()] + return result_list diff --git a/app/model/crud/sys_user_crud.py b/app/model/crud/sys_user_crud.py index ad8545c..9583807 100644 --- a/app/model/crud/sys_user_crud.py +++ b/app/model/crud/sys_user_crud.py @@ -1,23 +1,21 @@ -from app.model.model import SysUser +from app.model.sys_model import SysUser from app.model.schemas.sys_user_schemas import SysUserPager, SysUserOut from app.common.bcrypt_pw import hash_password from app.db.page_util import get_pager -from sqlalchemy import and_, desc +from sqlalchemy import and_, asc from sqlalchemy.orm import Session def user_pager(user: SysUserPager, session: Session): - query = session.query(SysUser).order_by(desc(SysUser.id)) + query = session.query(SysUser).order_by(asc(SysUser.id)) filters = [] if user.username is not None: filters.append(SysUser.username.ilike(f"%{user.username}%")) - if user.dept_id is not None: - filters.append(SysUser.dept_id == user.dept_id) if len(filters) > 0: query.filter(and_(*filters)) pager = get_pager(query, user.pagerNum, user.pagerSize) - pager.data = [SysUserOut.from_orm(user) for user in pager.data] + pager.data = [SysUserOut.from_orm(user).dict() for user in pager.data] return pager @@ -35,13 +33,13 @@ def get_user_by_id(id: int, session: Session): def stop_user(user: SysUser, session: Session): user.user_status = "1" - session.commit(); + session.commit() return user def start_user(user: SysUser, session: Session): user.user_status = "0" - session.commit(); + session.commit() return user diff --git a/app/model/model.py b/app/model/model.py deleted file mode 100644 index 79ce233..0000000 --- a/app/model/model.py +++ /dev/null @@ -1,29 +0,0 @@ -from app.db.db_base import DbCommon -from sqlalchemy import String,Integer -from sqlalchemy.orm import Mapped, mapped_column - - -class SysUser(DbCommon): - - __tablename__ = "sys_user" - username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False) - password: Mapped[str] = mapped_column(String(255)) - dept_id: Mapped[int] = mapped_column(Integer) - login_name: Mapped[str] = mapped_column(String(255)) - avatar: Mapped[str] = mapped_column(String(255)) - user_status: Mapped[str] = mapped_column(String(10)) - - - def keys(self): - keys = ["username", "password", "dept_id", "login_name", "avatar", "user_status"] - keys.extend(super().keys()) - return keys - - - def __getitem__(self, item): - ''' - 内置方法, 当使用obj['name']的形式的时候, 将调用这个方法, 这里返回的结果就是值 - :param item: - :return: - ''' - return getattr(self, item, None) diff --git a/app/model/schemas/project_image_schemas.py b/app/model/schemas/project_image_schemas.py new file mode 100644 index 0000000..206d809 --- /dev/null +++ b/app/model/schemas/project_image_schemas.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel, Field +from typing import Optional + + +class ProjectImage(BaseModel): + id: Optional[int] = Field(None, description="id") + project_id: Optional[int] = Field(..., description="项目id") + image_url: Optional[str] = Field(..., description="原图路径") + thumb_image_url: Optional[str] = Field(..., description="缩略图路径") + + +class ProjectImagePager(BaseModel): + project_id: Optional[int] = Field(..., description="项目id") + pagerNum: Optional[int] = Field(1, description="当前页码") + pagerSize: Optional[int] = Field(10, description="每页数量") diff --git a/app/model/schemas/project_info_schemas.py b/app/model/schemas/project_info_schemas.py new file mode 100644 index 0000000..9392ad7 --- /dev/null +++ b/app/model/schemas/project_info_schemas.py @@ -0,0 +1,32 @@ +from pydantic import BaseModel, Field +from typing import Optional + + +class ProjectInfoIn(BaseModel): + """项目信息输入""" + id: Optional[int] = Field(None, description="项目id") + project_name: Optional[str] = Field(..., description="项目名称") + type_code: Optional[str] = Field(..., description="项目类型编码") + description: Optional[str] = Field(None, description="项目描述") + + +class ProjectInfoOut(BaseModel): + """项目信息输出""" + id: Optional[int] = Field(None, description="项目id") + project_no: Optional[str] = Field(..., description="项目编号") + project_name: Optional[str] = Field(..., description="项目名称") + type_code: Optional[str] = Field(..., description="项目类型编码") + description: Optional[str] = Field(None, description="项目描述") + user_name: Optional[str] = Field(None, description="创建人") + train_version: Optional[int] = Field(None, description="训练版本号") + project_status: Optional[str] = Field(None, description="项目状态") + + class Config: + orm_mode = True + + +class ProjectInfoPager(BaseModel): + project_name: Optional[str] = Field(None, description="项目名称") + pagerNum: Optional[int] = Field(1, description="当前页码") + pagerSize: Optional[int] = Field(10, description="每页数量") + diff --git a/app/model/schemas/project_label_schemas.py b/app/model/schemas/project_label_schemas.py new file mode 100644 index 0000000..4ee4a0e --- /dev/null +++ b/app/model/schemas/project_label_schemas.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel, Field +from typing import Optional + + +class ProjectLabel(BaseModel): + """项目标签输入输出""" + id: Optional[int] = Field(None, description="id") + project_id: Optional[int] = Field(..., description="项目id") + label_name: Optional[str] = Field(..., description="标签名称") + + class Config: + orm_mode = True diff --git a/app/model/schemas/project_type_schemas.py b/app/model/schemas/project_type_schemas.py new file mode 100644 index 0000000..2f45081 --- /dev/null +++ b/app/model/schemas/project_type_schemas.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional + + +class ProjectTypeIn(BaseModel): + """ + 项目类型输入 + """ + type_code: Optional[str] = Field(..., description="类型code", max_length=20) + type_name: Optional[str] = Field(..., description="类型名称", max_length=20, min_length=4) + icon_path: Optional[str] = Field(None, description="iconPath", max_length=255) + description: Optional[str] = Field(None, description="类型描述", max_length=255) + + +class ProjectTypeOut(BaseModel): + """ + 项目类型输出 + """ + id: Optional[int] = Field(..., description="id") + type_code: Optional[str] = Field(..., description="类型code", max_length=20) + type_name: Optional[str] = Field(..., description="类型名称", max_length=20, min_length=4) + icon_path: Optional[str] = Field(None, description="iconPath", max_length=255) + description: Optional[str] = Field(None, description="类型描述", max_length=255) + + class Config: + orm_mode = True diff --git a/app/model/schemas/sys_user_schemas.py b/app/model/schemas/sys_user_schemas.py index 91b7a5b..ce2f61d 100644 --- a/app/model/schemas/sys_user_schemas.py +++ b/app/model/schemas/sys_user_schemas.py @@ -3,10 +3,9 @@ from typing import Optional # 用户相关的原型 -class SysUserIN(BaseModel): +class SysUserIn(BaseModel): username: Optional[str] = Field(..., description="用户名", max_length=50) password: Optional[str] = Field(..., description="密码", max_length=30, min_length=6) - dept_id: Optional[str] = Field(None, description="部门id") login_name: Optional[str] = Field(None, description="昵称", max_length=20) @@ -18,9 +17,8 @@ class SysUserLogin(BaseModel): class SysUserOut(BaseModel): id: Optional[int] = Field(..., description="id") username: Optional[str] = Field(..., description="用户名") - dept_id: Optional[str] = Field(None, description="部门id") - dept_name: Optional[str] = Field(None, description="部门名称") login_name: Optional[str] = Field(None, description="昵称") + user_status: Optional[str] = Field(None, description="用户状态") class Config: orm_mode = True @@ -34,7 +32,6 @@ class SysUserUpdatePw(BaseModel): class SysUserPager(BaseModel): username: Optional[str] = Field(None, description="用户名") - dept_id: Optional[str] = Field(None, description="部门id") login_name: Optional[str] = Field(None, description="昵称") pagerNum: Optional[int] = Field(1, description="当前页码") pagerSize: Optional[int] = Field(10, description="每页数量") diff --git a/app/model/sys_model.py b/app/model/sys_model.py new file mode 100644 index 0000000..cfa70cd --- /dev/null +++ b/app/model/sys_model.py @@ -0,0 +1,15 @@ +from app.db.db_base import DbCommon +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column + + +class SysUser(DbCommon): + """ + 用户表 - 保存用户基本信息 + """ + __tablename__ = "sys_user" + username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False) + password: Mapped[str] = mapped_column(String(255)) + login_name: Mapped[str] = mapped_column(String(255)) + avatar: Mapped[str] = mapped_column(String(255)) + user_status: Mapped[str] = mapped_column(String(10)) diff --git a/app/service/project_service.py b/app/service/project_service.py new file mode 100644 index 0000000..35c3ac7 --- /dev/null +++ b/app/service/project_service.py @@ -0,0 +1,51 @@ +from app.model.bussiness_model import ProjectImage +from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut +from app.model.bussiness_model import ProjectInfo +from app.util import random_utils as ru +from app.config.config_reader import datasets_url, runs_url, images_url +from app.model.crud import project_info_crud as pic +from app.util import os_utils as os + +from sqlalchemy.orm import Session +from typing import List +from fastapi import UploadFile + + +def add_project(info: ProjectInfoIn, session: Session, user_id: int): + """ + 新建项目,完善数据,并创建对应的文件夹 + :param info: 项目信息 + :param session: 数据库session + :param user_id: 用户id + :return: + """ + project_info = ProjectInfo(**info.dict()) + project_info.user_id = user_id + project_info.project_no = ru.random_str(6) + project_info.project_status = "0" + project_info.train_version = 0 + os.create_folder(datasets_url, project_info.project_no) + os.create_folder(runs_url, project_info.project_no) + pic.add_project(project_info, session) + return project_info + + +def upload_project_image(session: Session, project_info: ProjectInfoOut, files: List[UploadFile]): + """ + 上传项目的图片 + :param files: 上传的图片 + :param project_info: 项目信息 + :param image: + :param session: + :return: + """ + for file in files: + image = ProjectImage() + image.project_id = project_info.id + # 保存原图 + path = os.save_images(images_url, project_info.project_no, file=file) + image.image_url = path + + + + diff --git a/app/util/os_utils.py b/app/util/os_utils.py new file mode 100644 index 0000000..ef798c2 --- /dev/null +++ b/app/util/os_utils.py @@ -0,0 +1,43 @@ +import os +from fastapi import UploadFile +from PIL import Image + + +def create_folder(*path): + """根据路径创建文件夹""" + folder_path = os.path.join(*path) + try: + os.makedirs(path, exist_ok=True) + except Exception as e: + print(f"创建文件夹时错误: {e}") + + +def save_images(*path, file: UploadFile): + """ + 保存上传的图片 + :param path: 路径 + :param file: 文件 + :return: + """ + save_path = os.path.join(*path, file.filename) + with open(save_path, "wb") as f: + for line in file.file: + f.write(line) + return save_path + + +def create_thumbnail(input_image_path, out_image_path, size=(116, 70)): + """ + 给图片生成缩略图 + :param input_image_path: + :param out_image_path: + :param size: 缩略的尺寸 + :return: + """ + with Image.open(input_image_path) as image: + # 使用thumbnail方法生成缩略图,参数size指定缩略图的最大尺寸 + # 注意:thumbnail方法会保持图片的宽高比不变 + image.thumbnail(size) + + # 保存生成的缩略图 + image.save(out_image_path, 'JPEG') diff --git a/app/util/random_utils.py b/app/util/random_utils.py new file mode 100644 index 0000000..825d46c --- /dev/null +++ b/app/util/random_utils.py @@ -0,0 +1,12 @@ +import random +import string + + +def random_str(length=10): + """随机生成自定义长度的小写字母""" + letters = string.ascii_lowercase + # 使用 random.choices 从 letters 中随机选择 length 个字母,返回一个列表 + random_letters = random.choices(letters, k=length) + # 将列表中的字母连接成一个字符串 + return ''.join(random_letters) + diff --git a/requirements.txt b/requirements.txt index 740fcd7..5363578 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -# aicheck_base requirements +# aicheck_v2 requirements # Usage: pip install -r requirements.txt # API ------------------------------------------------------------------------- @@ -17,6 +17,7 @@ pymysql==1.0.2 pynvml==12.0.0 requests-toolbelt==1.0.0 + # YOLOV5 ---------------------------------------------------------------------- # BASE ------------------------------------------------------------------------ gitpython>=3.1.30