diff --git a/app.py b/app.py index 5e69d1a..1f4eee4 100644 --- a/app.py +++ b/app.py @@ -7,6 +7,7 @@ from flask.logging import default_handler sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from database.mongodb import initialize_db +from database.redis import initialize_redis from api.routes import blueprint_api from settings.dev import LOGGING from settings.common import INSTANCE, VERSION @@ -23,6 +24,7 @@ app.config.from_pyfile("settings/dev.py") # 初始化数据库连接 initialize_db(app) +initialize_redis(app) # 注册蓝图 app.register_blueprint(blueprint_api) diff --git a/docs/requirements.txt b/docs/requirements.txt index 11a0d77..5497352 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,6 @@ # 主要依赖如下 Flask==2.0.2 flask-mongoengine==1.0.0 +flask-redis==0.4.0 Flask-RESTful==0.3.9 jsonschema==4.2.1 diff --git a/settings/common.py b/settings/common.py index a1f34bc..9011647 100644 --- a/settings/common.py +++ b/settings/common.py @@ -11,3 +11,7 @@ X_TOKEN_SET = ( CROSS_SPID_SET = ( "000", "111", "222", "333", "444" ) + +# session 从请求头中获取,ops1 中的命名如: +# jiggry:1:chenzuoqing:089dbbae5bec11ec97c3525400e547a1 +SESSION_PREFIX = "jiggry:1" diff --git a/settings/dev.py b/settings/dev.py index 5ff1965..f87ad78 100644 --- a/settings/dev.py +++ b/settings/dev.py @@ -13,6 +13,9 @@ MONGODB_SETTINGS = { 'host': 'mongodb://admin:111111@10.2.2.10:27017/ops_api?authSource=admin', } +# redis 地址 +REDIS_URL = "redis://:@10.2.2.10:6379/1" + LOGGING = { 'version': 1, 'disable_existing_loggers': True, diff --git a/src/common/permission.py b/src/common/permission.py index e037470..a8f1d34 100644 --- a/src/common/permission.py +++ b/src/common/permission.py @@ -4,16 +4,57 @@ from flask import request from flask_restful import abort from settings import common +from database.redis import redis_client + + +def has_token_permission() -> bool: + """token 权限判断""" + token = request.headers.get("X-Token") + token_set = common.X_TOKEN_SET + return token in token_set + + +def has_session_permission() -> bool: + """session 权限判断 + TODO 可能需要走接口判断,暂时先读 redis + """ + username = request.headers.get("Identity") + session_id = request.headers.get("Session") + if not username or not session_id: + return False + + key_name = f"{common.SESSION_PREFIX}:{username}:{session_id}" + print("===> key name", key_name) + return bool(redis_client.exists(key_name)) def token_header_required(func): """获取请求头携带的 token,简单的过滤请求""" @wraps(func) def wrap_func(*args, **kwargs): - token = request.headers.get("X-Token") - token_set = common.X_TOKEN_SET - - if token not in token_set: + if not has_token_permission(): abort(403, msg="permission denied", code=1403) return func(*args, **kwargs) return wrap_func + + +def session_login_required(func): + """需要 session 登陆权限,校验ops1 session + """ + @wraps(func) + def wrap_func(*args, **kwargs): + if not has_session_permission(): + abort(403, msg="permission denied", code=1403) + return func(*args, **kwargs) + return wrap_func + + +def session_or_token_required(func): + """需要 token 或 session 认证,必须满足其一""" + @wraps(func) + def wrap_func(*args, **kwargs): + # 两种权限满足其一,否则 403 + if has_token_permission() or has_session_permission(): + return func(*args, **kwargs) + abort(403, msg="permission denied", code=1403) + return wrap_func diff --git a/src/controller/asset/views.py b/src/controller/asset/views.py index 0f32894..c64c4dd 100644 --- a/src/controller/asset/views.py +++ b/src/controller/asset/views.py @@ -5,12 +5,14 @@ from flask_restful import reqparse from models.asset import fields as assetField from models.asset import models as assetModel from common.views import ListCreateViewSet, DetailViewSet +from common.permission import session_or_token_required from controller.asset import parsers class HostViews(parsers.HostParse, ListCreateViewSet): model = assetModel.Host fields = assetField.HostFields + method_decorators = [session_or_token_required] filter_fields = (("public_ip", "contains"), ("private_ip", "contains"), ("tags", "")) def __init__(self): @@ -23,6 +25,7 @@ class HostViews(parsers.HostParse, ListCreateViewSet): class HostDetailViews(parsers.HostParse, DetailViewSet): model = assetModel.Host fields = assetField.HostFields + method_decorators = [session_or_token_required] def __init__(self): self.init_parse() @@ -33,6 +36,7 @@ class HostDetailViews(parsers.HostParse, DetailViewSet): class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): model = assetModel.MySQLInstance fields = assetField.MySQLInstanceFields + method_decorators = [session_or_token_required] filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", "")) def __init__(self): @@ -53,6 +57,7 @@ class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): model = assetModel.MySQLInstance fields = assetField.MySQLInstanceFields + method_decorators = [session_or_token_required] def __init__(self): """对象修改的参数解析""" @@ -73,6 +78,7 @@ class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): model = assetModel.RedisInstance fields = assetField.RedisInstanceFields + method_decorators = [session_or_token_required] filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", "")) def __init__(self): @@ -91,6 +97,7 @@ class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): class RedisInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): model = assetModel.RedisInstance fields = assetField.RedisInstanceFields + method_decorators = [session_or_token_required] def __init__(self): self.init_parse() @@ -109,6 +116,7 @@ class NginxInstanceViews(parsers.MiddlewareParse, ListCreateViewSet): model = assetModel.NginxInstance fields = assetField.NginxInstanceFields filter_fields = (("host", "icontains"), ("manage", ""), ("url", "icontains"), ("tags", "")) + method_decorators = [session_or_token_required] def __init__(self): self.init_parse() @@ -123,6 +131,7 @@ class NginxInstanceViews(parsers.MiddlewareParse, ListCreateViewSet): class NginxInstanceDetail(parsers.MiddlewareParse, DetailViewSet): model = assetModel.NginxInstance fields = assetField.NginxInstanceFields + method_decorators = [session_or_token_required] def __init__(self): self.init_parse() @@ -139,6 +148,7 @@ class CDNViews(ListCreateViewSet): fields = assetField.CDNFields uniq_fields = ("domain",) filter_fields = (("domain", "icontains"), ("tags", ""),) + method_decorators = [session_or_token_required] def __init__(self): self.request_parse = reqparse.RequestParser() @@ -153,6 +163,7 @@ class CDNDetail(DetailViewSet): model = assetModel.CDN fields = assetField.CDNFields uniq_fields = ("domain",) + method_decorators = [session_or_token_required] def __init__(self): self.request_parse = reqparse.RequestParser() diff --git a/src/controller/project/views.py b/src/controller/project/views.py index b3b19de..9d2f03c 100644 --- a/src/controller/project/views.py +++ b/src/controller/project/views.py @@ -6,6 +6,7 @@ from models.project import fields from models.asset import models as assetModel from models.project.models import Project, Channel, Server from common.views import ListMixin, CreateMixin, ListCreateViewSet, DetailViewSet +from common.permission import session_or_token_required from controller.project import parsers @@ -13,18 +14,21 @@ class ProjectViews(ListMixin, CreateMixin): model = Project paging = False fields = fields.ProjectFields + method_decorators = [session_or_token_required] filter_fields = (("name", ""), ("fork", ""), ("ops_ip", ""), ("domain", "icontains"), ("tags", "")) class ProjectDetailViews(DetailViewSet): model = Project fields = fields.ProjectFields + method_decorators = [session_or_token_required] class ChannelViews(parsers.ChannelParse, ListCreateViewSet): model = Channel fields = fields.ChannelFields uniq_fields = (("spid", "project_id"),) + method_decorators = [session_or_token_required] filter_fields = (("name", ""), ("spid", ""), ("tags", "")) def __init__(self): @@ -38,6 +42,7 @@ class ChannelDetailViews(parsers.ChannelParse, DetailViewSet): model = Channel fields = fields.ChannelFields uniq_fields = (("spid", "project_id"),) + method_decorators = [session_or_token_required] def __init__(self): self.init_parse() @@ -49,6 +54,7 @@ class ChannelDetailViews(parsers.ChannelParse, DetailViewSet): class ServerViews(parsers.ServerParse, ListCreateViewSet): model = Server fields = fields.ServerFields + method_decorators = [session_or_token_required] uniq_fields = (("num", "channel_id"),) relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True)) filter_fields = (("num", ""), ("spid", ""), ("tags", "")) @@ -63,6 +69,7 @@ class ServerDetailView(parsers.ServerParse, DetailViewSet): model = Server fields = fields.ServerFields uniq_fields = (("num", "channel_id"),) + method_decorators = [session_or_token_required] relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True)) def __init__(self): diff --git a/src/database/redis.py b/src/database/redis.py new file mode 100644 index 0000000..18b09b5 --- /dev/null +++ b/src/database/redis.py @@ -0,0 +1,7 @@ +from flask_redis import FlaskRedis + +redis_client = FlaskRedis() + + +def initialize_redis(app): + redis_client.init_app(app)