model对象关联信息检查封装,抽出输入参数解析方法

This commit is contained in:
chenzuoqing 2021-12-08 20:30:07 +08:00
parent 4e565ed4d8
commit 34511da090
7 changed files with 250 additions and 99 deletions

View File

@ -1,6 +1,6 @@
from typing import List, Dict from typing import List, Dict
from flask import request from flask import request, current_app as app
from flask_restful import abort, Resource, marshal, fields, reqparse from flask_restful import abort, Resource, marshal, fields, reqparse
from common.utils import abort_response from common.utils import abort_response
@ -44,8 +44,11 @@ class ModelViewBase(Resource):
# 是否分页 # 是否分页
paging = True paging = True
# 需要检查重复的字段 # 需要检查重复的字段
# TODO 可能还需要多个字段组合的判断并设默认值dict可能比较合适 # ["field_name", ("filed_name1", "filed_name2"), {"f1": "default_val", "f2": "default_val"}]
uniq_fields = [] uniq_fields = []
# 需要检查是否存在的关联字段, required为false时不判断空值
# [("host_id", models.Host, required), ("xx_id", models.xx, False)]
relation_fields = []
# method_decorators 可以用作权限校验,参考 `flask_restful.Resource` # method_decorators 可以用作权限校验,参考 `flask_restful.Resource`
# 参数是装饰器列表(未指定方法将会对所有请求方法生效) # 参数是装饰器列表(未指定方法将会对所有请求方法生效)
@ -116,6 +119,42 @@ class ModelViewBase(Resource):
def validate_fields(self, args: dict, create=True) -> dict: def validate_fields(self, args: dict, create=True) -> dict:
return args return args
@staticmethod
def validate_relation_pk(val, model, key, pk="id", need_hex=True):
"""
校验关联主键适合非mongo外键字段关联但保存有关联对象的信息的
:param val: 解析过的参数字典
:param model: mongo 模型
:param key: 检查的字段
:param pk: 当前模型保存的key对应关联模型中的字段名默认 `id`
:param need_hex: 是否检查值为16进制字符串如果是 mongo `id` 需要是 16 进制字符串
:return:
"""
try:
if need_hex:
assert int(val, 16), f"{key} 不合法"
assert model.objects(**{pk: val}).first(), f"{key} 不存在"
except (AssertionError, ValueError, TypeError):
app.logger.exception(f"{key} 异常")
abort_response(400, 1002, msg=f"{key} 不存在")
def validate_relation_fields(self, args):
"""
注意请在解析 args 时对必须的字段设置为 required=True否则为 None 不判断
校验关联字段请设置 `relation_fields`(("host_id", models.Host, required), )
:param args: 请求参数
:return:
"""
if not self.relation_fields:
return
for field, model, required in self.relation_fields:
val = args.get(field)
# 判断 required: 参数有提交值才校验,未在参数中或零值不校验
if not required and not val:
continue
self.validate_relation_pk(val, model, field)
class ListMixin(ModelViewBase): class ListMixin(ModelViewBase):
@ -148,6 +187,9 @@ class CreateMixin(ModelViewBase):
args = self.request_parse.parse_args() args = self.request_parse.parse_args()
validated_data = self.validate_fields(args, create=True) validated_data = self.validate_fields(args, create=True)
# 校验关联字段
self.validate_relation_fields(validated_data)
# 创建前钩子, # 创建前钩子,
validated_data = self.pre_create(validated_data) validated_data = self.pre_create(validated_data)
@ -155,31 +197,17 @@ class CreateMixin(ModelViewBase):
self.validate_uniq_fields(args) self.validate_uniq_fields(args)
# 保存对象 # 保存对象
try:
obj = self.model(**validated_data) obj = self.model(**validated_data)
obj.save() obj.save()
except Exception as e:
app.logger.exception(f"{self.model} 创建对象失败! data={args}")
abort_response(500, 1500, msg=f"保存对象失败!{str(e)}")
return
# 返回创建信息 # 返回创建信息
return marshal(obj, self.fields) return marshal(obj, self.fields)
# TODO 模型校验
# class CreateFormMixin(ModelViewBase):
#
# def pre_create(self, args):
# """创建前钩子,接收参数,可以对参数进行处理,最后保存此方法返回的数据"""
# return args
#
# def post(self):
# """创建对象"""
# self.form = self.form if self.form else model_form(self.model)
# print("===> json", request.json)
# form = self.form(request.json, meta={'csrf': False})
# if not form.validate():
# print(form.errors)
# return {"msg": "error"}
# return {"msg": "ok"}
class RetrieveMixin(ModelViewBase): class RetrieveMixin(ModelViewBase):
def get(self, pk): def get(self, pk):
@ -205,16 +233,23 @@ class UpdateMixin(ModelViewBase):
# 解析参数 # 解析参数
args = self.request_parse.parse_args() args = self.request_parse.parse_args()
validated_data = self.validate_fields(args, create=False) validated_data = self.validate_fields(args, create=False)
# 校验关联字段
self.validate_relation_fields(validated_data)
validated_data = self.pre_update(obj, validated_data) validated_data = self.pre_update(obj, validated_data)
# 检查重复检查提交的唯一字段值是否有存在的对象与当前对象id不同认为异常 # 检查重复检查提交的唯一字段值是否有存在的对象与当前对象id不同认为异常
self.validate_uniq_fields(args, obj=obj) self.validate_uniq_fields(args, obj=obj)
# 更新对象、保存 # 更新对象、保存
try:
obj.update(**validated_data) obj.update(**validated_data)
obj.save() obj.save()
# 重新读取数据 # 重新读取数据
obj.reload() obj.reload()
except Exception as e:
app.logger.exception(f"{self.model} 保存对象失败pk={pk} data={args}")
abort_response(500, 1500, msg=f"保存对象失败!{str(e)}")
return
return marshal(obj, self.fields) return marshal(obj, self.fields)

