模型视图混合类增加对重复字段的判断处理,创建、更新对象时使用

This commit is contained in:
chenzuoqing 2021-12-08 10:29:32 +08:00
parent ae5e4b5de1
commit 2a742c948a
1 changed files with 35 additions and 1 deletions

View File

@ -1,6 +1,8 @@
from flask import request from flask import request
from flask_restful import abort, Resource, marshal, fields, reqparse from flask_restful import abort, Resource, marshal, fields, reqparse
from common.utils import abort_response
class ModelViewBase(Resource): class ModelViewBase(Resource):
"""Mongodb 模型视图相关方法的扩展集合""" """Mongodb 模型视图相关方法的扩展集合"""
@ -12,6 +14,8 @@ class ModelViewBase(Resource):
request_parse: reqparse.RequestParser = None request_parse: reqparse.RequestParser = None
# 是否分页 # 是否分页
paging = True paging = True
# 需要检查重复的字段
uniq_fields = []
# method_decorators 可以用作权限校验,参考 `flask_restful.Resource` # method_decorators 可以用作权限校验,参考 `flask_restful.Resource`
# 参数是装饰器列表(未指定方法将会对所有请求方法生效) # 参数是装饰器列表(未指定方法将会对所有请求方法生效)
@ -56,6 +60,30 @@ class ModelViewBase(Resource):
} }
return marshal(data, ret_format) return marshal(data, ret_format)
def validate_uniq_fields(self, args, obj=None):
"""
判断传入的字段值是否重复重复返回异常
注意请先在 `request_parse` 中判断字段是否必须若字段为空将不会检查重复
:param args: `request_parse` 解析得到的参数
:param obj: `update` model 对象 `create` 时忽略为 None
"""
if not self.uniq_fields:
return
errors = []
for field in self.uniq_fields:
val = args.get(field)
if val:
exists = self.model.objects(**{field: val}).first()
# 两种情况需要判断:
# 1. 创建时检查内容是否能匹配对象,匹配到返回异常
# 2. 更新时若存在重复,并且不是当前修改的对象,返回异常
if (not obj and exists) or (obj and obj.id != exists.id):
errors.append(f"提交的 {field} 已存在")
if errors:
abort_response(400, 1001, msg=f"部分字段需要唯一,请检查重复!", errors=errors)
class ListMixin(ModelViewBase): class ListMixin(ModelViewBase):
@ -90,6 +118,9 @@ class CreateMixin(ModelViewBase):
# 创建前钩子, # 创建前钩子,
validated_data = self.pre_create(args) validated_data = self.pre_create(args)
# 检查重复:检查提交的唯一字段值,是否有存在的对象
self.validate_uniq_fields(args)
# 保存对象 # 保存对象
obj = self.model(**validated_data) obj = self.model(**validated_data)
obj.save() obj.save()
@ -136,12 +167,15 @@ class UpdateMixin(ModelViewBase):
def put(self, pk): def put(self, pk):
# 获取对象 # 获取对象
obj = self.get_object(pk) obj = self.get_object(pk)
# 解析参数 # 解析参数
args = self.request_parse.parse_args() args = self.request_parse.parse_args()
validated_data = self.pre_update(obj, args) validated_data = self.pre_update(obj, args)
# 检查重复检查提交的唯一字段值是否有存在的对象与当前对象id不同认为异常
self.validate_uniq_fields(args, obj=obj)
# 更新对象、保存 # 更新对象、保存
obj.update(**validated_data) obj.update(**validated_data)
obj.save() obj.save()