From bed123c532e104d1cb3b0ddca92a2620f44fbd3f Mon Sep 17 00:00:00 2001 From: sunyugang Date: Wed, 19 Feb 2025 16:57:49 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E5=9F=BA=E7=A1=80=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/business/project_api.py | 56 +++++++++++++++++++--- app/api/common/view_img.py | 21 ++++++++ app/api/sys/sys_user_api.py | 4 +- app/application/app.py | 5 +- app/application/token_middleware.py | 14 +++++- app/model/bussiness_model.py | 37 +++++++------- app/model/crud/project_image_crud.py | 17 +++++-- app/model/crud/project_info_crud.py | 5 +- app/model/crud/project_label_crud.py | 22 +++++---- app/model/crud/sys_user_crud.py | 6 +-- app/model/schemas/project_image_schemas.py | 3 ++ app/model/schemas/project_label_schemas.py | 3 +- app/service/project_service.py | 17 +++++-- app/util/os_utils.py | 6 ++- 14 files changed, 159 insertions(+), 57 deletions(-) create mode 100644 app/api/common/view_img.py diff --git a/app/api/business/project_api.py b/app/api/business/project_api.py index 5b55c2b..9df5ade 100644 --- a/app/api/business/project_api.py +++ b/app/api/business/project_api.py @@ -1,14 +1,17 @@ from app.model.crud import project_type_crud as ptc from app.model.crud import project_label_crud as plc from app.model.crud import project_info_crud as pic +from app.model.crud import project_image_crud as pimc from app.service import project_service as ps from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager from app.model.schemas.project_label_schemas import ProjectLabel +from app.model.bussiness_model import ProjectLabel as pl from app.db.db_session import get_db from app.common.jwt_check import get_user_id from app.common import reponse_code as rc -from fastapi import APIRouter, Depends, Request +from typing import List +from fastapi import APIRouter, Depends, Request, UploadFile, File, Form from sqlalchemy.orm import Session """项目管理API""" @@ -38,6 +41,18 @@ def add_project(request: Request, info: ProjectInfoIn, session: Session = Depend return rc.response_success(msg="新建成功") +@project.get("/label_list/{project_id}") +def get_label_list(project_id: int, session: Session = Depends(get_db)): + """ + 根据项目id查询项目标签列表 + :param project_id: + :param session: + :return: + """ + label_list = plc.get_label_list(project_id, session) + return rc.response_success(msg="查询成功", data=label_list) + + @project.post("/add_label") def add_label(label: ProjectLabel, session: Session = Depends(get_db)): """ @@ -48,7 +63,7 @@ def add_label(label: ProjectLabel, session: Session = Depends(get_db)): """ if plc.check_label_name(label.project_id, label.label_name, session): return rc.response_error("标签名称已经存在,不能重复") - label_save = ProjectLabel(**label.dict()) + label_save = pl(**label.dict()) plc.add_label(label_save, session) return rc.response_success(msg="保存成功") @@ -63,24 +78,51 @@ def up_label(label: ProjectLabel, session: Session = Depends(get_db)): """ if plc.check_label_name(label.project_id, label.label_name, session, label.id): return rc.response_error("修改的标签名称已经存在,不能重复") - label_save = ProjectLabel(**label.dict()) + label_save = pl(**label.dict()) plc.update_label(label_save, session) return rc.response_success(msg="修改成功") -@project.post("/del_label") -def del_label(label: ProjectLabel, session: Session = Depends(get_db)): +@project.post("/del_label/{label_id}") +def del_label(label_id: int, session: Session = Depends(get_db)): """ 删除标签 - :param label: + :param label_id: :param session: :return: """ - row_del = plc.update_label(label.id, session) + row_del = plc.del_label(label_id, session) if row_del > 0: return rc.response_success(msg="删除成功") else: return rc.response_error("删除失败") +@project.post("/up_proj_img") +def upload_project_image(project_id: int = Form(...), files: List[UploadFile] = File(...), session: Session = Depends(get_db)): + """ + 上传项目图片 + :param files: 文件图片 + :param project_id: + :param session: + :return: + """ + project_info = pic.get_project_by_id(project_id, session) + if project_info is None: + return rc.response_error("项目查询错误,请刷新页面后再试") + ps.upload_project_image(project_info, files, session) + return rc.response_success(msg="上传成功") + + +@project.get("/img_list/{project_id}") +def get_image_list(project_id: int, session: Session = Depends(get_db)): + """ + 获取项目图片列表 + :param project_id: 项目id + :param session: + :return: + """ + image_list = pimc.get_image_list(project_id, session) + return rc.response_success(data=image_list) + diff --git a/app/api/common/view_img.py b/app/api/common/view_img.py new file mode 100644 index 0000000..3e0ccd4 --- /dev/null +++ b/app/api/common/view_img.py @@ -0,0 +1,21 @@ +import os +from fastapi import APIRouter, HTTPException +from starlette.responses import FileResponse + +from app.config.config_reader import images_url + +view = APIRouter() + + +@view.get("/{file_path:path}") +def view_img(file_path): + """ + 查看图片 + :param file_path: 图片路径 + :return: + """ + image_path = os.path.join(images_url, file_path) + # 检查文件是否存在以及是否是文件 + if not os.path.isfile(image_path): + raise HTTPException(status_code=404, detail="Image not found") + return FileResponse(image_path, media_type='image/jpeg') diff --git a/app/api/sys/sys_user_api.py b/app/api/sys/sys_user_api.py index 5d84f4f..e3142c7 100644 --- a/app/api/sys/sys_user_api.py +++ b/app/api/sys/sys_user_api.py @@ -19,7 +19,7 @@ def user_pager(user: SysUserPager, session: Session = Depends(get_db)): return rc.response_success_pager(pager) -@user.post("/") +@user.post("/add") def add_user(user: SysUserIn, session: Session = Depends(get_db)): """ 新增用户 @@ -81,4 +81,4 @@ def start_user(id: int, session: Session = Depends(get_db)): if user is None: return rc.response_error("用户查询错误,请稍后再试") us.start_user(user) - return rc.response_success("启用用户成功") \ No newline at end of file + return rc.response_success("启用用户成功") diff --git a/app/application/app.py b/app/application/app.py index cac7628..2922e5c 100644 --- a/app/application/app.py +++ b/app/application/app.py @@ -8,6 +8,7 @@ from app.api.common.upload_file import upload_files from app.api.sys.login_api import login from app.api.sys.sys_user_api import user from app.api.business.project_api import project +from app.api.common.view_img import view my_app = FastAPI() @@ -29,8 +30,10 @@ my_app.add_middleware( my_app.add_middleware(LoggerMiddleware) my_app.add_middleware(TokenMiddleware) -my_app.include_router(user, prefix="/user", tags=["用户管理API"]) my_app.include_router(login, prefix="/login", tags=["用户登录接口"]) my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"]) +my_app.include_router(view, prefix="/view_img", tags=["查看图片"]) +my_app.include_router(user, prefix="/user", tags=["用户管理API"]) my_app.include_router(project, prefix="/proj", tags=["项目管理API"]) + diff --git a/app/application/token_middleware.py b/app/application/token_middleware.py index 15e88d5..92f2cf9 100644 --- a/app/application/token_middleware.py +++ b/app/application/token_middleware.py @@ -6,7 +6,6 @@ from app.common import reponse_code as rc from app.common import jwt_check as jc - class TokenMiddleware(BaseHTTPMiddleware): def __init__(self, app): @@ -21,7 +20,7 @@ class TokenMiddleware(BaseHTTPMiddleware): """ token = request.headers.get('Authorization') path = request.url.path - if '/login' in path: + if check_green(path): response = await call_next(request) return response if not token: @@ -30,4 +29,15 @@ class TokenMiddleware(BaseHTTPMiddleware): jc.check_token(token) return await call_next(request) except PyJWTError as error: + print(error) return rc.response_code_view(status.HTTP_401_UNAUTHORIZED, "Token错误或失效,请重新验证") + + +green = ['/login', '/view_img'] + + +def check_green(s: str): + for url in green: + if url in s: + return True + return False diff --git a/app/model/bussiness_model.py b/app/model/bussiness_model.py index 86787a1..dc63ae2 100644 --- a/app/model/bussiness_model.py +++ b/app/model/bussiness_model.py @@ -1,5 +1,5 @@ from app.db.db_base import DbCommon -from sqlalchemy import String, Integer +from sqlalchemy import String, Integer, JSON from sqlalchemy.orm import Mapped, mapped_column @@ -8,35 +8,36 @@ class ProjectType(DbCommon): 项目类别表 - 标识项目的类型目前存在的(目标识别,OCR识别,瑕疵检测,图像分类) """ __tablename__ = "project_type" - type_code = Mapped[str] = mapped_column(String(20), unique=True, nullable=False) - type_name = Mapped[str] = mapped_column(String(20)) - icon_path = Mapped[str] = mapped_column(String(255)) - description = Mapped[str] = mapped_column(String(255)) - type_status = Mapped[str] = mapped_column(String(10)) + type_code: Mapped[str] = mapped_column(String(20), unique=True, nullable=False) + type_name: Mapped[str] = mapped_column(String(20)) + icon_path: Mapped[str] = mapped_column(String(255)) + description: Mapped[str] = mapped_column(String(255)) + type_status: Mapped[str] = mapped_column(String(10)) class ProjectInfo(DbCommon): """项目信息表""" __tablename__ = "project_info" - project_no = Mapped[str] = mapped_column(String(32), unique=True, nullable=False) - project_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False) - type_code = Mapped[str] = mapped_column(String(10)) - description = Mapped[str] = mapped_column(String(255)) - project_status = Mapped[str] = mapped_column(String(10)) - user_id = Mapped[int] = mapped_column(Integer) - train_version = Mapped[int] = mapped_column(Integer) + project_no: Mapped[str] = mapped_column(String(32), unique=True, nullable=False) + project_name: Mapped[str] = mapped_column(String(32), unique=True, nullable=False) + type_code: Mapped[str] = mapped_column(String(10)) + description: Mapped[str] = mapped_column(String(255)) + project_status: Mapped[str] = mapped_column(String(10)) + user_id: Mapped[int] = mapped_column(Integer) + train_version: Mapped[int] = mapped_column(Integer) class ProjectLabel(DbCommon): """项目标签表""" __tablename__ = "project_label" - label_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False) - project_id = Mapped[int] = mapped_column(Integer, nullable=False) + label_name: Mapped[str] = mapped_column(String(32), unique=True, nullable=False) + project_id: Mapped[int] = mapped_column(Integer, nullable=False) + meta: Mapped[dict] = mapped_column(JSON) class ProjectImage(DbCommon): """项目图片表""" __tablename__ = "project_image" - image_url = Mapped[str] = mapped_column(String(255), nullable=False) - thumb_image_url = Mapped[str] = mapped_column(String(255), nullable=False) - project_id = Mapped[int] = mapped_column(Integer) + image_url: Mapped[str] = mapped_column(String(255), nullable=False) + thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False) + project_id: Mapped[int] = mapped_column(Integer) diff --git a/app/model/crud/project_image_crud.py b/app/model/crud/project_image_crud.py index 6573116..a0ee5f6 100644 --- a/app/model/crud/project_image_crud.py +++ b/app/model/crud/project_image_crud.py @@ -1,5 +1,6 @@ from sqlalchemy.orm import Session from sqlalchemy import asc +from typing import List from app.model.bussiness_model import ProjectImage as piModel from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager @@ -13,9 +14,9 @@ def get_image_pager(image: ProjectImagePager, session: Session): return pager -def get_image_list(image: ProjectImage, session: Session): - query = session.query(piModel).filter(piModel.project_id == image.project_id).order_by(asc(piModel.id)) - image_list = [ProjectImage.from_orm(image) for image in query.all()] +def get_image_list(project_id: int, session: Session): + query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id)) + image_list = [ProjectImage.from_orm(image).dict() for image in query.all()] return image_list @@ -25,8 +26,14 @@ def add_image(image: ProjectImage, session: Session): return image -def del_image(id: str, session: Session): - row_del = session.query(piModel).filter_by(id=id).delete() +def add_image_batch(images: List[ProjectImage], session: Session): + for image in images: + session.add(image) + session.commit() + + +def del_image(image_id: str, session: Session): + row_del = session.query(piModel).filter_by(id=image_id).delete() session.commit() return row_del diff --git a/app/model/crud/project_info_crud.py b/app/model/crud/project_info_crud.py index a3a9418..c18a9f3 100644 --- a/app/model/crud/project_info_crud.py +++ b/app/model/crud/project_info_crud.py @@ -20,11 +20,12 @@ def get_project_pager(info: ProjectInfoPager, session: Session): return pager -def get_project_by_id(id: str, session: Session): - info = session.query(ProjectInfo).filter_by(id=id).first() +def get_project_by_id(project_id: str, session: Session): + info = session.query(ProjectInfo).filter_by(id=project_id).first() info_out = ProjectInfoOut.from_orm(info) return info_out + def add_project(info: ProjectInfo, session: Session): """新建项目,并在对应文件夹下面创建文件夹""" session.add(info) diff --git a/app/model/crud/project_label_crud.py b/app/model/crud/project_label_crud.py index 9c0d674..08bcad8 100644 --- a/app/model/crud/project_label_crud.py +++ b/app/model/crud/project_label_crud.py @@ -1,4 +1,5 @@ from sqlalchemy.orm import Session +from sqlalchemy import and_ from app.model.bussiness_model import ProjectLabel as plModel from app.model.schemas.project_label_schemas import ProjectLabel @@ -12,7 +13,7 @@ def get_label_list(project_id: int, session: Session): :return: """ label_list = session.query(plModel).filter(plModel.project_id == project_id).all() - label_list = [ProjectLabel.from_orm(label) for label in label_list] + label_list = [ProjectLabel.from_orm(label).dict() for label in label_list] return label_list @@ -41,11 +42,12 @@ def check_label_name(project_id: int, label_name: str, session: Session, label_i filters = [plModel.project_id == project_id, plModel.label_name == label_name] if label_id is not None: filters.append(plModel.id != label_id) - query.filter(*filters) - if query.count() > 0: - return False - else: + query = query.filter(and_(*filters)) + count = query.count() + if count > 0: return True + else: + return False def update_label(label: plModel, session: Session): @@ -56,17 +58,19 @@ def update_label(label: plModel, session: Session): :return: """ session.query(plModel).filter_by(id=label.id).update({ - "label_name": label.label_name + "label_name": label.label_name, + "meta": label.meta }) session.commit() -def del_label(id: str, session: Session): +def del_label(label_id: str, session: Session): """ 根据标签id删除标签 - :param id: 标签id + :param label_id: 标签id :param session: :return: """ - row_del = session.query(plModel).filter_by(id=id).delete() + row_del = session.query(plModel).filter_by(id=label_id).delete() + session.commit() return row_del diff --git a/app/model/crud/sys_user_crud.py b/app/model/crud/sys_user_crud.py index 9583807..f02595e 100644 --- a/app/model/crud/sys_user_crud.py +++ b/app/model/crud/sys_user_crud.py @@ -13,7 +13,7 @@ def user_pager(user: SysUserPager, session: Session): if user.username is not None: filters.append(SysUser.username.ilike(f"%{user.username}%")) if len(filters) > 0: - query.filter(and_(*filters)) + query = query.filter(and_(*filters)) pager = get_pager(query, user.pagerNum, user.pagerSize) pager.data = [SysUserOut.from_orm(user).dict() for user in pager.data] return pager @@ -26,8 +26,8 @@ def add_user(user: SysUser, session: Session): return user -def get_user_by_id(id: int, session: Session): - user = session.query(SysUser).filter(SysUser.id == id).first() +def get_user_by_id(user_id: int, session: Session): + user = session.query(SysUser).filter(SysUser.id == user_id).first() return user diff --git a/app/model/schemas/project_image_schemas.py b/app/model/schemas/project_image_schemas.py index 206d809..230d1cd 100644 --- a/app/model/schemas/project_image_schemas.py +++ b/app/model/schemas/project_image_schemas.py @@ -8,6 +8,9 @@ class ProjectImage(BaseModel): image_url: Optional[str] = Field(..., description="原图路径") thumb_image_url: Optional[str] = Field(..., description="缩略图路径") + class Config: + orm_mode = True + class ProjectImagePager(BaseModel): project_id: Optional[int] = Field(..., description="项目id") diff --git a/app/model/schemas/project_label_schemas.py b/app/model/schemas/project_label_schemas.py index 4ee4a0e..91ff16d 100644 --- a/app/model/schemas/project_label_schemas.py +++ b/app/model/schemas/project_label_schemas.py @@ -5,8 +5,9 @@ from typing import Optional class ProjectLabel(BaseModel): """项目标签输入输出""" id: Optional[int] = Field(None, description="id") - project_id: Optional[int] = Field(..., description="项目id") + project_id: Optional[int] = Field(None, description="项目id") label_name: Optional[str] = Field(..., description="标签名称") + meta: Optional[dict] = Field(None, description="label属性") class Config: orm_mode = True diff --git a/app/service/project_service.py b/app/service/project_service.py index 35c3ac7..ca3bea8 100644 --- a/app/service/project_service.py +++ b/app/service/project_service.py @@ -1,10 +1,11 @@ from app.model.bussiness_model import ProjectImage -from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut from app.model.bussiness_model import ProjectInfo +from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut +from app.model.crud import project_info_crud as pic +from app.model.crud import project_image_crud as pimc +from app.util import os_utils as os from app.util import random_utils as ru from app.config.config_reader import datasets_url, runs_url, images_url -from app.model.crud import project_info_crud as pic -from app.util import os_utils as os from sqlalchemy.orm import Session from typing import List @@ -30,21 +31,27 @@ def add_project(info: ProjectInfoIn, session: Session, user_id: int): return project_info -def upload_project_image(session: Session, project_info: ProjectInfoOut, files: List[UploadFile]): +def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session): """ 上传项目的图片 :param files: 上传的图片 :param project_info: 项目信息 - :param image: :param session: :return: """ + images = [] for file in files: image = ProjectImage() image.project_id = project_info.id # 保存原图 path = os.save_images(images_url, project_info.project_no, file=file) image.image_url = path + # 生成缩略图 + thumb_image_url = images_url + "\\thumb\\" + project_info.project_no + "\\" + ru.random_str(10) + ".jpg" + os.create_thumbnail(path, thumb_image_url) + image.thumb_image_url = thumb_image_url + images.append(image) + pimc.add_image_batch(images, session) diff --git a/app/util/os_utils.py b/app/util/os_utils.py index ef798c2..3d5956c 100644 --- a/app/util/os_utils.py +++ b/app/util/os_utils.py @@ -7,7 +7,7 @@ def create_folder(*path): """根据路径创建文件夹""" folder_path = os.path.join(*path) try: - os.makedirs(path, exist_ok=True) + os.makedirs(folder_path, exist_ok=True) except Exception as e: print(f"创建文件夹时错误: {e}") @@ -20,6 +20,8 @@ def save_images(*path, file: UploadFile): :return: """ save_path = os.path.join(*path, file.filename) + + os.makedirs(os.path.dirname(save_path), exist_ok=True) with open(save_path, "wb") as f: for line in file.file: f.write(line) @@ -38,6 +40,6 @@ def create_thumbnail(input_image_path, out_image_path, size=(116, 70)): # 使用thumbnail方法生成缩略图,参数size指定缩略图的最大尺寸 # 注意:thumbnail方法会保持图片的宽高比不变 image.thumbnail(size) - + os.makedirs(os.path.dirname(out_image_path), exist_ok=True) # 保存生成的缩略图 image.save(out_image_path, 'JPEG')