项目基础模块代码

This commit is contained in:
2025-02-19 16:57:49 +08:00
parent 31302bcd17
commit bed123c532
14 changed files with 159 additions and 57 deletions

View File

@ -1,14 +1,17 @@
from app.model.crud import project_type_crud as ptc from app.model.crud import project_type_crud as ptc
from app.model.crud import project_label_crud as plc from app.model.crud import project_label_crud as plc
from app.model.crud import project_info_crud as pic from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.service import project_service as ps from app.service import project_service as ps
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_label_schemas import ProjectLabel from app.model.schemas.project_label_schemas import ProjectLabel
from app.model.bussiness_model import ProjectLabel as pl
from app.db.db_session import get_db from app.db.db_session import get_db
from app.common.jwt_check import get_user_id from app.common.jwt_check import get_user_id
from app.common import reponse_code as rc from app.common import reponse_code as rc
from fastapi import APIRouter, Depends, Request from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File, Form
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
"""项目管理API""" """项目管理API"""
@ -38,6 +41,18 @@ def add_project(request: Request, info: ProjectInfoIn, session: Session = Depend
return rc.response_success(msg="新建成功") return rc.response_success(msg="新建成功")
@project.get("/label_list/{project_id}")
def get_label_list(project_id: int, session: Session = Depends(get_db)):
"""
根据项目id查询项目标签列表
:param project_id:
:param session:
:return:
"""
label_list = plc.get_label_list(project_id, session)
return rc.response_success(msg="查询成功", data=label_list)
@project.post("/add_label") @project.post("/add_label")
def add_label(label: ProjectLabel, session: Session = Depends(get_db)): def add_label(label: ProjectLabel, session: Session = Depends(get_db)):
""" """
@ -48,7 +63,7 @@ def add_label(label: ProjectLabel, session: Session = Depends(get_db)):
""" """
if plc.check_label_name(label.project_id, label.label_name, session): if plc.check_label_name(label.project_id, label.label_name, session):
return rc.response_error("标签名称已经存在,不能重复") return rc.response_error("标签名称已经存在,不能重复")
label_save = ProjectLabel(**label.dict()) label_save = pl(**label.dict())
plc.add_label(label_save, session) plc.add_label(label_save, session)
return rc.response_success(msg="保存成功") return rc.response_success(msg="保存成功")
@ -63,24 +78,51 @@ def up_label(label: ProjectLabel, session: Session = Depends(get_db)):
""" """
if plc.check_label_name(label.project_id, label.label_name, session, label.id): if plc.check_label_name(label.project_id, label.label_name, session, label.id):
return rc.response_error("修改的标签名称已经存在,不能重复") return rc.response_error("修改的标签名称已经存在,不能重复")
label_save = ProjectLabel(**label.dict()) label_save = pl(**label.dict())
plc.update_label(label_save, session) plc.update_label(label_save, session)
return rc.response_success(msg="修改成功") return rc.response_success(msg="修改成功")
@project.post("/del_label") @project.post("/del_label/{label_id}")
def del_label(label: ProjectLabel, session: Session = Depends(get_db)): def del_label(label_id: int, session: Session = Depends(get_db)):
""" """
删除标签 删除标签
:param label: :param label_id:
:param session: :param session:
:return: :return:
""" """
row_del = plc.update_label(label.id, session) row_del = plc.del_label(label_id, session)
if row_del > 0: if row_del > 0:
return rc.response_success(msg="删除成功") return rc.response_success(msg="删除成功")
else: else:
return rc.response_error("删除失败") return rc.response_error("删除失败")
@project.post("/up_proj_img")
def upload_project_image(project_id: int = Form(...), files: List[UploadFile] = File(...), session: Session = Depends(get_db)):
"""
上传项目图片
:param files: 文件图片
:param project_id:
:param session:
:return:
"""
project_info = pic.get_project_by_id(project_id, session)
if project_info is None:
return rc.response_error("项目查询错误,请刷新页面后再试")
ps.upload_project_image(project_info, files, session)
return rc.response_success(msg="上传成功")
@project.get("/img_list/{project_id}")
def get_image_list(project_id: int, session: Session = Depends(get_db)):
"""
获取项目图片列表
:param project_id: 项目id
:param session:
:return:
"""
image_list = pimc.get_image_list(project_id, session)
return rc.response_success(data=image_list)

View File

@ -0,0 +1,21 @@
import os
from fastapi import APIRouter, HTTPException
from starlette.responses import FileResponse
from app.config.config_reader import images_url
view = APIRouter()
@view.get("/{file_path:path}")
def view_img(file_path):
"""
查看图片
:param file_path: 图片路径
:return:
"""
image_path = os.path.join(images_url, file_path)
# 检查文件是否存在以及是否是文件
if not os.path.isfile(image_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(image_path, media_type='image/jpeg')

View File

@ -19,7 +19,7 @@ def user_pager(user: SysUserPager, session: Session = Depends(get_db)):
return rc.response_success_pager(pager) return rc.response_success_pager(pager)
@user.post("/") @user.post("/add")
def add_user(user: SysUserIn, session: Session = Depends(get_db)): def add_user(user: SysUserIn, session: Session = Depends(get_db)):
""" """
新增用户 新增用户
@ -81,4 +81,4 @@ def start_user(id: int, session: Session = Depends(get_db)):
if user is None: if user is None:
return rc.response_error("用户查询错误,请稍后再试") return rc.response_error("用户查询错误,请稍后再试")
us.start_user(user) us.start_user(user)
return rc.response_success("启用用户成功") return rc.response_success("启用用户成功")

View File

@ -8,6 +8,7 @@ from app.api.common.upload_file import upload_files
from app.api.sys.login_api import login from app.api.sys.login_api import login
from app.api.sys.sys_user_api import user from app.api.sys.sys_user_api import user
from app.api.business.project_api import project from app.api.business.project_api import project
from app.api.common.view_img import view
my_app = FastAPI() my_app = FastAPI()
@ -29,8 +30,10 @@ my_app.add_middleware(
my_app.add_middleware(LoggerMiddleware) my_app.add_middleware(LoggerMiddleware)
my_app.add_middleware(TokenMiddleware) my_app.add_middleware(TokenMiddleware)
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
my_app.include_router(login, prefix="/login", tags=["用户登录接口"]) my_app.include_router(login, prefix="/login", tags=["用户登录接口"])
my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"]) my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
my_app.include_router(view, prefix="/view_img", tags=["查看图片"])
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
my_app.include_router(project, prefix="/proj", tags=["项目管理API"]) my_app.include_router(project, prefix="/proj", tags=["项目管理API"])

View File

@ -6,7 +6,6 @@ from app.common import reponse_code as rc
from app.common import jwt_check as jc from app.common import jwt_check as jc
class TokenMiddleware(BaseHTTPMiddleware): class TokenMiddleware(BaseHTTPMiddleware):
def __init__(self, app): def __init__(self, app):
@ -21,7 +20,7 @@ class TokenMiddleware(BaseHTTPMiddleware):
""" """
token = request.headers.get('Authorization') token = request.headers.get('Authorization')
path = request.url.path path = request.url.path
if '/login' in path: if check_green(path):
response = await call_next(request) response = await call_next(request)
return response return response
if not token: if not token:
@ -30,4 +29,15 @@ class TokenMiddleware(BaseHTTPMiddleware):
jc.check_token(token) jc.check_token(token)
return await call_next(request) return await call_next(request)
except PyJWTError as error: except PyJWTError as error:
print(error)
return rc.response_code_view(status.HTTP_401_UNAUTHORIZED, "Token错误或失效请重新验证") return rc.response_code_view(status.HTTP_401_UNAUTHORIZED, "Token错误或失效请重新验证")
green = ['/login', '/view_img']
def check_green(s: str):
for url in green:
if url in s:
return True
return False

View File

@ -1,5 +1,5 @@
from app.db.db_base import DbCommon from app.db.db_base import DbCommon
from sqlalchemy import String, Integer from sqlalchemy import String, Integer, JSON
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@ -8,35 +8,36 @@ class ProjectType(DbCommon):
项目类别表 - 标识项目的类型目前存在的目标识别OCR识别瑕疵检测图像分类 项目类别表 - 标识项目的类型目前存在的目标识别OCR识别瑕疵检测图像分类
""" """
__tablename__ = "project_type" __tablename__ = "project_type"
type_code = Mapped[str] = mapped_column(String(20), unique=True, nullable=False) type_code: Mapped[str] = mapped_column(String(20), unique=True, nullable=False)
type_name = Mapped[str] = mapped_column(String(20)) type_name: Mapped[str] = mapped_column(String(20))
icon_path = Mapped[str] = mapped_column(String(255)) icon_path: Mapped[str] = mapped_column(String(255))
description = Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(String(255))
type_status = Mapped[str] = mapped_column(String(10)) type_status: Mapped[str] = mapped_column(String(10))
class ProjectInfo(DbCommon): class ProjectInfo(DbCommon):
"""项目信息表""" """项目信息表"""
__tablename__ = "project_info" __tablename__ = "project_info"
project_no = Mapped[str] = mapped_column(String(32), unique=True, nullable=False) project_no: Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
project_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False) project_name: Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
type_code = Mapped[str] = mapped_column(String(10)) type_code: Mapped[str] = mapped_column(String(10))
description = Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(String(255))
project_status = Mapped[str] = mapped_column(String(10)) project_status: Mapped[str] = mapped_column(String(10))
user_id = Mapped[int] = mapped_column(Integer) user_id: Mapped[int] = mapped_column(Integer)
train_version = Mapped[int] = mapped_column(Integer) train_version: Mapped[int] = mapped_column(Integer)
class ProjectLabel(DbCommon): class ProjectLabel(DbCommon):
"""项目标签表""" """项目标签表"""
__tablename__ = "project_label" __tablename__ = "project_label"
label_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False) label_name: Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
project_id = Mapped[int] = mapped_column(Integer, nullable=False) project_id: Mapped[int] = mapped_column(Integer, nullable=False)
meta: Mapped[dict] = mapped_column(JSON)
class ProjectImage(DbCommon): class ProjectImage(DbCommon):
"""项目图片表""" """项目图片表"""
__tablename__ = "project_image" __tablename__ = "project_image"
image_url = Mapped[str] = mapped_column(String(255), nullable=False) image_url: Mapped[str] = mapped_column(String(255), nullable=False)
thumb_image_url = Mapped[str] = mapped_column(String(255), nullable=False) thumb_image_url: Mapped[str] = mapped_column(String(255), nullable=False)
project_id = Mapped[int] = mapped_column(Integer) project_id: Mapped[int] = mapped_column(Integer)

View File

@ -1,5 +1,6 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import asc from sqlalchemy import asc
from typing import List
from app.model.bussiness_model import ProjectImage as piModel from app.model.bussiness_model import ProjectImage as piModel
from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager
@ -13,9 +14,9 @@ def get_image_pager(image: ProjectImagePager, session: Session):
return pager return pager
def get_image_list(image: ProjectImage, session: Session): def get_image_list(project_id: int, session: Session):
query = session.query(piModel).filter(piModel.project_id == image.project_id).order_by(asc(piModel.id)) query = session.query(piModel).filter(piModel.project_id == project_id).order_by(asc(piModel.id))
image_list = [ProjectImage.from_orm(image) for image in query.all()] image_list = [ProjectImage.from_orm(image).dict() for image in query.all()]
return image_list return image_list
@ -25,8 +26,14 @@ def add_image(image: ProjectImage, session: Session):
return image return image
def del_image(id: str, session: Session): def add_image_batch(images: List[ProjectImage], session: Session):
row_del = session.query(piModel).filter_by(id=id).delete() for image in images:
session.add(image)
session.commit()
def del_image(image_id: str, session: Session):
row_del = session.query(piModel).filter_by(id=image_id).delete()
session.commit() session.commit()
return row_del return row_del

View File

@ -20,11 +20,12 @@ def get_project_pager(info: ProjectInfoPager, session: Session):
return pager return pager
def get_project_by_id(id: str, session: Session): def get_project_by_id(project_id: str, session: Session):
info = session.query(ProjectInfo).filter_by(id=id).first() info = session.query(ProjectInfo).filter_by(id=project_id).first()
info_out = ProjectInfoOut.from_orm(info) info_out = ProjectInfoOut.from_orm(info)
return info_out return info_out
def add_project(info: ProjectInfo, session: Session): def add_project(info: ProjectInfo, session: Session):
"""新建项目,并在对应文件夹下面创建文件夹""" """新建项目,并在对应文件夹下面创建文件夹"""
session.add(info) session.add(info)

View File

@ -1,4 +1,5 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import and_
from app.model.bussiness_model import ProjectLabel as plModel from app.model.bussiness_model import ProjectLabel as plModel
from app.model.schemas.project_label_schemas import ProjectLabel from app.model.schemas.project_label_schemas import ProjectLabel
@ -12,7 +13,7 @@ def get_label_list(project_id: int, session: Session):
:return: :return:
""" """
label_list = session.query(plModel).filter(plModel.project_id == project_id).all() label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
label_list = [ProjectLabel.from_orm(label) for label in label_list] label_list = [ProjectLabel.from_orm(label).dict() for label in label_list]
return label_list return label_list
@ -41,11 +42,12 @@ def check_label_name(project_id: int, label_name: str, session: Session, label_i
filters = [plModel.project_id == project_id, plModel.label_name == label_name] filters = [plModel.project_id == project_id, plModel.label_name == label_name]
if label_id is not None: if label_id is not None:
filters.append(plModel.id != label_id) filters.append(plModel.id != label_id)
query.filter(*filters) query = query.filter(and_(*filters))
if query.count() > 0: count = query.count()
return False if count > 0:
else:
return True return True
else:
return False
def update_label(label: plModel, session: Session): def update_label(label: plModel, session: Session):
@ -56,17 +58,19 @@ def update_label(label: plModel, session: Session):
:return: :return:
""" """
session.query(plModel).filter_by(id=label.id).update({ session.query(plModel).filter_by(id=label.id).update({
"label_name": label.label_name "label_name": label.label_name,
"meta": label.meta
}) })
session.commit() session.commit()
def del_label(id: str, session: Session): def del_label(label_id: str, session: Session):
""" """
根据标签id删除标签 根据标签id删除标签
:param id: 标签id :param label_id: 标签id
:param session: :param session:
:return: :return:
""" """
row_del = session.query(plModel).filter_by(id=id).delete() row_del = session.query(plModel).filter_by(id=label_id).delete()
session.commit()
return row_del return row_del

View File

@ -13,7 +13,7 @@ def user_pager(user: SysUserPager, session: Session):
if user.username is not None: if user.username is not None:
filters.append(SysUser.username.ilike(f"%{user.username}%")) filters.append(SysUser.username.ilike(f"%{user.username}%"))
if len(filters) > 0: if len(filters) > 0:
query.filter(and_(*filters)) query = query.filter(and_(*filters))
pager = get_pager(query, user.pagerNum, user.pagerSize) pager = get_pager(query, user.pagerNum, user.pagerSize)
pager.data = [SysUserOut.from_orm(user).dict() for user in pager.data] pager.data = [SysUserOut.from_orm(user).dict() for user in pager.data]
return pager return pager
@ -26,8 +26,8 @@ def add_user(user: SysUser, session: Session):
return user return user
def get_user_by_id(id: int, session: Session): def get_user_by_id(user_id: int, session: Session):
user = session.query(SysUser).filter(SysUser.id == id).first() user = session.query(SysUser).filter(SysUser.id == user_id).first()
return user return user

View File

@ -8,6 +8,9 @@ class ProjectImage(BaseModel):
image_url: Optional[str] = Field(..., description="原图路径") image_url: Optional[str] = Field(..., description="原图路径")
thumb_image_url: Optional[str] = Field(..., description="缩略图路径") thumb_image_url: Optional[str] = Field(..., description="缩略图路径")
class Config:
orm_mode = True
class ProjectImagePager(BaseModel): class ProjectImagePager(BaseModel):
project_id: Optional[int] = Field(..., description="项目id") project_id: Optional[int] = Field(..., description="项目id")

View File

@ -5,8 +5,9 @@ from typing import Optional
class ProjectLabel(BaseModel): class ProjectLabel(BaseModel):
"""项目标签输入输出""" """项目标签输入输出"""
id: Optional[int] = Field(None, description="id") id: Optional[int] = Field(None, description="id")
project_id: Optional[int] = Field(..., description="项目id") project_id: Optional[int] = Field(None, description="项目id")
label_name: Optional[str] = Field(..., description="标签名称") label_name: Optional[str] = Field(..., description="标签名称")
meta: Optional[dict] = Field(None, description="label属性")
class Config: class Config:
orm_mode = True orm_mode = True

View File

@ -1,10 +1,11 @@
from app.model.bussiness_model import ProjectImage from app.model.bussiness_model import ProjectImage
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.bussiness_model import ProjectInfo from app.model.bussiness_model import ProjectInfo
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.crud import project_info_crud as pic
from app.model.crud import project_image_crud as pimc
from app.util import os_utils as os
from app.util import random_utils as ru from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url from app.config.config_reader import datasets_url, runs_url, images_url
from app.model.crud import project_info_crud as pic
from app.util import os_utils as os
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List from typing import List
@ -30,21 +31,27 @@ def add_project(info: ProjectInfoIn, session: Session, user_id: int):
return project_info return project_info
def upload_project_image(session: Session, project_info: ProjectInfoOut, files: List[UploadFile]): def upload_project_image(project_info: ProjectInfoOut, files: List[UploadFile], session: Session):
""" """
上传项目的图片 上传项目的图片
:param files: 上传的图片 :param files: 上传的图片
:param project_info: 项目信息 :param project_info: 项目信息
:param image:
:param session: :param session:
:return: :return:
""" """
images = []
for file in files: for file in files:
image = ProjectImage() image = ProjectImage()
image.project_id = project_info.id image.project_id = project_info.id
# 保存原图 # 保存原图
path = os.save_images(images_url, project_info.project_no, file=file) path = os.save_images(images_url, project_info.project_no, file=file)
image.image_url = path image.image_url = path
# 生成缩略图
thumb_image_url = images_url + "\\thumb\\" + project_info.project_no + "\\" + ru.random_str(10) + ".jpg"
os.create_thumbnail(path, thumb_image_url)
image.thumb_image_url = thumb_image_url
images.append(image)
pimc.add_image_batch(images, session)

View File

@ -7,7 +7,7 @@ def create_folder(*path):
"""根据路径创建文件夹""" """根据路径创建文件夹"""
folder_path = os.path.join(*path) folder_path = os.path.join(*path)
try: try:
os.makedirs(path, exist_ok=True) os.makedirs(folder_path, exist_ok=True)
except Exception as e: except Exception as e:
print(f"创建文件夹时错误: {e}") print(f"创建文件夹时错误: {e}")
@ -20,6 +20,8 @@ def save_images(*path, file: UploadFile):
:return: :return:
""" """
save_path = os.path.join(*path, file.filename) save_path = os.path.join(*path, file.filename)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "wb") as f: with open(save_path, "wb") as f:
for line in file.file: for line in file.file:
f.write(line) f.write(line)
@ -38,6 +40,6 @@ def create_thumbnail(input_image_path, out_image_path, size=(116, 70)):
# 使用thumbnail方法生成缩略图参数size指定缩略图的最大尺寸 # 使用thumbnail方法生成缩略图参数size指定缩略图的最大尺寸
# 注意thumbnail方法会保持图片的宽高比不变 # 注意thumbnail方法会保持图片的宽高比不变
image.thumbnail(size) image.thumbnail(size)
os.makedirs(os.path.dirname(out_image_path), exist_ok=True)
# 保存生成的缩略图 # 保存生成的缩略图
image.save(out_image_path, 'JPEG') image.save(out_image_path, 'JPEG')