model对象关联信息检查封装,抽出输入参数解析方法

This commit is contained in:
chenzuoqing 2021-12-08 20:30:07 +08:00
parent 4e565ed4d8
commit 34511da090
7 changed files with 250 additions and 99 deletions

View File

@ -1,6 +1,6 @@
from typing import List, Dict
from flask import request
from flask import request, current_app as app
from flask_restful import abort, Resource, marshal, fields, reqparse
from common.utils import abort_response
@ -44,8 +44,11 @@ class ModelViewBase(Resource):
# 是否分页
paging = True
# 需要检查重复的字段
# TODO 可能还需要多个字段组合的判断并设默认值dict可能比较合适
# ["field_name", ("filed_name1", "filed_name2"), {"f1": "default_val", "f2": "default_val"}]
uniq_fields = []
# 需要检查是否存在的关联字段, required为false时不判断空值
# [("host_id", models.Host, required), ("xx_id", models.xx, False)]
relation_fields = []
# method_decorators 可以用作权限校验,参考 `flask_restful.Resource`
# 参数是装饰器列表(未指定方法将会对所有请求方法生效)
@ -116,6 +119,42 @@ class ModelViewBase(Resource):
def validate_fields(self, args: dict, create=True) -> dict:
return args
@staticmethod
def validate_relation_pk(val, model, key, pk="id", need_hex=True):
"""
校验关联主键适合非mongo外键字段关联但保存有关联对象的信息的
:param val: 解析过的参数字典
:param model: mongo 模型
:param key: 检查的字段
:param pk: 当前模型保存的key对应关联模型中的字段名默认 `id`
:param need_hex: 是否检查值为16进制字符串如果是 mongo `id` 需要是 16 进制字符串
:return:
"""
try:
if need_hex:
assert int(val, 16), f"{key} 不合法"
assert model.objects(**{pk: val}).first(), f"{key} 不存在"
except (AssertionError, ValueError, TypeError):
app.logger.exception(f"{key} 异常")
abort_response(400, 1002, msg=f"{key} 不存在")
def validate_relation_fields(self, args):
"""
注意请在解析 args 时对必须的字段设置为 required=True否则为 None 不判断
校验关联字段请设置 `relation_fields`(("host_id", models.Host, required), )
:param args: 请求参数
:return:
"""
if not self.relation_fields:
return
for field, model, required in self.relation_fields:
val = args.get(field)
# 判断 required: 参数有提交值才校验,未在参数中或零值不校验
if not required and not val:
continue
self.validate_relation_pk(val, model, field)
class ListMixin(ModelViewBase):
@ -148,6 +187,9 @@ class CreateMixin(ModelViewBase):
args = self.request_parse.parse_args()
validated_data = self.validate_fields(args, create=True)
# 校验关联字段
self.validate_relation_fields(validated_data)
# 创建前钩子,
validated_data = self.pre_create(validated_data)
@ -155,31 +197,17 @@ class CreateMixin(ModelViewBase):
self.validate_uniq_fields(args)
# 保存对象
try:
obj = self.model(**validated_data)
obj.save()
except Exception as e:
app.logger.exception(f"{self.model} 创建对象失败! data={args}")
abort_response(500, 1500, msg=f"保存对象失败!{str(e)}")
return
# 返回创建信息
return marshal(obj, self.fields)
# TODO 模型校验
# class CreateFormMixin(ModelViewBase):
#
# def pre_create(self, args):
# """创建前钩子,接收参数,可以对参数进行处理,最后保存此方法返回的数据"""
# return args
#
# def post(self):
# """创建对象"""
# self.form = self.form if self.form else model_form(self.model)
# print("===> json", request.json)
# form = self.form(request.json, meta={'csrf': False})
# if not form.validate():
# print(form.errors)
# return {"msg": "error"}
# return {"msg": "ok"}
class RetrieveMixin(ModelViewBase):
def get(self, pk):
@ -205,16 +233,23 @@ class UpdateMixin(ModelViewBase):
# 解析参数
args = self.request_parse.parse_args()
validated_data = self.validate_fields(args, create=False)
# 校验关联字段
self.validate_relation_fields(validated_data)
validated_data = self.pre_update(obj, validated_data)
# 检查重复检查提交的唯一字段值是否有存在的对象与当前对象id不同认为异常
self.validate_uniq_fields(args, obj=obj)
# 更新对象、保存
try:
obj.update(**validated_data)
obj.save()
# 重新读取数据
obj.reload()
except Exception as e:
app.logger.exception(f"{self.model} 保存对象失败pk={pk} data={args}")
abort_response(500, 1500, msg=f"保存对象失败!{str(e)}")
return
return marshal(obj, self.fields)

