项目基础模块代码
This commit is contained in:
86
app/api/business/project_api.py
Normal file
86
app/api/business/project_api.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from app.model.crud import project_type_crud as ptc
|
||||||
|
from app.model.crud import project_label_crud as plc
|
||||||
|
from app.model.crud import project_info_crud as pic
|
||||||
|
from app.service import project_service as ps
|
||||||
|
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
|
||||||
|
from app.model.schemas.project_label_schemas import ProjectLabel
|
||||||
|
from app.db.db_session import get_db
|
||||||
|
from app.common.jwt_check import get_user_id
|
||||||
|
from app.common import reponse_code as rc
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
"""项目管理API"""
|
||||||
|
project = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@project.get("/types")
|
||||||
|
def get_type_list(session: Session = Depends(get_db)):
|
||||||
|
"""获取项目类别"""
|
||||||
|
type_list = ptc.get_list(session)
|
||||||
|
return rc.response_success(data=type_list)
|
||||||
|
|
||||||
|
|
||||||
|
@project.post("/list")
|
||||||
|
def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
|
||||||
|
pager = pic.get_project_pager(info, session)
|
||||||
|
return rc.response_success_pager(pager)
|
||||||
|
|
||||||
|
|
||||||
|
@project.post("/add")
|
||||||
|
def add_project(request: Request, info: ProjectInfoIn, session: Session = Depends(get_db)):
|
||||||
|
"""新建项目"""
|
||||||
|
if pic.check_project_name(info.project_name, session):
|
||||||
|
return rc.response_error("已经存在相同名称的项目")
|
||||||
|
user_id = get_user_id(request)
|
||||||
|
ps.add_project(info, session, user_id)
|
||||||
|
return rc.response_success(msg="新建成功")
|
||||||
|
|
||||||
|
|
||||||
|
@project.post("/add_label")
|
||||||
|
def add_label(label: ProjectLabel, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
新增标签
|
||||||
|
:param label:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if plc.check_label_name(label.project_id, label.label_name, session):
|
||||||
|
return rc.response_error("标签名称已经存在,不能重复")
|
||||||
|
label_save = ProjectLabel(**label.dict())
|
||||||
|
plc.add_label(label_save, session)
|
||||||
|
return rc.response_success(msg="保存成功")
|
||||||
|
|
||||||
|
|
||||||
|
@project.post("/up_label")
|
||||||
|
def up_label(label: ProjectLabel, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
修改标签
|
||||||
|
:param label:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if plc.check_label_name(label.project_id, label.label_name, session, label.id):
|
||||||
|
return rc.response_error("修改的标签名称已经存在,不能重复")
|
||||||
|
label_save = ProjectLabel(**label.dict())
|
||||||
|
plc.update_label(label_save, session)
|
||||||
|
return rc.response_success(msg="修改成功")
|
||||||
|
|
||||||
|
|
||||||
|
@project.post("/del_label")
|
||||||
|
def del_label(label: ProjectLabel, session: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
删除标签
|
||||||
|
:param label:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
row_del = plc.update_label(label.id, session)
|
||||||
|
if row_del > 0:
|
||||||
|
return rc.response_success(msg="删除成功")
|
||||||
|
else:
|
||||||
|
return rc.response_error("删除失败")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,17 +1,18 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
import os
|
import os
|
||||||
|
|
||||||
upload = APIRouter()
|
from app.config.config_reader import images_url
|
||||||
|
|
||||||
|
upload_files = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@upload.post("/")
|
@upload_files.post("/")
|
||||||
def upload_file(files: List[UploadFile]):
|
def upload(files: List[UploadFile], project_no: str):
|
||||||
paths = []
|
paths = []
|
||||||
for file in files:
|
for file in files:
|
||||||
path = os.path.join("images", file.filename)
|
path = os.path.join(images_url, project_no, file.filename)
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
for line in file.file:
|
for line in file.file:
|
||||||
f.write(line)
|
f.write(line)
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from app.model.schemas.sys_user_schemas import SysUserOut, SysUserIN, SysUserPager
|
from app.model.schemas.sys_user_schemas import SysUserOut, SysUserIn, SysUserPager
|
||||||
from app.common import reponse_code as rc
|
from app.common import reponse_code as rc
|
||||||
from app.model.crud import sys_user_crud as us
|
from app.model.crud import sys_user_crud as us
|
||||||
from app.model.model import SysUser
|
from app.model.sys_model import SysUser
|
||||||
from app.common.redis_cli import redis_conn
|
from app.common.redis_cli import redis_conn
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -20,7 +20,7 @@ def user_pager(user: SysUserPager, session: Session = Depends(get_db)):
|
|||||||
|
|
||||||
|
|
||||||
@user.post("/")
|
@user.post("/")
|
||||||
def add_user(user: SysUserIN, session: Session = Depends(get_db)):
|
def add_user(user: SysUserIn, session: Session = Depends(get_db)):
|
||||||
"""
|
"""
|
||||||
新增用户
|
新增用户
|
||||||
:param session:
|
:param session:
|
||||||
@ -30,7 +30,7 @@ def add_user(user: SysUserIN, session: Session = Depends(get_db)):
|
|||||||
if us.check_username(user.username, session):
|
if us.check_username(user.username, session):
|
||||||
return rc.response_error(msg="该用户名已存在!")
|
return rc.response_error(msg="该用户名已存在!")
|
||||||
else:
|
else:
|
||||||
user_in= SysUser(**user.dict())
|
user_in = SysUser(**user.dict())
|
||||||
user_in.user_status = '0'
|
user_in.user_status = '0'
|
||||||
if us.add_user(user_in, session):
|
if us.add_user(user_in, session):
|
||||||
return rc.response_success(msg="保存成功")
|
return rc.response_success(msg="保存成功")
|
||||||
@ -49,7 +49,7 @@ def get_user(id: int, session: Session = Depends(get_db)):
|
|||||||
user = us.get_user_by_id(id, session)
|
user = us.get_user_by_id(id, session)
|
||||||
if user is None:
|
if user is None:
|
||||||
return rc.response_success(data=None)
|
return rc.response_success(data=None)
|
||||||
user_out = SysUserOut(**dict(user))
|
user_out = SysUserOut.from_orm(user)
|
||||||
return rc.response_success(data=user_out.dict())
|
return rc.response_success(data=user_out.dict())
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,6 +4,11 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from app.application.token_middleware import TokenMiddleware
|
from app.application.token_middleware import TokenMiddleware
|
||||||
from app.application.logger_middleware import LoggerMiddleware
|
from app.application.logger_middleware import LoggerMiddleware
|
||||||
|
|
||||||
|
from app.api.common.upload_file import upload_files
|
||||||
|
from app.api.sys.login_api import login
|
||||||
|
from app.api.sys.sys_user_api import user
|
||||||
|
from app.api.business.project_api import project
|
||||||
|
|
||||||
my_app = FastAPI()
|
my_app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +23,14 @@ my_app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
#注意中间的顺序,这个地方是倒序执行的
|
'''
|
||||||
|
注意中间的顺序,这个地方是倒序执行的
|
||||||
|
'''
|
||||||
my_app.add_middleware(LoggerMiddleware)
|
my_app.add_middleware(LoggerMiddleware)
|
||||||
my_app.add_middleware(TokenMiddleware)
|
my_app.add_middleware(TokenMiddleware)
|
||||||
|
|
||||||
|
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
|
||||||
|
my_app.include_router(login, prefix="/login", tags=["用户登录接口"])
|
||||||
|
my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
|
||||||
|
my_app.include_router(project, prefix="/proj", tags=["项目管理API"])
|
||||||
|
|
||||||
|
10
app/application/exception_handler.py
Normal file
10
app/application/exception_handler.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from app.common.reponse_code import response_error
|
||||||
|
|
||||||
|
from app import my_app
|
||||||
|
"""全局异常处理"""
|
||||||
|
|
||||||
|
|
||||||
|
@my_app.exception_handlers(HTTPException)
|
||||||
|
async def http_exception(request: Request, he: HTTPException):
|
||||||
|
return response_error(request.url + "出现异常:" + he.detail)
|
@ -2,8 +2,8 @@ from fastapi import status
|
|||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from urllib.request import Request
|
from urllib.request import Request
|
||||||
from jwt import PyJWTError
|
from jwt import PyJWTError
|
||||||
from common import reponse_code as rc
|
from app.common import reponse_code as rc
|
||||||
from common import jwt_check as jc
|
from app.common import jwt_check as jc
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import datetime
|
|||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from app.common.redis_cli import redis_conn
|
from app.common.redis_cli import redis_conn
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
# 过期时间,单位S
|
# 过期时间,单位S
|
||||||
exp = 6000
|
exp = 6000
|
||||||
@ -48,3 +49,11 @@ def check_token(token: str):
|
|||||||
raise jwt.ExpiredSignatureError("Expired Token")
|
raise jwt.ExpiredSignatureError("Expired Token")
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
raise jwt.InvalidTokenError("Invalid Token")
|
raise jwt.InvalidTokenError("Invalid Token")
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_id(request: Request):
|
||||||
|
"""根据Request请求获取token"""
|
||||||
|
token = request.headers.get("Authorization")
|
||||||
|
decoded_payload = check_token(token)
|
||||||
|
user_id = decoded_payload['user_id']
|
||||||
|
return user_id
|
@ -12,6 +12,7 @@ def response_code_view(code: int,msg: str) -> Response:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def response_success(msg: str = "查询成功", data: object = None):
|
def response_success(msg: str = "查询成功", data: object = None):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
@ -35,7 +36,7 @@ def response_success_pager(pager: Pager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def response_error(msg:str):
|
def response_error(msg: str):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
content={
|
content={
|
||||||
|
@ -9,3 +9,10 @@ password = sdust2020
|
|||||||
|
|
||||||
[log]
|
[log]
|
||||||
dir = D:\syg\workspace\logs
|
dir = D:\syg\workspace\logs
|
||||||
|
|
||||||
|
[yolo]
|
||||||
|
datasets_url = D:\syg\yolov5\datasets
|
||||||
|
runs_url = D:\syg\yolov5\runs
|
||||||
|
|
||||||
|
[images]
|
||||||
|
image_url = D:\syg\images
|
@ -1,11 +0,0 @@
|
|||||||
[mysql]
|
|
||||||
database_url = mysql+pymysql://root:root@localhost:3306/sun
|
|
||||||
|
|
||||||
[redis]
|
|
||||||
host = localhost
|
|
||||||
port = 6379
|
|
||||||
db = 0
|
|
||||||
password = 123456
|
|
||||||
|
|
||||||
[log]
|
|
||||||
dir = /Users/macbookpro/sunyg/workspace/logs
|
|
18
app/config/application_config_prod.ini
Normal file
18
app/config/application_config_prod.ini
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
[mysql]
|
||||||
|
database_url = mysql+pymysql://root:root@localhost:3306/sun
|
||||||
|
|
||||||
|
[redis]
|
||||||
|
host = localhost
|
||||||
|
port = 6379
|
||||||
|
db = 0
|
||||||
|
password = 123456
|
||||||
|
|
||||||
|
[log]
|
||||||
|
dir = /Users/macbookpro/sunyg/workspace/logs
|
||||||
|
|
||||||
|
[yolo]
|
||||||
|
datasets_url = /home/yolov5/datasets
|
||||||
|
runs_url = /home/yolov5/runs
|
||||||
|
|
||||||
|
[images]
|
||||||
|
image_url = /home/images
|
@ -3,6 +3,8 @@ import os
|
|||||||
|
|
||||||
env = "dev"
|
env = "dev"
|
||||||
|
|
||||||
|
# env = "prod"
|
||||||
|
|
||||||
|
|
||||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_path = os.path.join(script_directory, f'application_config_{env}.ini')
|
config_path = os.path.join(script_directory, f'application_config_{env}.ini')
|
||||||
@ -19,3 +21,8 @@ redis_db = config.get('redis', 'db')
|
|||||||
redis_password = config.get('redis', 'password')
|
redis_password = config.get('redis', 'password')
|
||||||
|
|
||||||
log_dir = config.get('log', 'dir')
|
log_dir = config.get('log', 'dir')
|
||||||
|
|
||||||
|
datasets_url = config.get('yolo', 'datasets_url')
|
||||||
|
runs_url = config.get('yolo', 'runs_url')
|
||||||
|
|
||||||
|
images_url = config.get('images', 'image_url')
|
@ -11,10 +11,6 @@ class DbCommon(Base):
|
|||||||
create_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow())
|
create_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow())
|
||||||
update_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow(), onupdate=datetime.datetime.utcnow())
|
update_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow(), onupdate=datetime.datetime.utcnow())
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return ["id", "create_time", "update_time"]
|
|
||||||
|
|
||||||
|
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def __tablename__(cls):
|
def __tablename__(cls):
|
||||||
return cls.__name__.lower()
|
return cls.__name__.lower()
|
||||||
|
42
app/model/bussiness_model.py
Normal file
42
app/model/bussiness_model.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from app.db.db_base import DbCommon
|
||||||
|
from sqlalchemy import String, Integer
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectType(DbCommon):
|
||||||
|
"""
|
||||||
|
项目类别表 - 标识项目的类型目前存在的(目标识别,OCR识别,瑕疵检测,图像分类)
|
||||||
|
"""
|
||||||
|
__tablename__ = "project_type"
|
||||||
|
type_code = Mapped[str] = mapped_column(String(20), unique=True, nullable=False)
|
||||||
|
type_name = Mapped[str] = mapped_column(String(20))
|
||||||
|
icon_path = Mapped[str] = mapped_column(String(255))
|
||||||
|
description = Mapped[str] = mapped_column(String(255))
|
||||||
|
type_status = Mapped[str] = mapped_column(String(10))
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectInfo(DbCommon):
|
||||||
|
"""项目信息表"""
|
||||||
|
__tablename__ = "project_info"
|
||||||
|
project_no = Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
|
||||||
|
project_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
|
||||||
|
type_code = Mapped[str] = mapped_column(String(10))
|
||||||
|
description = Mapped[str] = mapped_column(String(255))
|
||||||
|
project_status = Mapped[str] = mapped_column(String(10))
|
||||||
|
user_id = Mapped[int] = mapped_column(Integer)
|
||||||
|
train_version = Mapped[int] = mapped_column(Integer)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectLabel(DbCommon):
|
||||||
|
"""项目标签表"""
|
||||||
|
__tablename__ = "project_label"
|
||||||
|
label_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
|
||||||
|
project_id = Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectImage(DbCommon):
|
||||||
|
"""项目图片表"""
|
||||||
|
__tablename__ = "project_image"
|
||||||
|
image_url = Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
thumb_image_url = Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
project_id = Mapped[int] = mapped_column(Integer)
|
35
app/model/crud/project_image_crud.py
Normal file
35
app/model/crud/project_image_crud.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import asc
|
||||||
|
|
||||||
|
from app.model.bussiness_model import ProjectImage as piModel
|
||||||
|
from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager
|
||||||
|
from app.db.page_util import get_pager
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_pager(image: ProjectImagePager, session: Session):
|
||||||
|
query = session.query(piModel).filter(piModel.project_id == image.project_id).order_by(asc(piModel.id))
|
||||||
|
pager = get_pager(query, image.pagerNum, image.pagerSize)
|
||||||
|
pager.data = [ProjectImage.from_orm(image) for image in pager.data]
|
||||||
|
return pager
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_list(image: ProjectImage, session: Session):
|
||||||
|
query = session.query(piModel).filter(piModel.project_id == image.project_id).order_by(asc(piModel.id))
|
||||||
|
image_list = [ProjectImage.from_orm(image) for image in query.all()]
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
|
||||||
|
def add_image(image: ProjectImage, session: Session):
|
||||||
|
session.add(image)
|
||||||
|
session.commit()
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def del_image(id: str, session: Session):
|
||||||
|
row_del = session.query(piModel).filter_by(id=id).delete()
|
||||||
|
session.commit()
|
||||||
|
return row_del
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
42
app/model/crud/project_info_crud.py
Normal file
42
app/model/crud/project_info_crud.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import desc
|
||||||
|
|
||||||
|
from app.model.bussiness_model import ProjectInfo
|
||||||
|
from app.model.schemas.project_info_schemas import ProjectInfoOut
|
||||||
|
from app.model.schemas.project_info_schemas import ProjectInfoPager
|
||||||
|
from app.db.page_util import get_pager
|
||||||
|
|
||||||
|
|
||||||
|
def get_project_pager(info: ProjectInfoPager, session: Session):
|
||||||
|
"""分页查询项目信息"""
|
||||||
|
query = session.query(ProjectInfo).order_by(desc(ProjectInfo.id))
|
||||||
|
filters = []
|
||||||
|
if info.project_name is not None:
|
||||||
|
filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%"))
|
||||||
|
if len(filters) > 0:
|
||||||
|
query.filter(*filters)
|
||||||
|
pager = get_pager(query, info.pagerNum, info.pagerSize)
|
||||||
|
pager.data = [ProjectInfoOut.from_orm(info).dict() for info in pager.data]
|
||||||
|
return pager
|
||||||
|
|
||||||
|
|
||||||
|
def get_project_by_id(id: str, session: Session):
|
||||||
|
info = session.query(ProjectInfo).filter_by(id=id).first()
|
||||||
|
info_out = ProjectInfoOut.from_orm(info)
|
||||||
|
return info_out
|
||||||
|
|
||||||
|
def add_project(info: ProjectInfo, session: Session):
|
||||||
|
"""新建项目,并在对应文件夹下面创建文件夹"""
|
||||||
|
session.add(info)
|
||||||
|
session.commit()
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def check_project_name(project_name: str, session: Session):
|
||||||
|
"""检验是否存在重名的项目名称"""
|
||||||
|
count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name).count()
|
||||||
|
if count > 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
72
app/model/crud/project_label_crud.py
Normal file
72
app/model/crud/project_label_crud.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.model.bussiness_model import ProjectLabel as plModel
|
||||||
|
from app.model.schemas.project_label_schemas import ProjectLabel
|
||||||
|
|
||||||
|
|
||||||
|
def get_label_list(project_id: int, session: Session):
|
||||||
|
"""
|
||||||
|
根绝项目id获取标签列表
|
||||||
|
:param project_id: 项目id
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
|
||||||
|
label_list = [ProjectLabel.from_orm(label) for label in label_list]
|
||||||
|
return label_list
|
||||||
|
|
||||||
|
|
||||||
|
def add_label(label: plModel, session: Session):
|
||||||
|
"""
|
||||||
|
新增标签
|
||||||
|
:param label:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
session.add(label)
|
||||||
|
session.commit()
|
||||||
|
return label.id
|
||||||
|
|
||||||
|
|
||||||
|
def check_label_name(project_id: int, label_name: str, session: Session, label_id: int = None):
|
||||||
|
"""
|
||||||
|
检验标签名称是否存在
|
||||||
|
:param label_id:
|
||||||
|
:param project_id: 项目id
|
||||||
|
:param label_name: 标签名称
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
query = session.query(plModel)
|
||||||
|
filters = [plModel.project_id == project_id, plModel.label_name == label_name]
|
||||||
|
if label_id is not None:
|
||||||
|
filters.append(plModel.id != label_id)
|
||||||
|
query.filter(*filters)
|
||||||
|
if query.count() > 0:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def update_label(label: plModel, session: Session):
|
||||||
|
"""
|
||||||
|
修改标签名称
|
||||||
|
:param label:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
session.query(plModel).filter_by(id=label.id).update({
|
||||||
|
"label_name": label.label_name
|
||||||
|
})
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def del_label(id: str, session: Session):
|
||||||
|
"""
|
||||||
|
根据标签id删除标签
|
||||||
|
:param id: 标签id
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
row_del = session.query(plModel).filter_by(id=id).delete()
|
||||||
|
return row_del
|
13
app/model/crud/project_type_crud.py
Normal file
13
app/model/crud/project_type_crud.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from app.model.bussiness_model import ProjectType
|
||||||
|
from app.model.schemas.project_type_schemas import ProjectTypeOut
|
||||||
|
|
||||||
|
from sqlalchemy import asc
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
|
def get_list(session: Session):
|
||||||
|
"""获取项目类型列表"""
|
||||||
|
query = session.query(ProjectType).order_by(asc(ProjectType.id))
|
||||||
|
query.filter(ProjectType.type_status == "0")
|
||||||
|
result_list = [ProjectTypeOut.from_orm(project_type).dict() for project_type in query.all()]
|
||||||
|
return result_list
|
@ -1,23 +1,21 @@
|
|||||||
from app.model.model import SysUser
|
from app.model.sys_model import SysUser
|
||||||
from app.model.schemas.sys_user_schemas import SysUserPager, SysUserOut
|
from app.model.schemas.sys_user_schemas import SysUserPager, SysUserOut
|
||||||
from app.common.bcrypt_pw import hash_password
|
from app.common.bcrypt_pw import hash_password
|
||||||
from app.db.page_util import get_pager
|
from app.db.page_util import get_pager
|
||||||
|
|
||||||
from sqlalchemy import and_, desc
|
from sqlalchemy import and_, asc
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
def user_pager(user: SysUserPager, session: Session):
|
def user_pager(user: SysUserPager, session: Session):
|
||||||
query = session.query(SysUser).order_by(desc(SysUser.id))
|
query = session.query(SysUser).order_by(asc(SysUser.id))
|
||||||
filters = []
|
filters = []
|
||||||
if user.username is not None:
|
if user.username is not None:
|
||||||
filters.append(SysUser.username.ilike(f"%{user.username}%"))
|
filters.append(SysUser.username.ilike(f"%{user.username}%"))
|
||||||
if user.dept_id is not None:
|
|
||||||
filters.append(SysUser.dept_id == user.dept_id)
|
|
||||||
if len(filters) > 0:
|
if len(filters) > 0:
|
||||||
query.filter(and_(*filters))
|
query.filter(and_(*filters))
|
||||||
pager = get_pager(query, user.pagerNum, user.pagerSize)
|
pager = get_pager(query, user.pagerNum, user.pagerSize)
|
||||||
pager.data = [SysUserOut.from_orm(user) for user in pager.data]
|
pager.data = [SysUserOut.from_orm(user).dict() for user in pager.data]
|
||||||
return pager
|
return pager
|
||||||
|
|
||||||
|
|
||||||
@ -35,13 +33,13 @@ def get_user_by_id(id: int, session: Session):
|
|||||||
|
|
||||||
def stop_user(user: SysUser, session: Session):
|
def stop_user(user: SysUser, session: Session):
|
||||||
user.user_status = "1"
|
user.user_status = "1"
|
||||||
session.commit();
|
session.commit()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def start_user(user: SysUser, session: Session):
|
def start_user(user: SysUser, session: Session):
|
||||||
user.user_status = "0"
|
user.user_status = "0"
|
||||||
session.commit();
|
session.commit()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,29 +0,0 @@
|
|||||||
from app.db.db_base import DbCommon
|
|
||||||
from sqlalchemy import String,Integer
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
|
|
||||||
|
|
||||||
class SysUser(DbCommon):
|
|
||||||
|
|
||||||
__tablename__ = "sys_user"
|
|
||||||
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
|
||||||
password: Mapped[str] = mapped_column(String(255))
|
|
||||||
dept_id: Mapped[int] = mapped_column(Integer)
|
|
||||||
login_name: Mapped[str] = mapped_column(String(255))
|
|
||||||
avatar: Mapped[str] = mapped_column(String(255))
|
|
||||||
user_status: Mapped[str] = mapped_column(String(10))
|
|
||||||
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
keys = ["username", "password", "dept_id", "login_name", "avatar", "user_status"]
|
|
||||||
keys.extend(super().keys())
|
|
||||||
return keys
|
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
'''
|
|
||||||
内置方法, 当使用obj['name']的形式的时候, 将调用这个方法, 这里返回的结果就是值
|
|
||||||
:param item:
|
|
||||||
:return:
|
|
||||||
'''
|
|
||||||
return getattr(self, item, None)
|
|
15
app/model/schemas/project_image_schemas.py
Normal file
15
app/model/schemas/project_image_schemas.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectImage(BaseModel):
|
||||||
|
id: Optional[int] = Field(None, description="id")
|
||||||
|
project_id: Optional[int] = Field(..., description="项目id")
|
||||||
|
image_url: Optional[str] = Field(..., description="原图路径")
|
||||||
|
thumb_image_url: Optional[str] = Field(..., description="缩略图路径")
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectImagePager(BaseModel):
|
||||||
|
project_id: Optional[int] = Field(..., description="项目id")
|
||||||
|
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||||
|
pagerSize: Optional[int] = Field(10, description="每页数量")
|
32
app/model/schemas/project_info_schemas.py
Normal file
32
app/model/schemas/project_info_schemas.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectInfoIn(BaseModel):
|
||||||
|
"""项目信息输入"""
|
||||||
|
id: Optional[int] = Field(None, description="项目id")
|
||||||
|
project_name: Optional[str] = Field(..., description="项目名称")
|
||||||
|
type_code: Optional[str] = Field(..., description="项目类型编码")
|
||||||
|
description: Optional[str] = Field(None, description="项目描述")
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectInfoOut(BaseModel):
|
||||||
|
"""项目信息输出"""
|
||||||
|
id: Optional[int] = Field(None, description="项目id")
|
||||||
|
project_no: Optional[str] = Field(..., description="项目编号")
|
||||||
|
project_name: Optional[str] = Field(..., description="项目名称")
|
||||||
|
type_code: Optional[str] = Field(..., description="项目类型编码")
|
||||||
|
description: Optional[str] = Field(None, description="项目描述")
|
||||||
|
user_name: Optional[str] = Field(None, description="创建人")
|
||||||
|
train_version: Optional[int] = Field(None, description="训练版本号")
|
||||||
|
project_status: Optional[str] = Field(None, description="项目状态")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectInfoPager(BaseModel):
|
||||||
|
project_name: Optional[str] = Field(None, description="项目名称")
|
||||||
|
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||||
|
pagerSize: Optional[int] = Field(10, description="每页数量")
|
||||||
|
|
12
app/model/schemas/project_label_schemas.py
Normal file
12
app/model/schemas/project_label_schemas.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectLabel(BaseModel):
|
||||||
|
"""项目标签输入输出"""
|
||||||
|
id: Optional[int] = Field(None, description="id")
|
||||||
|
project_id: Optional[int] = Field(..., description="项目id")
|
||||||
|
label_name: Optional[str] = Field(..., description="标签名称")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
26
app/model/schemas/project_type_schemas.py
Normal file
26
app/model/schemas/project_type_schemas.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectTypeIn(BaseModel):
|
||||||
|
"""
|
||||||
|
项目类型输入
|
||||||
|
"""
|
||||||
|
type_code: Optional[str] = Field(..., description="类型code", max_length=20)
|
||||||
|
type_name: Optional[str] = Field(..., description="类型名称", max_length=20, min_length=4)
|
||||||
|
icon_path: Optional[str] = Field(None, description="iconPath", max_length=255)
|
||||||
|
description: Optional[str] = Field(None, description="类型描述", max_length=255)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectTypeOut(BaseModel):
|
||||||
|
"""
|
||||||
|
项目类型输出
|
||||||
|
"""
|
||||||
|
id: Optional[int] = Field(..., description="id")
|
||||||
|
type_code: Optional[str] = Field(..., description="类型code", max_length=20)
|
||||||
|
type_name: Optional[str] = Field(..., description="类型名称", max_length=20, min_length=4)
|
||||||
|
icon_path: Optional[str] = Field(None, description="iconPath", max_length=255)
|
||||||
|
description: Optional[str] = Field(None, description="类型描述", max_length=255)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
@ -3,10 +3,9 @@ from typing import Optional
|
|||||||
|
|
||||||
|
|
||||||
# 用户相关的原型
|
# 用户相关的原型
|
||||||
class SysUserIN(BaseModel):
|
class SysUserIn(BaseModel):
|
||||||
username: Optional[str] = Field(..., description="用户名", max_length=50)
|
username: Optional[str] = Field(..., description="用户名", max_length=50)
|
||||||
password: Optional[str] = Field(..., description="密码", max_length=30, min_length=6)
|
password: Optional[str] = Field(..., description="密码", max_length=30, min_length=6)
|
||||||
dept_id: Optional[str] = Field(None, description="部门id")
|
|
||||||
login_name: Optional[str] = Field(None, description="昵称", max_length=20)
|
login_name: Optional[str] = Field(None, description="昵称", max_length=20)
|
||||||
|
|
||||||
|
|
||||||
@ -18,9 +17,8 @@ class SysUserLogin(BaseModel):
|
|||||||
class SysUserOut(BaseModel):
|
class SysUserOut(BaseModel):
|
||||||
id: Optional[int] = Field(..., description="id")
|
id: Optional[int] = Field(..., description="id")
|
||||||
username: Optional[str] = Field(..., description="用户名")
|
username: Optional[str] = Field(..., description="用户名")
|
||||||
dept_id: Optional[str] = Field(None, description="部门id")
|
|
||||||
dept_name: Optional[str] = Field(None, description="部门名称")
|
|
||||||
login_name: Optional[str] = Field(None, description="昵称")
|
login_name: Optional[str] = Field(None, description="昵称")
|
||||||
|
user_status: Optional[str] = Field(None, description="用户状态")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
@ -34,7 +32,6 @@ class SysUserUpdatePw(BaseModel):
|
|||||||
|
|
||||||
class SysUserPager(BaseModel):
|
class SysUserPager(BaseModel):
|
||||||
username: Optional[str] = Field(None, description="用户名")
|
username: Optional[str] = Field(None, description="用户名")
|
||||||
dept_id: Optional[str] = Field(None, description="部门id")
|
|
||||||
login_name: Optional[str] = Field(None, description="昵称")
|
login_name: Optional[str] = Field(None, description="昵称")
|
||||||
pagerNum: Optional[int] = Field(1, description="当前页码")
|
pagerNum: Optional[int] = Field(1, description="当前页码")
|
||||||
pagerSize: Optional[int] = Field(10, description="每页数量")
|
pagerSize: Optional[int] = Field(10, description="每页数量")
|
||||||
|
15
app/model/sys_model.py
Normal file
15
app/model/sys_model.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from app.db.db_base import DbCommon
|
||||||
|
from sqlalchemy import String
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class SysUser(DbCommon):
|
||||||
|
"""
|
||||||
|
用户表 - 保存用户基本信息
|
||||||
|
"""
|
||||||
|
__tablename__ = "sys_user"
|
||||||
|
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
||||||
|
password: Mapped[str] = mapped_column(String(255))
|
||||||
|
login_name: Mapped[str] = mapped_column(String(255))
|
||||||
|
avatar: Mapped[str] = mapped_column(String(255))
|
||||||
|
user_status: Mapped[str] = mapped_column(String(10))
|
51
app/service/project_service.py
Normal file
51
app/service/project_service.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from app.model.bussiness_model import ProjectImage
|
||||||
|
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
|
||||||
|
from app.model.bussiness_model import ProjectInfo
|
||||||
|
from app.util import random_utils as ru
|
||||||
|
from app.config.config_reader import datasets_url, runs_url, images_url
|
||||||
|
from app.model.crud import project_info_crud as pic
|
||||||
|
from app.util import os_utils as os
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List
|
||||||
|
from fastapi import UploadFile
|
||||||
|
|
||||||
|
|
||||||
|
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
|
||||||
|
"""
|
||||||
|
新建项目,完善数据,并创建对应的文件夹
|
||||||
|
:param info: 项目信息
|
||||||
|
:param session: 数据库session
|
||||||
|
:param user_id: 用户id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
project_info = ProjectInfo(**info.dict())
|
||||||
|
project_info.user_id = user_id
|
||||||
|
project_info.project_no = ru.random_str(6)
|
||||||
|
project_info.project_status = "0"
|
||||||
|
project_info.train_version = 0
|
||||||
|
os.create_folder(datasets_url, project_info.project_no)
|
||||||
|
os.create_folder(runs_url, project_info.project_no)
|
||||||
|
pic.add_project(project_info, session)
|
||||||
|
return project_info
|
||||||
|
|
||||||
|
|
||||||
|
def upload_project_image(session: Session, project_info: ProjectInfoOut, files: List[UploadFile]):
|
||||||
|
"""
|
||||||
|
上传项目的图片
|
||||||
|
:param files: 上传的图片
|
||||||
|
:param project_info: 项目信息
|
||||||
|
:param image:
|
||||||
|
:param session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for file in files:
|
||||||
|
image = ProjectImage()
|
||||||
|
image.project_id = project_info.id
|
||||||
|
# 保存原图
|
||||||
|
path = os.save_images(images_url, project_info.project_no, file=file)
|
||||||
|
image.image_url = path
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
43
app/util/os_utils.py
Normal file
43
app/util/os_utils.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def create_folder(*path):
|
||||||
|
"""根据路径创建文件夹"""
|
||||||
|
folder_path = os.path.join(*path)
|
||||||
|
try:
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"创建文件夹时错误: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_images(*path, file: UploadFile):
|
||||||
|
"""
|
||||||
|
保存上传的图片
|
||||||
|
:param path: 路径
|
||||||
|
:param file: 文件
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
save_path = os.path.join(*path, file.filename)
|
||||||
|
with open(save_path, "wb") as f:
|
||||||
|
for line in file.file:
|
||||||
|
f.write(line)
|
||||||
|
return save_path
|
||||||
|
|
||||||
|
|
||||||
|
def create_thumbnail(input_image_path, out_image_path, size=(116, 70)):
|
||||||
|
"""
|
||||||
|
给图片生成缩略图
|
||||||
|
:param input_image_path:
|
||||||
|
:param out_image_path:
|
||||||
|
:param size: 缩略的尺寸
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
with Image.open(input_image_path) as image:
|
||||||
|
# 使用thumbnail方法生成缩略图,参数size指定缩略图的最大尺寸
|
||||||
|
# 注意:thumbnail方法会保持图片的宽高比不变
|
||||||
|
image.thumbnail(size)
|
||||||
|
|
||||||
|
# 保存生成的缩略图
|
||||||
|
image.save(out_image_path, 'JPEG')
|
12
app/util/random_utils.py
Normal file
12
app/util/random_utils.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
|
||||||
|
def random_str(length=10):
|
||||||
|
"""随机生成自定义长度的小写字母"""
|
||||||
|
letters = string.ascii_lowercase
|
||||||
|
# 使用 random.choices 从 letters 中随机选择 length 个字母,返回一个列表
|
||||||
|
random_letters = random.choices(letters, k=length)
|
||||||
|
# 将列表中的字母连接成一个字符串
|
||||||
|
return ''.join(random_letters)
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
# aicheck_base requirements
|
# aicheck_v2 requirements
|
||||||
# Usage: pip install -r requirements.txt
|
# Usage: pip install -r requirements.txt
|
||||||
|
|
||||||
# API -------------------------------------------------------------------------
|
# API -------------------------------------------------------------------------
|
||||||
@ -17,6 +17,7 @@ pymysql==1.0.2
|
|||||||
pynvml==12.0.0
|
pynvml==12.0.0
|
||||||
requests-toolbelt==1.0.0
|
requests-toolbelt==1.0.0
|
||||||
|
|
||||||
|
|
||||||
# YOLOV5 ----------------------------------------------------------------------
|
# YOLOV5 ----------------------------------------------------------------------
|
||||||
# BASE ------------------------------------------------------------------------
|
# BASE ------------------------------------------------------------------------
|
||||||
gitpython>=3.1.30
|
gitpython>=3.1.30
|
||||||
|
Reference in New Issue
Block a user