抽出需要加密字段的视图

This commit is contained in:
chenzuoqing 2021-12-22 18:13:55 +08:00
parent a351fc500b
commit b1f7c6081b
3 changed files with 53 additions and 5 deletions

View File

@ -5,6 +5,7 @@ from flask import request
from flask_restful import abort, Resource, marshal, fields, reqparse from flask_restful import abort, Resource, marshal, fields, reqparse
from common.utils import abort_response from common.utils import abort_response
from common.crypto import quick_crypto
logger = logging.getLogger("views") logger = logging.getLogger("views")
@ -316,3 +317,34 @@ class ListCreateViewSet(ListMixin, CreateMixin):
class DetailViewSet(RetrieveMixin, UpdateMixin, DestroyMixin): class DetailViewSet(RetrieveMixin, UpdateMixin, DestroyMixin):
"""带 `pk` 参数的视图集合""" """带 `pk` 参数的视图集合"""
class EncryptRequiredCreateView(CreateMixin):
"""创建对象时,加密字段的视图"""
# 需要加密的字段
encrypt_fields = []
def pre_create(self, args):
"""加密保存密码,若提交了密码,必须加密存储"""
args = super().pre_create(args)
for field in self.encrypt_fields:
if field in args:
password = args.get(field)
args[field] = quick_crypto(password)
return args
class EncryptRequiredUpdateView(UpdateMixin):
"""更新加密字段的视图"""
# 需要加密的字段
encrypt_fields = []
def pre_update(self, obj, args: dict):
"""加密保存密码,若提交了密码,必须加密存储"""
data = super().pre_update(obj, args)
for field in self.encrypt_fields:
if field in data:
password = data.get(field)
if password != obj.password:
data[field] = quick_crypto(password)
return data

View File

@ -20,6 +20,7 @@ class HostParse:
self.request_parse.add_argument("cpu_core", required=False, type=int, location='json') self.request_parse.add_argument("cpu_core", required=False, type=int, location='json')
self.request_parse.add_argument("memory", required=False, type=int, location='json') self.request_parse.add_argument("memory", required=False, type=int, location='json')
self.request_parse.add_argument("tags", required=False, type=list, location='json') self.request_parse.add_argument("tags", required=False, type=list, location='json')
self.request_parse.add_argument("data", required=False, type=dict, location='json')
self.request_parse.add_argument("labels", required=False, type=dict, location='json') self.request_parse.add_argument("labels", required=False, type=dict, location='json')
def validate_fields(self, args: dict, create=True) -> dict: def validate_fields(self, args: dict, create=True) -> dict:

View File

@ -5,7 +5,10 @@ from flask_restful import reqparse, marshal, Resource
from models.asset import fields as assetField from models.asset import fields as assetField
from models.asset import models as assetModel from models.asset import models as assetModel
from common.views import ListCreateViewSet, DetailViewSet, CreateMixin, UpdateMixin, DestroyMixin from common.views import (
ListCreateViewSet, DetailViewSet, ListMixin, UpdateMixin, DestroyMixin, RetrieveMixin,
EncryptRequiredCreateView, EncryptRequiredUpdateView
)
from common.permission import session_or_token_required from common.permission import session_or_token_required
from common.utils import abort_response from common.utils import abort_response
from controller.asset import parsers from controller.asset import parsers
@ -39,12 +42,13 @@ class HostDetailViews(parsers.HostParse, DetailViewSet):
super(HostDetailViews, self).__init__() super(HostDetailViews, self).__init__()
class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): class MySQLInstanceViews(parsers.DatabaseServerParse, ListMixin, EncryptRequiredCreateView):
model = assetModel.MySQLInstance model = assetModel.MySQLInstance
fields = assetField.MySQLInstanceFields fields = assetField.MySQLInstanceFields
method_decorators = [session_or_token_required] method_decorators = [session_or_token_required]
filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", ""), filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", ""),
("databases__name", ""), ) ("databases__name", ""), )
encrypt_fields = ["password"]
def __init__(self): def __init__(self):
self.init_parse() self.init_parse()
@ -69,10 +73,11 @@ class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
return args return args
class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): class MySQLInstanceDetail(parsers.DatabaseServerParse, RetrieveMixin, DestroyMixin, EncryptRequiredUpdateView):
model = assetModel.MySQLInstance model = assetModel.MySQLInstance
fields = assetField.MySQLInstanceFields fields = assetField.MySQLInstanceFields
method_decorators = [session_or_token_required] method_decorators = [session_or_token_required]
encrypt_fields = ["password"]
def __init__(self): def __init__(self):
"""对象修改的参数解析""" """对象修改的参数解析"""
@ -196,11 +201,12 @@ class DatabaseDetailViews(parsers.DatabaseParse, Resource):
return marshal(db_obj, self.db_fields) return marshal(db_obj, self.db_fields)
class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): class RedisInstanceViews(parsers.DatabaseServerParse, ListMixin, EncryptRequiredCreateView):
model = assetModel.RedisInstance model = assetModel.RedisInstance
fields = assetField.RedisInstanceFields fields = assetField.RedisInstanceFields
method_decorators = [session_or_token_required] method_decorators = [session_or_token_required]
filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", "")) filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", ""))
encrypt_fields = ["password"]
def __init__(self): def __init__(self):
self.init_parse() self.init_parse()
@ -214,11 +220,20 @@ class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
self.request_parse.add_argument("replicas", required=False, type=int, location='json') self.request_parse.add_argument("replicas", required=False, type=int, location='json')
super(RedisInstanceViews, self).__init__() super(RedisInstanceViews, self).__init__()
def pre_create(self, args):
"""加密保存密码,若提交了密码,必须加密存储"""
args = super().pre_create(args)
if "password" in args:
password = args.get("password")
args["password"] = quick_crypto(password)
return args
class RedisInstanceDetail(parsers.DatabaseServerParse, DetailViewSet):
class RedisInstanceDetail(parsers.DatabaseServerParse, DestroyMixin, RetrieveMixin, EncryptRequiredUpdateView):
model = assetModel.RedisInstance model = assetModel.RedisInstance
fields = assetField.RedisInstanceFields fields = assetField.RedisInstanceFields
method_decorators = [session_or_token_required] method_decorators = [session_or_token_required]
encrypt_fields = ["password"]
def __init__(self): def __init__(self):
self.init_parse() self.init_parse()