项目基础模块代码

This commit is contained in:
2025-02-19 11:47:33 +08:00
parent 3cb2a4c507
commit 31302bcd17
30 changed files with 588 additions and 74 deletions

View File

@ -0,0 +1,86 @@
from app.model.crud import project_type_crud as ptc
from app.model.crud import project_label_crud as plc
from app.model.crud import project_info_crud as pic
from app.service import project_service as ps
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoPager
from app.model.schemas.project_label_schemas import ProjectLabel
from app.db.db_session import get_db
from app.common.jwt_check import get_user_id
from app.common import reponse_code as rc
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
"""项目管理API"""
project = APIRouter()
@project.get("/types")
def get_type_list(session: Session = Depends(get_db)):
"""获取项目类别"""
type_list = ptc.get_list(session)
return rc.response_success(data=type_list)
@project.post("/list")
def project_pager(info: ProjectInfoPager, session: Session = Depends(get_db)):
pager = pic.get_project_pager(info, session)
return rc.response_success_pager(pager)
@project.post("/add")
def add_project(request: Request, info: ProjectInfoIn, session: Session = Depends(get_db)):
"""新建项目"""
if pic.check_project_name(info.project_name, session):
return rc.response_error("已经存在相同名称的项目")
user_id = get_user_id(request)
ps.add_project(info, session, user_id)
return rc.response_success(msg="新建成功")
@project.post("/add_label")
def add_label(label: ProjectLabel, session: Session = Depends(get_db)):
"""
新增标签
:param label:
:param session:
:return:
"""
if plc.check_label_name(label.project_id, label.label_name, session):
return rc.response_error("标签名称已经存在,不能重复")
label_save = ProjectLabel(**label.dict())
plc.add_label(label_save, session)
return rc.response_success(msg="保存成功")
@project.post("/up_label")
def up_label(label: ProjectLabel, session: Session = Depends(get_db)):
"""
修改标签
:param label:
:param session:
:return:
"""
if plc.check_label_name(label.project_id, label.label_name, session, label.id):
return rc.response_error("修改的标签名称已经存在,不能重复")
label_save = ProjectLabel(**label.dict())
plc.update_label(label_save, session)
return rc.response_success(msg="修改成功")
@project.post("/del_label")
def del_label(label: ProjectLabel, session: Session = Depends(get_db)):
"""
删除标签
:param label:
:param session:
:return:
"""
row_del = plc.update_label(label.id, session)
if row_del > 0:
return rc.response_success(msg="删除成功")
else:
return rc.response_error("删除失败")

View File

@ -1,17 +1,18 @@
from typing import List from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import UploadFile from fastapi import UploadFile
import os import os
upload = APIRouter() from app.config.config_reader import images_url
upload_files = APIRouter()
@upload.post("/") @upload_files.post("/")
def upload_file(files: List[UploadFile]): def upload(files: List[UploadFile], project_no: str):
paths = [] paths = []
for file in files: for file in files:
path = os.path.join("images", file.filename) path = os.path.join(images_url, project_no, file.filename)
with open(path, "wb") as f: with open(path, "wb") as f:
for line in file.file: for line in file.file:
f.write(line) f.write(line)

View File

@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from app.model.schemas.sys_user_schemas import SysUserOut, SysUserIN, SysUserPager from app.model.schemas.sys_user_schemas import SysUserOut, SysUserIn, SysUserPager
from app.common import reponse_code as rc from app.common import reponse_code as rc
from app.model.crud import sys_user_crud as us from app.model.crud import sys_user_crud as us
from app.model.model import SysUser from app.model.sys_model import SysUser
from app.common.redis_cli import redis_conn from app.common.redis_cli import redis_conn
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -20,7 +20,7 @@ def user_pager(user: SysUserPager, session: Session = Depends(get_db)):
@user.post("/") @user.post("/")
def add_user(user: SysUserIN, session: Session = Depends(get_db)): def add_user(user: SysUserIn, session: Session = Depends(get_db)):
""" """
新增用户 新增用户
:param session: :param session:
@ -49,7 +49,7 @@ def get_user(id: int, session: Session = Depends(get_db)):
user = us.get_user_by_id(id, session) user = us.get_user_by_id(id, session)
if user is None: if user is None:
return rc.response_success(data=None) return rc.response_success(data=None)
user_out = SysUserOut(**dict(user)) user_out = SysUserOut.from_orm(user)
return rc.response_success(data=user_out.dict()) return rc.response_success(data=user_out.dict())

