diff --git a/src/common/views.py b/src/common/views.py index 6c7300d..2181151 100644 --- a/src/common/views.py +++ b/src/common/views.py @@ -1,6 +1,6 @@ from typing import List, Dict -from flask import request +from flask import request, current_app as app from flask_restful import abort, Resource, marshal, fields, reqparse from common.utils import abort_response @@ -44,8 +44,11 @@ class ModelViewBase(Resource): # 是否分页 paging = True # 需要检查重复的字段 - # TODO 可能还需要多个字段组合的判断,并设默认值,dict可能比较合适 + # ["field_name", ("filed_name1", "filed_name2"), {"f1": "default_val", "f2": "default_val"}] uniq_fields = [] + # 需要检查是否存在的关联字段, required为false时不判断空值 + # [("host_id", models.Host, required), ("xx_id", models.xx, False)] + relation_fields = [] # method_decorators 可以用作权限校验,参考 `flask_restful.Resource` # 参数是装饰器列表(未指定方法将会对所有请求方法生效) @@ -116,6 +119,42 @@ class ModelViewBase(Resource): def validate_fields(self, args: dict, create=True) -> dict: return args + @staticmethod + def validate_relation_pk(val, model, key, pk="id", need_hex=True): + """ + 校验关联主键,适合非mongo外键字段关联,但保存有关联对象的信息的 + :param val: 解析过的参数字典 + :param model: mongo 模型 + :param key: 检查的字段 + :param pk: 当前模型保存的key对应关联模型中的字段名,默认 `id` + :param need_hex: 是否检查值为16进制字符串,如果是 mongo 的 `id` 需要是 16 进制字符串 + :return: + """ + try: + if need_hex: + assert int(val, 16), f"{key} 不合法" + assert model.objects(**{pk: val}).first(), f"{key} 不存在" + except (AssertionError, ValueError, TypeError): + app.logger.exception(f"{key} 异常") + abort_response(400, 1002, msg=f"{key} 不存在") + + def validate_relation_fields(self, args): + """ + 注意请在解析 args 时对必须的字段设置为 required=True,否则为 None 不判断 + 校验关联字段,请设置 `relation_fields`,如:(("host_id", models.Host, required), ) + :param args: 请求参数 + :return: + """ + if not self.relation_fields: + return + + for field, model, required in self.relation_fields: + val = args.get(field) + # 判断 required: 参数有提交值才校验,未在参数中或零值不校验 + if not required and not val: + continue + self.validate_relation_pk(val, model, field) + class ListMixin(ModelViewBase): @@ -148,6 +187,9 @@ class CreateMixin(ModelViewBase): args = self.request_parse.parse_args() validated_data = self.validate_fields(args, create=True) + # 校验关联字段 + self.validate_relation_fields(validated_data) + # 创建前钩子, validated_data = self.pre_create(validated_data) @@ -155,31 +197,17 @@ class CreateMixin(ModelViewBase): self.validate_uniq_fields(args) # 保存对象 - obj = self.model(**validated_data) - obj.save() - + try: + obj = self.model(**validated_data) + obj.save() + except Exception as e: + app.logger.exception(f"{self.model} 创建对象失败! data={args}") + abort_response(500, 1500, msg=f"保存对象失败!{str(e)}") + return # 返回创建信息 return marshal(obj, self.fields) -# TODO 模型校验 -# class CreateFormMixin(ModelViewBase): -# -# def pre_create(self, args): -# """创建前钩子,接收参数,可以对参数进行处理,最后保存此方法返回的数据""" -# return args -# -# def post(self): -# """创建对象""" -# self.form = self.form if self.form else model_form(self.model) -# print("===> json", request.json) -# form = self.form(request.json, meta={'csrf': False}) -# if not form.validate(): -# print(form.errors) -# return {"msg": "error"} -# return {"msg": "ok"} - - class RetrieveMixin(ModelViewBase): def get(self, pk): @@ -205,16 +233,23 @@ class UpdateMixin(ModelViewBase): # 解析参数 args = self.request_parse.parse_args() validated_data = self.validate_fields(args, create=False) + # 校验关联字段 + self.validate_relation_fields(validated_data) validated_data = self.pre_update(obj, validated_data) # 检查重复:检查提交的唯一字段值,是否有存在的对象,与当前对象id不同认为异常 self.validate_uniq_fields(args, obj=obj) # 更新对象、保存 - obj.update(**validated_data) - obj.save() - # 重新读取数据 - obj.reload() + try: + obj.update(**validated_data) + obj.save() + # 重新读取数据 + obj.reload() + except Exception as e: + app.logger.exception(f"{self.model} 保存对象失败!pk={pk} data={args}") + abort_response(500, 1500, msg=f"保存对象失败!{str(e)}") + return return marshal(obj, self.fields) diff --git a/src/controller/asset/parsers.py b/src/controller/asset/parsers.py new file mode 100644 index 0000000..3abcf7f --- /dev/null +++ b/src/controller/asset/parsers.py @@ -0,0 +1,57 @@ +import datetime + +from flask_restful import reqparse + + +class HostParse: + model = None + request_parse = None + uniq_fields = ("public_ip", "minion_id") + + def init_parse(self): + self.request_parse = reqparse.RequestParser() + # 创建时必须,修改时可选 + # self.request_parse.add_argument("public_ip", type=str, required=True, + # help='not public_ip provided', location='json') + self.request_parse.add_argument("private_ip", required=False, type=str, location='json') + self.request_parse.add_argument("minion_id", required=False, type=str, location='json') + self.request_parse.add_argument("weights", required=False, type=int, location='json') + self.request_parse.add_argument("cpu_num", required=False, type=int, location='json') + self.request_parse.add_argument("cpu_core", required=False, type=int, location='json') + self.request_parse.add_argument("memory", required=False, type=int, location='json') + self.request_parse.add_argument("tags", required=False, type=list, location='json') + self.request_parse.add_argument("labels", required=False, type=dict, location='json') + + def validate_fields(self, args: dict, create=True) -> dict: + if create: + args["created"] = datetime.datetime.now() + else: + if "created" in args: + args.pop("created") + return args + + +class DatabaseServerParse: + request_parse = None + # 给子类用的唯一字段,用于校验 + uniq_fields = ("name", "host", "domain") + + def init_parse(self): + self.request_parse = reqparse.RequestParser() + self.request_parse.add_argument("domain", required=False, type=str, location='json') + self.request_parse.add_argument("data", required=False, type=dict, location='json') + self.request_parse.add_argument("tags", required=False, type=list, location='json') + self.request_parse.add_argument("labels", required=False, type=dict, location='json') + + +class MiddlewareParse: + request_parse = None + # 给子类用的唯一字段,用于校验 + uniq_fields = ("name", "host") + + def init_parse(self): + self.request_parse = reqparse.RequestParser() + # self.request_parse.add_argument("domain", required=False, type=str, location='json') + self.request_parse.add_argument("data", required=False, type=dict, location='json') + self.request_parse.add_argument("tags", required=False, type=list, location='json') + self.request_parse.add_argument("labels", required=False, type=dict, location='json') diff --git a/src/controller/asset/views.py b/src/controller/asset/views.py index 50bb5e7..ed08365 100644 --- a/src/controller/asset/views.py +++ b/src/controller/asset/views.py @@ -5,39 +5,10 @@ from flask_restful import reqparse from models.asset import fields as assetField from models.asset import models as assetModel from common.views import ListCreateViewSet, DetailViewSet +from controller.asset import parsers -class HostParse: - model = None - request_parse = None - uniq_fields = ("public_ip", "minion_id") - - def init_parse(self): - self.request_parse = reqparse.RequestParser() - # 创建时必须,修改时可选 - # self.request_parse.add_argument("public_ip", type=str, required=True, - # help='not public_ip provided', location='json') - self.request_parse.add_argument("private_ip", required=False, type=str, location='json') - self.request_parse.add_argument("minion_id", required=False, type=str, location='json') - self.request_parse.add_argument("weights", required=False, type=int, location='json') - self.request_parse.add_argument("cpu_num", required=False, type=int, location='json') - self.request_parse.add_argument("cpu_core", required=False, type=int, location='json') - self.request_parse.add_argument("memory", required=False, type=int, location='json') - self.request_parse.add_argument("tags", required=False, type=list, location='json') - self.request_parse.add_argument("labels", required=False, type=dict, location='json') - - def pre_create(self, args): - # 标记创建时间 - args["created"] = datetime.datetime.now() - return args - - def pre_update(self, obj, args): - if "created" in args: - args.pop("created") - return args - - -class HostViews(HostParse, ListCreateViewSet): +class HostViews(parsers.HostParse, ListCreateViewSet): model = assetModel.Host fields = assetField.HostFields @@ -48,7 +19,7 @@ class HostViews(HostParse, ListCreateViewSet): super(HostViews, self).__init__() -class HostDetailViews(HostParse, DetailViewSet): +class HostDetailViews(parsers.HostParse, DetailViewSet): model = assetModel.Host fields = assetField.HostFields @@ -58,20 +29,7 @@ class HostDetailViews(HostParse, DetailViewSet): super(HostDetailViews, self).__init__() -class DatabaseServerParse: - request_parse = None - # 给子类用的唯一字段,用于校验 - uniq_fields = ("name", "host", "domain") - - def init_parse(self): - self.request_parse = reqparse.RequestParser() - self.request_parse.add_argument("domain", required=False, type=str, location='json') - self.request_parse.add_argument("data", required=False, type=dict, location='json') - self.request_parse.add_argument("tags", required=False, type=list, location='json') - self.request_parse.add_argument("labels", required=False, type=dict, location='json') - - -class MySQLInstanceViews(DatabaseServerParse, ListCreateViewSet): +class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): model = assetModel.MySQLInstance fields = assetField.MySQLInstanceFields @@ -90,7 +48,7 @@ class MySQLInstanceViews(DatabaseServerParse, ListCreateViewSet): super(MySQLInstanceViews, self).__init__() -class MySQLInstanceDetail(DatabaseServerParse, DetailViewSet): +class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): model = assetModel.MySQLInstance fields = assetField.MySQLInstanceFields @@ -110,7 +68,7 @@ class MySQLInstanceDetail(DatabaseServerParse, DetailViewSet): super(MySQLInstanceDetail, self).__init__() -class RedisInstanceViews(DatabaseServerParse, ListCreateViewSet): +class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): model = assetModel.RedisInstance fields = assetField.RedisInstanceFields @@ -127,7 +85,7 @@ class RedisInstanceViews(DatabaseServerParse, ListCreateViewSet): super(RedisInstanceViews, self).__init__() -class RedisInstanceDetail(DatabaseServerParse, DetailViewSet): +class RedisInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): model = assetModel.RedisInstance fields = assetField.RedisInstanceFields @@ -144,20 +102,7 @@ class RedisInstanceDetail(DatabaseServerParse, DetailViewSet): super().__init__() -class MiddlewareParse: - request_parse = None - # 给子类用的唯一字段,用于校验 - uniq_fields = ("name", "host") - - def init_parse(self): - self.request_parse = reqparse.RequestParser() - # self.request_parse.add_argument("domain", required=False, type=str, location='json') - self.request_parse.add_argument("data", required=False, type=dict, location='json') - self.request_parse.add_argument("tags", required=False, type=list, location='json') - self.request_parse.add_argument("labels", required=False, type=dict, location='json') - - -class NginxInstanceViews(MiddlewareParse, ListCreateViewSet): +class NginxInstanceViews(parsers.MiddlewareParse, ListCreateViewSet): model = assetModel.NginxInstance fields = assetField.NginxInstanceFields @@ -171,7 +116,7 @@ class NginxInstanceViews(MiddlewareParse, ListCreateViewSet): super().__init__() -class NginxInstanceDetail(MiddlewareParse, DetailViewSet): +class NginxInstanceDetail(parsers.MiddlewareParse, DetailViewSet): model = assetModel.NginxInstance fields = assetField.NginxInstanceFields diff --git a/src/controller/project/parsers.py b/src/controller/project/parsers.py new file mode 100644 index 0000000..2091ad4 --- /dev/null +++ b/src/controller/project/parsers.py @@ -0,0 +1,77 @@ +""" +视图类继承的解析扩展,子类须继承自 `common.views.ModelViewBase` +有使用到该父类的部分方法,主要抽出部分请求参数定义、解析后的进一步校验,供 `post` 和 `put 请求方法使用 +""" + +from flask import current_app as app +from flask_restful import reqparse + +from models.asset import models as assetModel +from models.project import models as projectModel +from common.utils import abort_response + + +class ChannelParse: + request_parse = None + + def init_parse(self): + self.request_parse = reqparse.RequestParser() + + self.request_parse.add_argument("name", type=str, location='json', default="") + self.request_parse.add_argument("version", type=dict, location='json') + self.request_parse.add_argument("repository", type=str, location='json', default="") + self.request_parse.add_argument("branch", type=dict, location='json', default="") + + self.request_parse.add_argument("data", type=dict, location='json') + self.request_parse.add_argument("tags", type=list, location='json') + self.request_parse.add_argument("labels", type=dict, location='json') + + def validate_fields(self, args: dict, create=True) -> dict: + if create: + self.validate_relation_pk(args, projectModel.Project, key="project_id") + spid = str(args.get("spid")) + if not spid.isalnum() or len(spid) != 3: + abort_response(400, 1001, msg="spid 不合法") + else: + if "project_id" in args: + self.validate_relation_pk(args, projectModel.Project, key="project_id") + if "spid" in args: + spid = str(args.get("spid")) + if not spid.isalnum() or len(spid) != 3: + abort_response(400, 1001, msg="spid 不合法") + + return args + + +class ServerParse: + """views 中使用的解析、校验方法""" + request_parse = None + + def init_parse(self): + self.request_parse = reqparse.RequestParser() + self.request_parse.add_argument("data", type=dict, location='json') + self.request_parse.add_argument("tags", type=list, location='json') + self.request_parse.add_argument("labels", type=dict, location='json') + + self.request_parse.add_argument("host_id", type=str, location='json') + self.request_parse.add_argument("domain", type=str, location='json', default="") + # status是枚举 + self.request_parse.add_argument("status", choices=projectModel.Server.STATUS.keys(), + type=str, location='json', default="running") + self.request_parse.add_argument("port", type=int, location='json', default=0) + self.request_parse.add_argument("version", type=dict, location='json') + self.request_parse.add_argument("weight", type=int, location='json') + self.request_parse.add_argument("slot", type=int, location='json') + + def validate_fields(self, args: dict, create=True) -> dict: + # if not create: + # if "channel_id" in args: + # self.validate_relation_pk(args, projectModel.Channel, key="channel_id") + # if "host_id" in args: + # self.validate_relation_pk(args, assetModel.Host, key="channel_id") + # else: + # self.validate_relation_pk(args, projectModel.Channel, key="channel_id") + version = args.get("version") + if version: + args["version"] = projectModel.Version(**version) + return args diff --git a/src/controller/project/views.py b/src/controller/project/views.py index af251d5..3ac13e0 100644 --- a/src/controller/project/views.py +++ b/src/controller/project/views.py @@ -1,8 +1,12 @@ """数据模型常规增删查改的接口""" +from flask import current_app as app + from models.project import fields +from models.asset import models as assetModel from models.project.models import Project, Channel, Server from common.views import ListMixin, CreateMixin, ListCreateViewSet, DetailViewSet +from controller.project import parsers class ProjectViews(ListMixin, CreateMixin): @@ -16,21 +20,49 @@ class ProjectDetailViews(DetailViewSet): fields = fields.ProjectFields -class ChannelViews(ListCreateViewSet): +class ChannelViews(parsers.ChannelParse, ListCreateViewSet): model = Channel fields = fields.ChannelFields + uniq_fields = (("spid", "project_id"),) + + def __init__(self): + self.init_parse() + self.request_parse.add_argument("spid", type=str, location='json', required=True) + self.request_parse.add_argument("project_id", type=str, location='json', required=True) + super().__init__() -class ChannelDetailViews(DetailViewSet): +class ChannelDetailViews(parsers.ChannelParse, DetailViewSet): model = Channel fields = fields.ChannelFields + uniq_fields = (("spid", "project_id"),) + + def __init__(self): + self.init_parse() + self.request_parse.add_argument("spid", type=str, location='json', required=False) + self.request_parse.add_argument("project_id", type=str, location='json', required=False) + super().__init__() -class ServerViews(ListCreateViewSet): +class ServerViews(parsers.ServerParse, ListCreateViewSet): model = Server fields = fields.ServerFields + uniq_fields = (("num", "channel_id"),) + relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True)) + + def __init__(self): + self.init_parse() + self.request_parse.add_argument("num", type=int, location='json', required=True) + self.request_parse.add_argument("channel_id", type=str, location='json', required=True) -class ServerDetailView(DetailViewSet): +class ServerDetailView(parsers.ServerParse, DetailViewSet): model = Server fields = fields.ServerFields + uniq_fields = (("num", "channel_id"),) + relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True)) + + def __init__(self): + self.init_parse() + self.request_parse.add_argument("num", type=int, location='json') + self.request_parse.add_argument("channel_id", type=str, location='json') diff --git a/src/models/project/fields.py b/src/models/project/fields.py index c9116a9..12b26ff 100644 --- a/src/models/project/fields.py +++ b/src/models/project/fields.py @@ -18,6 +18,7 @@ ProjectFields = { "kw": fields.String, "desc": fields.String, "repository": fields.String, + "data": fields.Raw, "tags": fields.List(fields.String), "labels": fields.Raw, } @@ -48,6 +49,7 @@ ChannelFields = { "version": fields.Nested(VersionFields), "repository": fields.String, "branch": fields.String, + "data": fields.Raw, "tags": fields.List(fields.String), # List 必须指明类型 "labels": fields.Raw, } @@ -70,6 +72,7 @@ ServerFields = { "version": fields.Nested(VersionFields), "weight": fields.Integer, "slot": fields.Integer, + "data": fields.Raw, "tags": fields.List(fields.String), "labels": fields.Raw, } @@ -94,6 +97,7 @@ AgentServerFields = { "status": fields.String, "weight": fields.Integer, "slot": fields.Integer, + "data": fields.Raw, # "tags": serializer.List(serializer.String), # "labels": serializer.Raw, } diff --git a/src/models/project/models.py b/src/models/project/models.py index 3001368..edb1659 100644 --- a/src/models/project/models.py +++ b/src/models/project/models.py @@ -21,6 +21,7 @@ class Project(DocumentBase): desc = mongo.StringField(max_length=256, default="") # 项目的代码仓库地址 repository = mongo.StringField("仓库", max_length=256, default="", help_text="仓库地址", null=True) + data = mongo.DictField(default={}) # 标记和标签 tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表 labels = mongo.DictField(default=dict) @@ -51,8 +52,8 @@ class Channel(DocumentBase): # 项目的代码仓库地址 repository = mongo.StringField(max_length=256, default="") - branch = mongo.StringField(max_length=32) - + branch = mongo.StringField(max_length=32, default="") + data = mongo.DictField(default={}) # 标记和标签 tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表 labels = mongo.DictField(default=dict) @@ -94,7 +95,7 @@ class Server(DocumentBase): num = mongo.IntField(required=True, unique_with="channel_id") # 关联channel表,保存是一个channel._id的hex字符串 channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string) - status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=True, default="running") + status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=False, default="running") # 机器字段,TODO 先允许为空 host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string) @@ -103,7 +104,7 @@ class Server(DocumentBase): version = mongo.EmbeddedDocumentField(Version) weight = mongo.IntField(default=1) slot = mongo.IntField(default=0) - + data = mongo.DictField(default={}) # 标记和标签 tags = mongo.ListField(mongo.StringField(), default=list) # tags 默认是空列表 labels = mongo.DictField(default=dict)