View File

@ -0,0 +1,57 @@
import datetime
from flask_restful import reqparse
class HostParse:
model = None
request_parse = None
uniq_fields = ("public_ip", "minion_id")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
# 创建时必须,修改时可选
# self.request_parse.add_argument("public_ip", type=str, required=True,
# help='not public_ip provided', location='json')
self.request_parse.add_argument("private_ip", required=False, type=str, location='json')
self.request_parse.add_argument("minion_id", required=False, type=str, location='json')
self.request_parse.add_argument("weights", required=False, type=int, location='json')
self.request_parse.add_argument("cpu_num", required=False, type=int, location='json')
self.request_parse.add_argument("cpu_core", required=False, type=int, location='json')
self.request_parse.add_argument("memory", required=False, type=int, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
def validate_fields(self, args: dict, create=True) -> dict:
if create:
args["created"] = datetime.datetime.now()
else:
if "created" in args:
args.pop("created")
return args
class DatabaseServerParse:
request_parse = None
# 给子类用的唯一字段,用于校验
uniq_fields = ("name", "host", "domain")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
self.request_parse.add_argument("domain", required=False, type=str, location='json')
self.request_parse.add_argument("data", required=False, type=dict, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
class MiddlewareParse:
request_parse = None
# 给子类用的唯一字段,用于校验
uniq_fields = ("name", "host")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
# self.request_parse.add_argument("domain", required=False, type=str, location='json')
self.request_parse.add_argument("data", required=False, type=dict, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')

View File

@ -5,39 +5,10 @@ from flask_restful import reqparse
from models.asset import fields as assetField from models.asset import fields as assetField
from models.asset import models as assetModel from models.asset import models as assetModel
from common.views import ListCreateViewSet, DetailViewSet from common.views import ListCreateViewSet, DetailViewSet
from controller.asset import parsers
class HostParse: class HostViews(parsers.HostParse, ListCreateViewSet):
model = None
request_parse = None
uniq_fields = ("public_ip", "minion_id")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
# 创建时必须,修改时可选
# self.request_parse.add_argument("public_ip", type=str, required=True,
# help='not public_ip provided', location='json')
self.request_parse.add_argument("private_ip", required=False, type=str, location='json')
self.request_parse.add_argument("minion_id", required=False, type=str, location='json')
self.request_parse.add_argument("weights", required=False, type=int, location='json')
self.request_parse.add_argument("cpu_num", required=False, type=int, location='json')
self.request_parse.add_argument("cpu_core", required=False, type=int, location='json')
self.request_parse.add_argument("memory", required=False, type=int, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
def pre_create(self, args):
# 标记创建时间
args["created"] = datetime.datetime.now()
return args
def pre_update(self, obj, args):
if "created" in args:
args.pop("created")
return args
class HostViews(HostParse, ListCreateViewSet):
model = assetModel.Host model = assetModel.Host
fields = assetField.HostFields fields = assetField.HostFields
@ -48,7 +19,7 @@ class HostViews(HostParse, ListCreateViewSet):
super(HostViews, self).__init__() super(HostViews, self).__init__()
class HostDetailViews(HostParse, DetailViewSet): class HostDetailViews(parsers.HostParse, DetailViewSet):
model = assetModel.Host model = assetModel.Host
fields = assetField.HostFields fields = assetField.HostFields
@ -58,20 +29,7 @@ class HostDetailViews(HostParse, DetailViewSet):
super(HostDetailViews, self).__init__() super(HostDetailViews, self).__init__()
class DatabaseServerParse: class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
request_parse = None
# 给子类用的唯一字段,用于校验
uniq_fields = ("name", "host", "domain")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
self.request_parse.add_argument("domain", required=False, type=str, location='json')
self.request_parse.add_argument("data", required=False, type=dict, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
class MySQLInstanceViews(DatabaseServerParse, ListCreateViewSet):
model = assetModel.MySQLInstance model = assetModel.MySQLInstance
fields = assetField.MySQLInstanceFields fields = assetField.MySQLInstanceFields
@ -90,7 +48,7 @@ class MySQLInstanceViews(DatabaseServerParse, ListCreateViewSet):
super(MySQLInstanceViews, self).__init__() super(MySQLInstanceViews, self).__init__()
class MySQLInstanceDetail(DatabaseServerParse, DetailViewSet): class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet):
model = assetModel.MySQLInstance model = assetModel.MySQLInstance
fields = assetField.MySQLInstanceFields fields = assetField.MySQLInstanceFields
@ -110,7 +68,7 @@ class MySQLInstanceDetail(DatabaseServerParse, DetailViewSet):
super(MySQLInstanceDetail, self).__init__() super(MySQLInstanceDetail, self).__init__()
class RedisInstanceViews(DatabaseServerParse, ListCreateViewSet): class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
model = assetModel.RedisInstance model = assetModel.RedisInstance
fields = assetField.RedisInstanceFields fields = assetField.RedisInstanceFields
@ -127,7 +85,7 @@ class RedisInstanceViews(DatabaseServerParse, ListCreateViewSet):
super(RedisInstanceViews, self).__init__() super(RedisInstanceViews, self).__init__()
class RedisInstanceDetail(DatabaseServerParse, DetailViewSet): class RedisInstanceDetail(parsers.DatabaseServerParse, DetailViewSet):
model = assetModel.RedisInstance model = assetModel.RedisInstance
fields = assetField.RedisInstanceFields fields = assetField.RedisInstanceFields
@ -144,20 +102,7 @@ class RedisInstanceDetail(DatabaseServerParse, DetailViewSet):
super().__init__() super().__init__()
class MiddlewareParse: class NginxInstanceViews(parsers.MiddlewareParse, ListCreateViewSet):
request_parse = None
# 给子类用的唯一字段,用于校验
uniq_fields = ("name", "host")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
# self.request_parse.add_argument("domain", required=False, type=str, location='json')
self.request_parse.add_argument("data", required=False, type=dict, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
class NginxInstanceViews(MiddlewareParse, ListCreateViewSet):
model = assetModel.NginxInstance model = assetModel.NginxInstance
fields = assetField.NginxInstanceFields fields = assetField.NginxInstanceFields
@ -171,7 +116,7 @@ class NginxInstanceViews(MiddlewareParse, ListCreateViewSet):
super().__init__() super().__init__()
class NginxInstanceDetail(MiddlewareParse, DetailViewSet): class NginxInstanceDetail(parsers.MiddlewareParse, DetailViewSet):
model = assetModel.NginxInstance model = assetModel.NginxInstance
fields = assetField.NginxInstanceFields fields = assetField.NginxInstanceFields

View File

@ -0,0 +1,77 @@
"""
视图类继承的解析扩展子类须继承自 `common.views.ModelViewBase`
有使用到该父类的部分方法主要抽出部分请求参数定义解析后的进一步校验 `post` `put 请求方法使用
"""
from flask import current_app as app
from flask_restful import reqparse
from models.asset import models as assetModel
from models.project import models as projectModel
from common.utils import abort_response
class ChannelParse:
request_parse = None
def init_parse(self):
self.request_parse = reqparse.RequestParser()
self.request_parse.add_argument("name", type=str, location='json', default="")
self.request_parse.add_argument("version", type=dict, location='json')
self.request_parse.add_argument("repository", type=str, location='json', default="")
self.request_parse.add_argument("branch", type=dict, location='json', default="")
self.request_parse.add_argument("data", type=dict, location='json')
self.request_parse.add_argument("tags", type=list, location='json')
self.request_parse.add_argument("labels", type=dict, location='json')
def validate_fields(self, args: dict, create=True) -> dict:
if create:
self.validate_relation_pk(args, projectModel.Project, key="project_id")
spid = str(args.get("spid"))
if not spid.isalnum() or len(spid) != 3:
abort_response(400, 1001, msg="spid 不合法")
else:
if "project_id" in args:
self.validate_relation_pk(args, projectModel.Project, key="project_id")
if "spid" in args:
spid = str(args.get("spid"))
if not spid.isalnum() or len(spid) != 3:
abort_response(400, 1001, msg="spid 不合法")
return args
class ServerParse:
"""views 中使用的解析、校验方法"""
request_parse = None
def init_parse(self):
self.request_parse = reqparse.RequestParser()
self.request_parse.add_argument("data", type=dict, location='json')
self.request_parse.add_argument("tags", type=list, location='json')
self.request_parse.add_argument("labels", type=dict, location='json')
self.request_parse.add_argument("host_id", type=str, location='json')
self.request_parse.add_argument("domain", type=str, location='json', default="")
# status是枚举
self.request_parse.add_argument("status", choices=projectModel.Server.STATUS.keys(),
type=str, location='json', default="running")
self.request_parse.add_argument("port", type=int, location='json', default=0)
self.request_parse.add_argument("version", type=dict, location='json')
self.request_parse.add_argument("weight", type=int, location='json')
self.request_parse.add_argument("slot", type=int, location='json')
def validate_fields(self, args: dict, create=True) -> dict:
# if not create:
# if "channel_id" in args:
# self.validate_relation_pk(args, projectModel.Channel, key="channel_id")
# if "host_id" in args:
# self.validate_relation_pk(args, assetModel.Host, key="channel_id")
# else:
# self.validate_relation_pk(args, projectModel.Channel, key="channel_id")
version = args.get("version")
if version:
args["version"] = projectModel.Version(**version)
return args

View File

@ -1,8 +1,12 @@
"""数据模型常规增删查改的接口""" """数据模型常规增删查改的接口"""
from flask import current_app as app
from models.project import fields from models.project import fields
from models.asset import models as assetModel
from models.project.models import Project, Channel, Server from models.project.models import Project, Channel, Server
from common.views import ListMixin, CreateMixin, ListCreateViewSet, DetailViewSet from common.views import ListMixin, CreateMixin, ListCreateViewSet, DetailViewSet
from controller.project import parsers
class ProjectViews(ListMixin, CreateMixin): class ProjectViews(ListMixin, CreateMixin):
@ -16,21 +20,49 @@ class ProjectDetailViews(DetailViewSet):
fields = fields.ProjectFields fields = fields.ProjectFields
class ChannelViews(ListCreateViewSet): class ChannelViews(parsers.ChannelParse, ListCreateViewSet):
model = Channel model = Channel
fields = fields.ChannelFields fields = fields.ChannelFields
uniq_fields = (("spid", "project_id"),)
def __init__(self):
self.init_parse()
self.request_parse.add_argument("spid", type=str, location='json', required=True)
self.request_parse.add_argument("project_id", type=str, location='json', required=True)
super().__init__()
class ChannelDetailViews(DetailViewSet): class ChannelDetailViews(parsers.ChannelParse, DetailViewSet):
model = Channel model = Channel
fields = fields.ChannelFields fields = fields.ChannelFields
uniq_fields = (("spid", "project_id"),)
def __init__(self):
self.init_parse()
self.request_parse.add_argument("spid", type=str, location='json', required=False)
self.request_parse.add_argument("project_id", type=str, location='json', required=False)
super().__init__()
class ServerViews(ListCreateViewSet): class ServerViews(parsers.ServerParse, ListCreateViewSet):
model = Server model = Server
fields = fields.ServerFields fields = fields.ServerFields
uniq_fields = (("num", "channel_id"),)
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
def __init__(self):
self.init_parse()
self.request_parse.add_argument("num", type=int, location='json', required=True)
self.request_parse.add_argument("channel_id", type=str, location='json', required=True)
class ServerDetailView(DetailViewSet): class ServerDetailView(parsers.ServerParse, DetailViewSet):
model = Server model = Server
fields = fields.ServerFields fields = fields.ServerFields
uniq_fields = (("num", "channel_id"),)
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
def __init__(self):
self.init_parse()
self.request_parse.add_argument("num", type=int, location='json')
self.request_parse.add_argument("channel_id", type=str, location='json')

View File

@ -18,6 +18,7 @@ ProjectFields = {
"kw": fields.String, "kw": fields.String,
"desc": fields.String, "desc": fields.String,
"repository": fields.String, "repository": fields.String,
"data": fields.Raw,
"tags": fields.List(fields.String), "tags": fields.List(fields.String),
"labels": fields.Raw, "labels": fields.Raw,
} }
@ -48,6 +49,7 @@ ChannelFields = {
"version": fields.Nested(VersionFields), "version": fields.Nested(VersionFields),
"repository": fields.String, "repository": fields.String,
"branch": fields.String, "branch": fields.String,
"data": fields.Raw,
"tags": fields.List(fields.String), # List 必须指明类型 "tags": fields.List(fields.String), # List 必须指明类型
"labels": fields.Raw, "labels": fields.Raw,
} }
@ -70,6 +72,7 @@ ServerFields = {
"version": fields.Nested(VersionFields), "version": fields.Nested(VersionFields),
"weight": fields.Integer, "weight": fields.Integer,
"slot": fields.Integer, "slot": fields.Integer,
"data": fields.Raw,
"tags": fields.List(fields.String), "tags": fields.List(fields.String),
"labels": fields.Raw, "labels": fields.Raw,
} }
@ -94,6 +97,7 @@ AgentServerFields = {
"status": fields.String, "status": fields.String,
"weight": fields.Integer, "weight": fields.Integer,
"slot": fields.Integer, "slot": fields.Integer,
"data": fields.Raw,
# "tags": serializer.List(serializer.String), # "tags": serializer.List(serializer.String),
# "labels": serializer.Raw, # "labels": serializer.Raw,
} }

View File

@ -21,6 +21,7 @@ class Project(DocumentBase):
desc = mongo.StringField(max_length=256, default="") desc = mongo.StringField(max_length=256, default="")
# 项目的代码仓库地址 # 项目的代码仓库地址
repository = mongo.StringField("仓库", max_length=256, default="", help_text="仓库地址", null=True) repository = mongo.StringField("仓库", max_length=256, default="", help_text="仓库地址", null=True)
data = mongo.DictField(default={})
# 标记和标签 # 标记和标签
tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表 tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表
labels = mongo.DictField(default=dict) labels = mongo.DictField(default=dict)
@ -51,8 +52,8 @@ class Channel(DocumentBase):
# 项目的代码仓库地址 # 项目的代码仓库地址
repository = mongo.StringField(max_length=256, default="") repository = mongo.StringField(max_length=256, default="")
branch = mongo.StringField(max_length=32) branch = mongo.StringField(max_length=32, default="")
data = mongo.DictField(default={})
# 标记和标签 # 标记和标签
tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表 tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表
labels = mongo.DictField(default=dict) labels = mongo.DictField(default=dict)
@ -94,7 +95,7 @@ class Server(DocumentBase):
num = mongo.IntField(required=True, unique_with="channel_id") num = mongo.IntField(required=True, unique_with="channel_id")
# 关联channel表保存是一个channel._id的hex字符串 # 关联channel表保存是一个channel._id的hex字符串
channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string) channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=True, default="running") status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=False, default="running")
# 机器字段TODO 先允许为空 # 机器字段TODO 先允许为空
host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string) host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string)
@ -103,7 +104,7 @@ class Server(DocumentBase):
version = mongo.EmbeddedDocumentField(Version) version = mongo.EmbeddedDocumentField(Version)
weight = mongo.IntField(default=1) weight = mongo.IntField(default=1)
slot = mongo.IntField(default=0) slot = mongo.IntField(default=0)
data = mongo.DictField(default={})
# 标记和标签 # 标记和标签
tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表 tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表
labels = mongo.DictField(default=dict) labels = mongo.DictField(default=dict)