diff --git a/src/common/views.py b/src/common/views.py index 3bbacde..7c5e84b 100644 --- a/src/common/views.py +++ b/src/common/views.py @@ -5,6 +5,7 @@ from flask import request from flask_restful import abort, Resource, marshal, fields, reqparse from common.utils import abort_response +from common.crypto import quick_crypto logger = logging.getLogger("views") @@ -316,3 +317,34 @@ class ListCreateViewSet(ListMixin, CreateMixin): class DetailViewSet(RetrieveMixin, UpdateMixin, DestroyMixin): """带 `pk` 参数的视图集合""" + + +class EncryptRequiredCreateView(CreateMixin): + """创建对象时,加密字段的视图""" + # 需要加密的字段 + encrypt_fields = [] + + def pre_create(self, args): + """加密保存密码,若提交了密码,必须加密存储""" + args = super().pre_create(args) + for field in self.encrypt_fields: + if field in args: + password = args.get(field) + args[field] = quick_crypto(password) + return args + + +class EncryptRequiredUpdateView(UpdateMixin): + """更新加密字段的视图""" + # 需要加密的字段 + encrypt_fields = [] + + def pre_update(self, obj, args: dict): + """加密保存密码,若提交了密码,必须加密存储""" + data = super().pre_update(obj, args) + for field in self.encrypt_fields: + if field in data: + password = data.get(field) + if password != obj.password: + data[field] = quick_crypto(password) + return data diff --git a/src/controller/asset/parsers.py b/src/controller/asset/parsers.py index a83d25e..b3aaaa1 100644 --- a/src/controller/asset/parsers.py +++ b/src/controller/asset/parsers.py @@ -20,6 +20,7 @@ class HostParse: self.request_parse.add_argument("cpu_core", required=False, type=int, location='json') self.request_parse.add_argument("memory", required=False, type=int, location='json') self.request_parse.add_argument("tags", required=False, type=list, location='json') + self.request_parse.add_argument("data", required=False, type=dict, location='json') self.request_parse.add_argument("labels", required=False, type=dict, location='json') def validate_fields(self, args: dict, create=True) -> dict: diff --git a/src/controller/asset/views.py b/src/controller/asset/views.py index e7834a1..2671e53 100644 --- a/src/controller/asset/views.py +++ b/src/controller/asset/views.py @@ -5,7 +5,10 @@ from flask_restful import reqparse, marshal, Resource from models.asset import fields as assetField from models.asset import models as assetModel -from common.views import ListCreateViewSet, DetailViewSet, CreateMixin, UpdateMixin, DestroyMixin +from common.views import ( + ListCreateViewSet, DetailViewSet, ListMixin, UpdateMixin, DestroyMixin, RetrieveMixin, + EncryptRequiredCreateView, EncryptRequiredUpdateView +) from common.permission import session_or_token_required from common.utils import abort_response from controller.asset import parsers @@ -39,12 +42,13 @@ class HostDetailViews(parsers.HostParse, DetailViewSet): super(HostDetailViews, self).__init__() -class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): +class MySQLInstanceViews(parsers.DatabaseServerParse, ListMixin, EncryptRequiredCreateView): model = assetModel.MySQLInstance fields = assetField.MySQLInstanceFields method_decorators = [session_or_token_required] filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", ""), ("databases__name", ""), ) + encrypt_fields = ["password"] def __init__(self): self.init_parse() @@ -69,10 +73,11 @@ class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): return args -class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): +class MySQLInstanceDetail(parsers.DatabaseServerParse, RetrieveMixin, DestroyMixin, EncryptRequiredUpdateView): model = assetModel.MySQLInstance fields = assetField.MySQLInstanceFields method_decorators = [session_or_token_required] + encrypt_fields = ["password"] def __init__(self): """对象修改的参数解析""" @@ -196,11 +201,12 @@ class DatabaseDetailViews(parsers.DatabaseParse, Resource): return marshal(db_obj, self.db_fields) -class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): +class RedisInstanceViews(parsers.DatabaseServerParse, ListMixin, EncryptRequiredCreateView): model = assetModel.RedisInstance fields = assetField.RedisInstanceFields method_decorators = [session_or_token_required] filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", "")) + encrypt_fields = ["password"] def __init__(self): self.init_parse() @@ -214,11 +220,20 @@ class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): self.request_parse.add_argument("replicas", required=False, type=int, location='json') super(RedisInstanceViews, self).__init__() + def pre_create(self, args): + """加密保存密码,若提交了密码,必须加密存储""" + args = super().pre_create(args) + if "password" in args: + password = args.get("password") + args["password"] = quick_crypto(password) + return args -class RedisInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): + +class RedisInstanceDetail(parsers.DatabaseServerParse, DestroyMixin, RetrieveMixin, EncryptRequiredUpdateView): model = assetModel.RedisInstance fields = assetField.RedisInstanceFields method_decorators = [session_or_token_required] + encrypt_fields = ["password"] def __init__(self): self.init_parse()