diff --git a/src/common/views.py b/src/common/views.py index 807ccd3..a346871 100644 --- a/src/common/views.py +++ b/src/common/views.py @@ -1,6 +1,8 @@ from flask import request from flask_restful import abort, Resource, marshal, fields, reqparse +from common.utils import abort_response + class ModelViewBase(Resource): """Mongodb 模型视图相关方法的扩展集合""" @@ -12,6 +14,8 @@ class ModelViewBase(Resource): request_parse: reqparse.RequestParser = None # 是否分页 paging = True + # 需要检查重复的字段 + uniq_fields = [] # method_decorators 可以用作权限校验,参考 `flask_restful.Resource` # 参数是装饰器列表(未指定方法将会对所有请求方法生效) @@ -56,6 +60,30 @@ class ModelViewBase(Resource): } return marshal(data, ret_format) + def validate_uniq_fields(self, args, obj=None): + """ + 判断传入的字段值是否重复,重复返回异常 + 注意:请先在 `request_parse` 中判断字段是否必须,若字段为空将不会检查重复 + :param args: `request_parse` 解析得到的参数 + :param obj: `update` 的 model 对象, `create` 时忽略为 None + """ + + if not self.uniq_fields: + return + + errors = [] + for field in self.uniq_fields: + val = args.get(field) + if val: + exists = self.model.objects(**{field: val}).first() + # 两种情况需要判断: + # 1. 创建时检查内容是否能匹配对象,匹配到返回异常 + # 2. 更新时若存在重复,并且不是当前修改的对象,返回异常 + if (not obj and exists) or (obj and obj.id != exists.id): + errors.append(f"提交的 {field} 已存在") + if errors: + abort_response(400, 1001, msg=f"部分字段需要唯一,请检查重复!", errors=errors) + class ListMixin(ModelViewBase): @@ -90,6 +118,9 @@ class CreateMixin(ModelViewBase): # 创建前钩子, validated_data = self.pre_create(args) + # 检查重复:检查提交的唯一字段值,是否有存在的对象 + self.validate_uniq_fields(args) + # 保存对象 obj = self.model(**validated_data) obj.save() @@ -136,12 +167,15 @@ class UpdateMixin(ModelViewBase): def put(self, pk): # 获取对象 - obj = self.get_object(pk) + # 解析参数 args = self.request_parse.parse_args() validated_data = self.pre_update(obj, args) + # 检查重复:检查提交的唯一字段值,是否有存在的对象,与当前对象id不同认为异常 + self.validate_uniq_fields(args, obj=obj) + # 更新对象、保存 obj.update(**validated_data) obj.save()