From 2a742c948a93fc5d9a2854c52bd9269ae260968f Mon Sep 17 00:00:00 2001 From: chenzuoqing Date: Wed, 8 Dec 2021 10:29:32 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=A7=86=E5=9B=BE=E6=B7=B7?= =?UTF-8?q?=E5=90=88=E7=B1=BB=E5=A2=9E=E5=8A=A0=E5=AF=B9=E9=87=8D=E5=A4=8D?= =?UTF-8?q?=E5=AD=97=E6=AE=B5=E7=9A=84=E5=88=A4=E6=96=AD=E5=A4=84=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E5=88=9B=E5=BB=BA=E3=80=81=E6=9B=B4=E6=96=B0=E5=AF=B9?= =?UTF-8?q?=E8=B1=A1=E6=97=B6=E4=BD=BF=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/views.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/common/views.py b/src/common/views.py index 807ccd3..a346871 100644 --- a/src/common/views.py +++ b/src/common/views.py @@ -1,6 +1,8 @@ from flask import request from flask_restful import abort, Resource, marshal, fields, reqparse +from common.utils import abort_response + class ModelViewBase(Resource): """Mongodb 模型视图相关方法的扩展集合""" @@ -12,6 +14,8 @@ class ModelViewBase(Resource): request_parse: reqparse.RequestParser = None # 是否分页 paging = True + # 需要检查重复的字段 + uniq_fields = [] # method_decorators 可以用作权限校验,参考 `flask_restful.Resource` # 参数是装饰器列表(未指定方法将会对所有请求方法生效) @@ -56,6 +60,30 @@ class ModelViewBase(Resource): } return marshal(data, ret_format) + def validate_uniq_fields(self, args, obj=None): + """ + 判断传入的字段值是否重复,重复返回异常 + 注意:请先在 `request_parse` 中判断字段是否必须,若字段为空将不会检查重复 + :param args: `request_parse` 解析得到的参数 + :param obj: `update` 的 model 对象, `create` 时忽略为 None + """ + + if not self.uniq_fields: + return + + errors = [] + for field in self.uniq_fields: + val = args.get(field) + if val: + exists = self.model.objects(**{field: val}).first() + # 两种情况需要判断: + # 1. 创建时检查内容是否能匹配对象,匹配到返回异常 + # 2. 更新时若存在重复,并且不是当前修改的对象,返回异常 + if (not obj and exists) or (obj and obj.id != exists.id): + errors.append(f"提交的 {field} 已存在") + if errors: + abort_response(400, 1001, msg=f"部分字段需要唯一,请检查重复!", errors=errors) + class ListMixin(ModelViewBase): @@ -90,6 +118,9 @@ class CreateMixin(ModelViewBase): # 创建前钩子, validated_data = self.pre_create(args) + # 检查重复:检查提交的唯一字段值,是否有存在的对象 + self.validate_uniq_fields(args) + # 保存对象 obj = self.model(**validated_data) obj.save() @@ -136,12 +167,15 @@ class UpdateMixin(ModelViewBase): def put(self, pk): # 获取对象 - obj = self.get_object(pk) + # 解析参数 args = self.request_parse.parse_args() validated_data = self.pre_update(obj, args) + # 检查重复:检查提交的唯一字段值,是否有存在的对象,与当前对象id不同认为异常 + self.validate_uniq_fields(args, obj=obj) + # 更新对象、保存 obj.update(**validated_data) obj.save()