View File

@ -0,0 +1,57 @@
import datetime
from flask_restful import reqparse
class HostParse:
model = None
request_parse = None
uniq_fields = ("public_ip", "minion_id")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
# 创建时必须,修改时可选
# self.request_parse.add_argument("public_ip", type=str, required=True,
# help='not public_ip provided', location='json')
self.request_parse.add_argument("private_ip", required=False, type=str, location='json')
self.request_parse.add_argument("minion_id", required=False, type=str, location='json')
self.request_parse.add_argument("weights", required=False, type=int, location='json')
self.request_parse.add_argument("cpu_num", required=False, type=int, location='json')
self.request_parse.add_argument("cpu_core", required=False, type=int, location='json')
self.request_parse.add_argument("memory", required=False, type=int, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
def validate_fields(self, args: dict, create=True) -> dict:
if create:
args["created"] = datetime.datetime.now()
else:
if "created" in args:
args.pop("created")
return args
class DatabaseServerParse:
request_parse = None
# 给子类用的唯一字段,用于校验
uniq_fields = ("name", "host", "domain")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
self.request_parse.add_argument("domain", required=False, type=str, location='json')
self.request_parse.add_argument("data", required=False, type=dict, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
class MiddlewareParse:
request_parse = None
# 给子类用的唯一字段,用于校验
uniq_fields = ("name", "host")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
# self.request_parse.add_argument("domain", required=False, type=str, location='json')
self.request_parse.add_argument("data", required=False, type=dict, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')

View File

@ -5,39 +5,10 @@ from flask_restful import reqparse
from models.asset import fields as assetField
from models.asset import models as assetModel
from common.views import ListCreateViewSet, DetailViewSet
from controller.asset import parsers
class HostParse:
model = None
request_parse = None
uniq_fields = ("public_ip", "minion_id")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
# 创建时必须,修改时可选
# self.request_parse.add_argument("public_ip", type=str, required=True,
# help='not public_ip provided', location='json')
self.request_parse.add_argument("private_ip", required=False, type=str, location='json')
self.request_parse.add_argument("minion_id", required=False, type=str, location='json')
self.request_parse.add_argument("weights", required=False, type=int, location='json')
self.request_parse.add_argument("cpu_num", required=False, type=int, location='json')
self.request_parse.add_argument("cpu_core", required=False, type=int, location='json')
self.request_parse.add_argument("memory", required=False, type=int, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
def pre_create(self, args):
# 标记创建时间
args["created"] = datetime.datetime.now()
return args
def pre_update(self, obj, args):
if "created" in args:
args.pop("created")
return args
class HostViews(HostParse, ListCreateViewSet):
class HostViews(parsers.HostParse, ListCreateViewSet):
model = assetModel.Host
fields = assetField.HostFields
@ -48,7 +19,7 @@ class HostViews(HostParse, ListCreateViewSet):
super(HostViews, self).__init__()
class HostDetailViews(HostParse, DetailViewSet):
class HostDetailViews(parsers.HostParse, DetailViewSet):
model = assetModel.Host
fields = assetField.HostFields
@ -58,20 +29,7 @@ class HostDetailViews(HostParse, DetailViewSet):
super(HostDetailViews, self).__init__()
class DatabaseServerParse:
request_parse = None
# 给子类用的唯一字段,用于校验
uniq_fields = ("name", "host", "domain")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
self.request_parse.add_argument("domain", required=False, type=str, location='json')
self.request_parse.add_argument("data", required=False, type=dict, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
class MySQLInstanceViews(DatabaseServerParse, ListCreateViewSet):
class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
model = assetModel.MySQLInstance
fields = assetField.MySQLInstanceFields
@ -90,7 +48,7 @@ class MySQLInstanceViews(DatabaseServerParse, ListCreateViewSet):
super(MySQLInstanceViews, self).__init__()
class MySQLInstanceDetail(DatabaseServerParse, DetailViewSet):
class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet):
model = assetModel.MySQLInstance
fields = assetField.MySQLInstanceFields
@ -110,7 +68,7 @@ class MySQLInstanceDetail(DatabaseServerParse, DetailViewSet):
super(MySQLInstanceDetail, self).__init__()
class RedisInstanceViews(DatabaseServerParse, ListCreateViewSet):
class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
model = assetModel.RedisInstance
fields = assetField.RedisInstanceFields
@ -127,7 +85,7 @@ class RedisInstanceViews(DatabaseServerParse, ListCreateViewSet):
super(RedisInstanceViews, self).__init__()
class RedisInstanceDetail(DatabaseServerParse, DetailViewSet):
class RedisInstanceDetail(parsers.DatabaseServerParse, DetailViewSet):
model = assetModel.RedisInstance
fields = assetField.RedisInstanceFields
@ -144,20 +102,7 @@ class RedisInstanceDetail(DatabaseServerParse, DetailViewSet):
super().__init__()
class MiddlewareParse:
request_parse = None
# 给子类用的唯一字段,用于校验
uniq_fields = ("name", "host")
def init_parse(self):
self.request_parse = reqparse.RequestParser()
# self.request_parse.add_argument("domain", required=False, type=str, location='json')
self.request_parse.add_argument("data", required=False, type=dict, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json')
class NginxInstanceViews(MiddlewareParse, ListCreateViewSet):
class NginxInstanceViews(parsers.MiddlewareParse, ListCreateViewSet):
model = assetModel.NginxInstance
fields = assetField.NginxInstanceFields
@ -171,7 +116,7 @@ class NginxInstanceViews(MiddlewareParse, ListCreateViewSet):
super().__init__()
class NginxInstanceDetail(MiddlewareParse, DetailViewSet):
class NginxInstanceDetail(parsers.MiddlewareParse, DetailViewSet):
model = assetModel.NginxInstance
fields = assetField.NginxInstanceFields

View File

@ -0,0 +1,77 @@
"""
视图类继承的解析扩展子类须继承自 `common.views.ModelViewBase`
有使用到该父类的部分方法主要抽出部分请求参数定义解析后的进一步校验 `post` `put 请求方法使用
"""
from flask import current_app as app
from flask_restful import reqparse
from models.asset import models as assetModel
from models.project import models as projectModel
from common.utils import abort_response
class ChannelParse:
request_parse = None
def init_parse(self):
self.request_parse = reqparse.RequestParser()
self.request_parse.add_argument("name", type=str, location='json', default="")
self.request_parse.add_argument("version", type=dict, location='json')
self.request_parse.add_argument("repository", type=str, location='json', default="")
self.request_parse.add_argument("branch", type=dict, location='json', default="")
self.request_parse.add_argument("data", type=dict, location='json')
self.request_parse.add_argument("tags", type=list, location='json')
self.request_parse.add_argument("labels", type=dict, location='json')
def validate_fields(self, args: dict, create=True) -> dict:
if create:
self.validate_relation_pk(args, projectModel.Project, key="project_id")
spid = str(args.get("spid"))
if not spid.isalnum() or len(spid) != 3:
abort_response(400, 1001, msg="spid 不合法")
else:
if "project_id" in args:
self.validate_relation_pk(args, projectModel.Project, key="project_id")
if "spid" in args:
spid = str(args.get("spid"))
if not spid.isalnum() or len(spid) != 3:
abort_response(400, 1001, msg="spid 不合法")
return args
class ServerParse:
"""views 中使用的解析、校验方法"""
request_parse = None
def init_parse(self):
self.request_parse = reqparse.RequestParser()
self.request_parse.add_argument("data", type=dict, location='json')
self.request_parse.add_argument("tags", type=list, location='json')
self.request_parse.add_argument("labels", type=dict, location='json')
self.request_parse.add_argument("host_id", type=str, location='json')
self.request_parse.add_argument("domain", type=str, location='json', default="")
# status是枚举
self.request_parse.add_argument("status", choices=projectModel.Server.STATUS.keys(),
type=str, location='json', default="running")
self.request_parse.add_argument("port", type=int, location='json', default=0)
self.request_parse.add_argument("version", type=dict, location='json')
self.request_parse.add_argument("weight", type=int, location='json')
self.request_parse.add_argument("slot", type=int, location='json')
def validate_fields(self, args: dict, create=True) -> dict:
# if not create:
# if "channel_id" in args:
# self.validate_relation_pk(args, projectModel.Channel, key="channel_id")
# if "host_id" in args:
# self.validate_relation_pk(args, assetModel.Host, key="channel_id")
# else:
# self.validate_relation_pk(args, projectModel.Channel, key="channel_id")
version = args.get("version")
if version:
args["version"] = projectModel.Version(**version)
return args

View File

@ -1,8 +1,12 @@
"""数据模型常规增删查改的接口"""
from flask import current_app as app
from models.project import fields
from models.asset import models as assetModel
from models.project.models import Project, Channel, Server
from common.views import ListMixin, CreateMixin, ListCreateViewSet, DetailViewSet
from controller.project import parsers
class ProjectViews(ListMixin, CreateMixin):
@ -16,21 +20,49 @@ class ProjectDetailViews(DetailViewSet):
fields = fields.ProjectFields
class ChannelViews(ListCreateViewSet):
class ChannelViews(parsers.ChannelParse, ListCreateViewSet):
model = Channel
fields = fields.ChannelFields
uniq_fields = (("spid", "project_id"),)
def __init__(self):
self.init_parse()
self.request_parse.add_argument("spid", type=str, location='json', required=True)
self.request_parse.add_argument("project_id", type=str, location='json', required=True)
super().__init__()
class ChannelDetailViews(DetailViewSet):
class ChannelDetailViews(parsers.ChannelParse, DetailViewSet):
model = Channel
fields = fields.ChannelFields
uniq_fields = (("spid", "project_id"),)
def __init__(self):
self.init_parse()
self.request_parse.add_argument("spid", type=str, location='json', required=False)
self.request_parse.add_argument("project_id", type=str, location='json', required=False)
super().__init__()
class ServerViews(ListCreateViewSet):
class ServerViews(parsers.ServerParse, ListCreateViewSet):
model = Server
fields = fields.ServerFields
uniq_fields = (("num", "channel_id"),)
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
def __init__(self):
self.init_parse()
self.request_parse.add_argument("num", type=int, location='json', required=True)
self.request_parse.add_argument("channel_id", type=str, location='json', required=True)
class ServerDetailView(DetailViewSet):
class ServerDetailView(parsers.ServerParse, DetailViewSet):
model = Server
fields = fields.ServerFields
uniq_fields = (("num", "channel_id"),)
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
def __init__(self):
self.init_parse()
self.request_parse.add_argument("num", type=int, location='json')
self.request_parse.add_argument("channel_id", type=str, location='json')

View File

@ -18,6 +18,7 @@ ProjectFields = {
"kw": fields.String,
"desc": fields.String,
"repository": fields.String,
"data": fields.Raw,
"tags": fields.List(fields.String),
"labels": fields.Raw,
}
@ -48,6 +49,7 @@ ChannelFields = {
"version": fields.Nested(VersionFields),
"repository": fields.String,
"branch": fields.String,
"data": fields.Raw,
"tags": fields.List(fields.String), # List 必须指明类型
"labels": fields.Raw,
}
@ -70,6 +72,7 @@ ServerFields = {
"version": fields.Nested(VersionFields),
"weight": fields.Integer,
"slot": fields.Integer,
"data": fields.Raw,
"tags": fields.List(fields.String),
"labels": fields.Raw,
}
@ -94,6 +97,7 @@ AgentServerFields = {
"status": fields.String,
"weight": fields.Integer,
"slot": fields.Integer,
"data": fields.Raw,
# "tags": serializer.List(serializer.String),
# "labels": serializer.Raw,
}

View File

@ -21,6 +21,7 @@ class Project(DocumentBase):
desc = mongo.StringField(max_length=256, default="")
# 项目的代码仓库地址
repository = mongo.StringField("仓库", max_length=256, default="", help_text="仓库地址", null=True)
data = mongo.DictField(default={})
# 标记和标签
tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表
labels = mongo.DictField(default=dict)
@ -51,8 +52,8 @@ class Channel(DocumentBase):
# 项目的代码仓库地址
repository = mongo.StringField(max_length=256, default="")
branch = mongo.StringField(max_length=32)
branch = mongo.StringField(max_length=32, default="")
data = mongo.DictField(default={})
# 标记和标签
tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表
labels = mongo.DictField(default=dict)
@ -94,7 +95,7 @@ class Server(DocumentBase):
num = mongo.IntField(required=True, unique_with="channel_id")
# 关联channel表保存是一个channel._id的hex字符串
channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=True, default="running")
status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=False, default="running")
# 机器字段TODO 先允许为空
host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string)
@ -103,7 +104,7 @@ class Server(DocumentBase):
version = mongo.EmbeddedDocumentField(Version)
weight = mongo.IntField(default=1)
slot = mongo.IntField(default=0)
data = mongo.DictField(default={})
# 标记和标签
tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表
labels = mongo.DictField(default=dict)