View File

@ -4,6 +4,11 @@ from fastapi.middleware.cors import CORSMiddleware
from app.application.token_middleware import TokenMiddleware from app.application.token_middleware import TokenMiddleware
from app.application.logger_middleware import LoggerMiddleware from app.application.logger_middleware import LoggerMiddleware
from app.api.common.upload_file import upload_files
from app.api.sys.login_api import login
from app.api.sys.sys_user_api import user
from app.api.business.project_api import project
my_app = FastAPI() my_app = FastAPI()
@ -18,6 +23,14 @@ my_app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
#注意中间的顺序,这个地方是倒序执行的 '''
注意中间的顺序,这个地方是倒序执行的
'''
my_app.add_middleware(LoggerMiddleware) my_app.add_middleware(LoggerMiddleware)
my_app.add_middleware(TokenMiddleware) my_app.add_middleware(TokenMiddleware)
my_app.include_router(user, prefix="/user", tags=["用户管理API"])
my_app.include_router(login, prefix="/login", tags=["用户登录接口"])
my_app.include_router(upload_files, prefix="/upload", tags=["文件上传API"])
my_app.include_router(project, prefix="/proj", tags=["项目管理API"])

View File

@ -0,0 +1,10 @@
from fastapi import HTTPException, Request
from app.common.reponse_code import response_error
from app import my_app
"""全局异常处理"""
@my_app.exception_handlers(HTTPException)
async def http_exception(request: Request, he: HTTPException):
return response_error(request.url + "出现异常:" + he.detail)

View File

@ -2,8 +2,8 @@ from fastapi import status
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from urllib.request import Request from urllib.request import Request
from jwt import PyJWTError from jwt import PyJWTError
from common import reponse_code as rc from app.common import reponse_code as rc
from common import jwt_check as jc from app.common import jwt_check as jc

View File

@ -2,6 +2,7 @@ import datetime
import jwt import jwt
from app.common.redis_cli import redis_conn from app.common.redis_cli import redis_conn
from fastapi import Request
# 过期时间单位S # 过期时间单位S
exp = 6000 exp = 6000
@ -48,3 +49,11 @@ def check_token(token: str):
raise jwt.ExpiredSignatureError("Expired Token") raise jwt.ExpiredSignatureError("Expired Token")
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
raise jwt.InvalidTokenError("Invalid Token") raise jwt.InvalidTokenError("Invalid Token")
def get_user_id(request: Request):
"""根据Request请求获取token"""
token = request.headers.get("Authorization")
decoded_payload = check_token(token)
user_id = decoded_payload['user_id']
return user_id

View File

@ -12,6 +12,7 @@ def response_code_view(code: int,msg: str) -> Response:
} }
) )
def response_success(msg: str = "查询成功", data: object = None): def response_success(msg: str = "查询成功", data: object = None):
return JSONResponse( return JSONResponse(
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,

View File

@ -9,3 +9,10 @@ password = sdust2020
[log] [log]
dir = D:\syg\workspace\logs dir = D:\syg\workspace\logs
[yolo]
datasets_url = D:\syg\yolov5\datasets
runs_url = D:\syg\yolov5\runs
[images]
image_url = D:\syg\images

View File

@ -1,11 +0,0 @@
[mysql]
database_url = mysql+pymysql://root:root@localhost:3306/sun
[redis]
host = localhost
port = 6379
db = 0
password = 123456
[log]
dir = /Users/macbookpro/sunyg/workspace/logs

View File

@ -0,0 +1,18 @@
[mysql]
database_url = mysql+pymysql://root:root@localhost:3306/sun
[redis]
host = localhost
port = 6379
db = 0
password = 123456
[log]
dir = /Users/macbookpro/sunyg/workspace/logs
[yolo]
datasets_url = /home/yolov5/datasets
runs_url = /home/yolov5/runs
[images]
image_url = /home/images

View File

@ -3,6 +3,8 @@ import os
env = "dev" env = "dev"
# env = "prod"
script_directory = os.path.dirname(os.path.abspath(__file__)) script_directory = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(script_directory, f'application_config_{env}.ini') config_path = os.path.join(script_directory, f'application_config_{env}.ini')
@ -19,3 +21,8 @@ redis_db = config.get('redis', 'db')
redis_password = config.get('redis', 'password') redis_password = config.get('redis', 'password')
log_dir = config.get('log', 'dir') log_dir = config.get('log', 'dir')
datasets_url = config.get('yolo', 'datasets_url')
runs_url = config.get('yolo', 'runs_url')
images_url = config.get('images', 'image_url')

View File

@ -11,10 +11,6 @@ class DbCommon(Base):
create_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow()) create_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow())
update_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow(), onupdate=datetime.datetime.utcnow()) update_time: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow(), onupdate=datetime.datetime.utcnow())
def keys(self):
return ["id", "create_time", "update_time"]
@declared_attr @declared_attr
def __tablename__(cls): def __tablename__(cls):
return cls.__name__.lower() return cls.__name__.lower()

View File

@ -0,0 +1,42 @@
from app.db.db_base import DbCommon
from sqlalchemy import String, Integer
from sqlalchemy.orm import Mapped, mapped_column
class ProjectType(DbCommon):
"""
项目类别表 - 标识项目的类型目前存在的目标识别OCR识别瑕疵检测图像分类
"""
__tablename__ = "project_type"
type_code = Mapped[str] = mapped_column(String(20), unique=True, nullable=False)
type_name = Mapped[str] = mapped_column(String(20))
icon_path = Mapped[str] = mapped_column(String(255))
description = Mapped[str] = mapped_column(String(255))
type_status = Mapped[str] = mapped_column(String(10))
class ProjectInfo(DbCommon):
"""项目信息表"""
__tablename__ = "project_info"
project_no = Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
project_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
type_code = Mapped[str] = mapped_column(String(10))
description = Mapped[str] = mapped_column(String(255))
project_status = Mapped[str] = mapped_column(String(10))
user_id = Mapped[int] = mapped_column(Integer)
train_version = Mapped[int] = mapped_column(Integer)
class ProjectLabel(DbCommon):
"""项目标签表"""
__tablename__ = "project_label"
label_name = Mapped[str] = mapped_column(String(32), unique=True, nullable=False)
project_id = Mapped[int] = mapped_column(Integer, nullable=False)
class ProjectImage(DbCommon):
"""项目图片表"""
__tablename__ = "project_image"
image_url = Mapped[str] = mapped_column(String(255), nullable=False)
thumb_image_url = Mapped[str] = mapped_column(String(255), nullable=False)
project_id = Mapped[int] = mapped_column(Integer)

View File

@ -0,0 +1,35 @@
from sqlalchemy.orm import Session
from sqlalchemy import asc
from app.model.bussiness_model import ProjectImage as piModel
from app.model.schemas.project_image_schemas import ProjectImage, ProjectImagePager
from app.db.page_util import get_pager
def get_image_pager(image: ProjectImagePager, session: Session):
query = session.query(piModel).filter(piModel.project_id == image.project_id).order_by(asc(piModel.id))
pager = get_pager(query, image.pagerNum, image.pagerSize)
pager.data = [ProjectImage.from_orm(image) for image in pager.data]
return pager
def get_image_list(image: ProjectImage, session: Session):
query = session.query(piModel).filter(piModel.project_id == image.project_id).order_by(asc(piModel.id))
image_list = [ProjectImage.from_orm(image) for image in query.all()]
return image_list
def add_image(image: ProjectImage, session: Session):
session.add(image)
session.commit()
return image
def del_image(id: str, session: Session):
row_del = session.query(piModel).filter_by(id=id).delete()
session.commit()
return row_del

View File

@ -0,0 +1,42 @@
from sqlalchemy.orm import Session
from sqlalchemy import desc
from app.model.bussiness_model import ProjectInfo
from app.model.schemas.project_info_schemas import ProjectInfoOut
from app.model.schemas.project_info_schemas import ProjectInfoPager
from app.db.page_util import get_pager
def get_project_pager(info: ProjectInfoPager, session: Session):
"""分页查询项目信息"""
query = session.query(ProjectInfo).order_by(desc(ProjectInfo.id))
filters = []
if info.project_name is not None:
filters.append(ProjectInfo.project_name.ilike(f"%{info.project_name}%"))
if len(filters) > 0:
query.filter(*filters)
pager = get_pager(query, info.pagerNum, info.pagerSize)
pager.data = [ProjectInfoOut.from_orm(info).dict() for info in pager.data]
return pager
def get_project_by_id(id: str, session: Session):
info = session.query(ProjectInfo).filter_by(id=id).first()
info_out = ProjectInfoOut.from_orm(info)
return info_out
def add_project(info: ProjectInfo, session: Session):
"""新建项目,并在对应文件夹下面创建文件夹"""
session.add(info)
session.commit()
return info
def check_project_name(project_name: str, session: Session):
"""检验是否存在重名的项目名称"""
count = session.query(ProjectInfo).filter(ProjectInfo.project_name == project_name).count()
if count > 0:
return True
else:
return False

View File

@ -0,0 +1,72 @@
from sqlalchemy.orm import Session
from app.model.bussiness_model import ProjectLabel as plModel
from app.model.schemas.project_label_schemas import ProjectLabel
def get_label_list(project_id: int, session: Session):
"""
根绝项目id获取标签列表
:param project_id: 项目id
:param session:
:return:
"""
label_list = session.query(plModel).filter(plModel.project_id == project_id).all()
label_list = [ProjectLabel.from_orm(label) for label in label_list]
return label_list
def add_label(label: plModel, session: Session):
"""
新增标签
:param label:
:param session:
:return:
"""
session.add(label)
session.commit()
return label.id
def check_label_name(project_id: int, label_name: str, session: Session, label_id: int = None):
"""
检验标签名称是否存在
:param label_id:
:param project_id: 项目id
:param label_name: 标签名称
:param session:
:return:
"""
query = session.query(plModel)
filters = [plModel.project_id == project_id, plModel.label_name == label_name]
if label_id is not None:
filters.append(plModel.id != label_id)
query.filter(*filters)
if query.count() > 0:
return False
else:
return True
def update_label(label: plModel, session: Session):
"""
修改标签名称
:param label:
:param session:
:return:
"""
session.query(plModel).filter_by(id=label.id).update({
"label_name": label.label_name
})
session.commit()
def del_label(id: str, session: Session):
"""
根据标签id删除标签
:param id: 标签id
:param session:
:return:
"""
row_del = session.query(plModel).filter_by(id=id).delete()
return row_del

View File

@ -0,0 +1,13 @@
from app.model.bussiness_model import ProjectType
from app.model.schemas.project_type_schemas import ProjectTypeOut
from sqlalchemy import asc
from sqlalchemy.orm import Session
def get_list(session: Session):
"""获取项目类型列表"""
query = session.query(ProjectType).order_by(asc(ProjectType.id))
query.filter(ProjectType.type_status == "0")
result_list = [ProjectTypeOut.from_orm(project_type).dict() for project_type in query.all()]
return result_list

View File

@ -1,23 +1,21 @@
from app.model.model import SysUser from app.model.sys_model import SysUser
from app.model.schemas.sys_user_schemas import SysUserPager, SysUserOut from app.model.schemas.sys_user_schemas import SysUserPager, SysUserOut
from app.common.bcrypt_pw import hash_password from app.common.bcrypt_pw import hash_password
from app.db.page_util import get_pager from app.db.page_util import get_pager
from sqlalchemy import and_, desc from sqlalchemy import and_, asc
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
def user_pager(user: SysUserPager, session: Session): def user_pager(user: SysUserPager, session: Session):
query = session.query(SysUser).order_by(desc(SysUser.id)) query = session.query(SysUser).order_by(asc(SysUser.id))
filters = [] filters = []
if user.username is not None: if user.username is not None:
filters.append(SysUser.username.ilike(f"%{user.username}%")) filters.append(SysUser.username.ilike(f"%{user.username}%"))
if user.dept_id is not None:
filters.append(SysUser.dept_id == user.dept_id)
if len(filters) > 0: if len(filters) > 0:
query.filter(and_(*filters)) query.filter(and_(*filters))
pager = get_pager(query, user.pagerNum, user.pagerSize) pager = get_pager(query, user.pagerNum, user.pagerSize)
pager.data = [SysUserOut.from_orm(user) for user in pager.data] pager.data = [SysUserOut.from_orm(user).dict() for user in pager.data]
return pager return pager
@ -35,13 +33,13 @@ def get_user_by_id(id: int, session: Session):
def stop_user(user: SysUser, session: Session): def stop_user(user: SysUser, session: Session):
user.user_status = "1" user.user_status = "1"
session.commit(); session.commit()
return user return user
def start_user(user: SysUser, session: Session): def start_user(user: SysUser, session: Session):
user.user_status = "0" user.user_status = "0"
session.commit(); session.commit()
return user return user

View File

@ -1,29 +0,0 @@
from app.db.db_base import DbCommon
from sqlalchemy import String,Integer
from sqlalchemy.orm import Mapped, mapped_column
class SysUser(DbCommon):
__tablename__ = "sys_user"
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
password: Mapped[str] = mapped_column(String(255))
dept_id: Mapped[int] = mapped_column(Integer)
login_name: Mapped[str] = mapped_column(String(255))
avatar: Mapped[str] = mapped_column(String(255))
user_status: Mapped[str] = mapped_column(String(10))
def keys(self):
keys = ["username", "password", "dept_id", "login_name", "avatar", "user_status"]
keys.extend(super().keys())
return keys
def __getitem__(self, item):
'''
内置方法, 当使用obj['name']的形式的时候, 将调用这个方法, 这里返回的结果就是值
:param item:
:return:
'''
return getattr(self, item, None)

View File

@ -0,0 +1,15 @@
from pydantic import BaseModel, Field
from typing import Optional
class ProjectImage(BaseModel):
id: Optional[int] = Field(None, description="id")
project_id: Optional[int] = Field(..., description="项目id")
image_url: Optional[str] = Field(..., description="原图路径")
thumb_image_url: Optional[str] = Field(..., description="缩略图路径")
class ProjectImagePager(BaseModel):
project_id: Optional[int] = Field(..., description="项目id")
pagerNum: Optional[int] = Field(1, description="当前页码")
pagerSize: Optional[int] = Field(10, description="每页数量")

View File

@ -0,0 +1,32 @@
from pydantic import BaseModel, Field
from typing import Optional
class ProjectInfoIn(BaseModel):
"""项目信息输入"""
id: Optional[int] = Field(None, description="项目id")
project_name: Optional[str] = Field(..., description="项目名称")
type_code: Optional[str] = Field(..., description="项目类型编码")
description: Optional[str] = Field(None, description="项目描述")
class ProjectInfoOut(BaseModel):
"""项目信息输出"""
id: Optional[int] = Field(None, description="项目id")
project_no: Optional[str] = Field(..., description="项目编号")
project_name: Optional[str] = Field(..., description="项目名称")
type_code: Optional[str] = Field(..., description="项目类型编码")
description: Optional[str] = Field(None, description="项目描述")
user_name: Optional[str] = Field(None, description="创建人")
train_version: Optional[int] = Field(None, description="训练版本号")
project_status: Optional[str] = Field(None, description="项目状态")
class Config:
orm_mode = True
class ProjectInfoPager(BaseModel):
project_name: Optional[str] = Field(None, description="项目名称")
pagerNum: Optional[int] = Field(1, description="当前页码")
pagerSize: Optional[int] = Field(10, description="每页数量")

View File

@ -0,0 +1,12 @@
from pydantic import BaseModel, Field
from typing import Optional
class ProjectLabel(BaseModel):
"""项目标签输入输出"""
id: Optional[int] = Field(None, description="id")
project_id: Optional[int] = Field(..., description="项目id")
label_name: Optional[str] = Field(..., description="标签名称")
class Config:
orm_mode = True

View File

@ -0,0 +1,26 @@
from pydantic import BaseModel, Field
from typing import Optional
class ProjectTypeIn(BaseModel):
"""
项目类型输入
"""
type_code: Optional[str] = Field(..., description="类型code", max_length=20)
type_name: Optional[str] = Field(..., description="类型名称", max_length=20, min_length=4)
icon_path: Optional[str] = Field(None, description="iconPath", max_length=255)
description: Optional[str] = Field(None, description="类型描述", max_length=255)
class ProjectTypeOut(BaseModel):
"""
项目类型输出
"""
id: Optional[int] = Field(..., description="id")
type_code: Optional[str] = Field(..., description="类型code", max_length=20)
type_name: Optional[str] = Field(..., description="类型名称", max_length=20, min_length=4)
icon_path: Optional[str] = Field(None, description="iconPath", max_length=255)
description: Optional[str] = Field(None, description="类型描述", max_length=255)
class Config:
orm_mode = True

View File

@ -3,10 +3,9 @@ from typing import Optional
# 用户相关的原型 # 用户相关的原型
class SysUserIN(BaseModel): class SysUserIn(BaseModel):
username: Optional[str] = Field(..., description="用户名", max_length=50) username: Optional[str] = Field(..., description="用户名", max_length=50)
password: Optional[str] = Field(..., description="密码", max_length=30, min_length=6) password: Optional[str] = Field(..., description="密码", max_length=30, min_length=6)
dept_id: Optional[str] = Field(None, description="部门id")
login_name: Optional[str] = Field(None, description="昵称", max_length=20) login_name: Optional[str] = Field(None, description="昵称", max_length=20)
@ -18,9 +17,8 @@ class SysUserLogin(BaseModel):
class SysUserOut(BaseModel): class SysUserOut(BaseModel):
id: Optional[int] = Field(..., description="id") id: Optional[int] = Field(..., description="id")
username: Optional[str] = Field(..., description="用户名") username: Optional[str] = Field(..., description="用户名")
dept_id: Optional[str] = Field(None, description="部门id")
dept_name: Optional[str] = Field(None, description="部门名称")
login_name: Optional[str] = Field(None, description="昵称") login_name: Optional[str] = Field(None, description="昵称")
user_status: Optional[str] = Field(None, description="用户状态")
class Config: class Config:
orm_mode = True orm_mode = True
@ -34,7 +32,6 @@ class SysUserUpdatePw(BaseModel):
class SysUserPager(BaseModel): class SysUserPager(BaseModel):
username: Optional[str] = Field(None, description="用户名") username: Optional[str] = Field(None, description="用户名")
dept_id: Optional[str] = Field(None, description="部门id")
login_name: Optional[str] = Field(None, description="昵称") login_name: Optional[str] = Field(None, description="昵称")
pagerNum: Optional[int] = Field(1, description="当前页码") pagerNum: Optional[int] = Field(1, description="当前页码")
pagerSize: Optional[int] = Field(10, description="每页数量") pagerSize: Optional[int] = Field(10, description="每页数量")

15
app/model/sys_model.py Normal file
View File

@ -0,0 +1,15 @@
from app.db.db_base import DbCommon
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column
class SysUser(DbCommon):
"""
用户表 - 保存用户基本信息
"""
__tablename__ = "sys_user"
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
password: Mapped[str] = mapped_column(String(255))
login_name: Mapped[str] = mapped_column(String(255))
avatar: Mapped[str] = mapped_column(String(255))
user_status: Mapped[str] = mapped_column(String(10))

View File

@ -0,0 +1,51 @@
from app.model.bussiness_model import ProjectImage
from app.model.schemas.project_info_schemas import ProjectInfoIn, ProjectInfoOut
from app.model.bussiness_model import ProjectInfo
from app.util import random_utils as ru
from app.config.config_reader import datasets_url, runs_url, images_url
from app.model.crud import project_info_crud as pic
from app.util import os_utils as os
from sqlalchemy.orm import Session
from typing import List
from fastapi import UploadFile
def add_project(info: ProjectInfoIn, session: Session, user_id: int):
"""
新建项目,完善数据,并创建对应的文件夹
:param info: 项目信息
:param session: 数据库session
:param user_id: 用户id
:return:
"""
project_info = ProjectInfo(**info.dict())
project_info.user_id = user_id
project_info.project_no = ru.random_str(6)
project_info.project_status = "0"
project_info.train_version = 0
os.create_folder(datasets_url, project_info.project_no)
os.create_folder(runs_url, project_info.project_no)
pic.add_project(project_info, session)
return project_info
def upload_project_image(session: Session, project_info: ProjectInfoOut, files: List[UploadFile]):
"""
上传项目的图片
:param files: 上传的图片
:param project_info: 项目信息
:param image:
:param session:
:return:
"""
for file in files:
image = ProjectImage()
image.project_id = project_info.id
# 保存原图
path = os.save_images(images_url, project_info.project_no, file=file)
image.image_url = path

43
app/util/os_utils.py Normal file
View File

@ -0,0 +1,43 @@
import os
from fastapi import UploadFile
from PIL import Image
def create_folder(*path):
"""根据路径创建文件夹"""
folder_path = os.path.join(*path)
try:
os.makedirs(path, exist_ok=True)
except Exception as e:
print(f"创建文件夹时错误: {e}")
def save_images(*path, file: UploadFile):
"""
保存上传的图片
:param path: 路径
:param file: 文件
:return:
"""
save_path = os.path.join(*path, file.filename)
with open(save_path, "wb") as f:
for line in file.file:
f.write(line)
return save_path
def create_thumbnail(input_image_path, out_image_path, size=(116, 70)):
"""
给图片生成缩略图
:param input_image_path:
:param out_image_path:
:param size: 缩略的尺寸
:return:
"""
with Image.open(input_image_path) as image:
# 使用thumbnail方法生成缩略图参数size指定缩略图的最大尺寸
# 注意thumbnail方法会保持图片的宽高比不变
image.thumbnail(size)
# 保存生成的缩略图
image.save(out_image_path, 'JPEG')

12
app/util/random_utils.py Normal file
View File

@ -0,0 +1,12 @@
import random
import string
def random_str(length=10):
"""随机生成自定义长度的小写字母"""
letters = string.ascii_lowercase
# 使用 random.choices 从 letters 中随机选择 length 个字母,返回一个列表
random_letters = random.choices(letters, k=length)
# 将列表中的字母连接成一个字符串
return ''.join(random_letters)

View File

@ -1,4 +1,4 @@
# aicheck_base requirements # aicheck_v2 requirements
# Usage: pip install -r requirements.txt # Usage: pip install -r requirements.txt
# API ------------------------------------------------------------------------- # API -------------------------------------------------------------------------
@ -17,6 +17,7 @@ pymysql==1.0.2
pynvml==12.0.0 pynvml==12.0.0
requests-toolbelt==1.0.0 requests-toolbelt==1.0.0
# YOLOV5 ---------------------------------------------------------------------- # YOLOV5 ----------------------------------------------------------------------
# BASE ------------------------------------------------------------------------ # BASE ------------------------------------------------------------------------
gitpython>=3.1.30 gitpython>=3.1.30