新增session权限判断(运维后台登陆的session)

This commit is contained in:
chenzuoqing 2021-12-15 16:44:51 +08:00
parent 6edb81058a
commit 9b23d8a3a6
8 changed files with 80 additions and 4 deletions

2
app.py
View File

@ -7,6 +7,7 @@ from flask.logging import default_handler
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from database.mongodb import initialize_db from database.mongodb import initialize_db
from database.redis import initialize_redis
from api.routes import blueprint_api from api.routes import blueprint_api
from settings.dev import LOGGING from settings.dev import LOGGING
from settings.common import INSTANCE, VERSION from settings.common import INSTANCE, VERSION
@ -23,6 +24,7 @@ app.config.from_pyfile("settings/dev.py")
# 初始化数据库连接 # 初始化数据库连接
initialize_db(app) initialize_db(app)
initialize_redis(app)
# 注册蓝图 # 注册蓝图
app.register_blueprint(blueprint_api) app.register_blueprint(blueprint_api)

View File

@ -1,5 +1,6 @@
# 主要依赖如下 # 主要依赖如下
Flask==2.0.2 Flask==2.0.2
flask-mongoengine==1.0.0 flask-mongoengine==1.0.0
flask-redis==0.4.0
Flask-RESTful==0.3.9 Flask-RESTful==0.3.9
jsonschema==4.2.1 jsonschema==4.2.1

View File

@ -11,3 +11,7 @@ X_TOKEN_SET = (
CROSS_SPID_SET = ( CROSS_SPID_SET = (
"000", "111", "222", "333", "444" "000", "111", "222", "333", "444"
) )
# session 从请求头中获取ops1 中的命名如:
# jiggry:1:chenzuoqing:089dbbae5bec11ec97c3525400e547a1
SESSION_PREFIX = "jiggry:1"

View File

@ -13,6 +13,9 @@ MONGODB_SETTINGS = {
'host': 'mongodb://admin:111111@10.2.2.10:27017/ops_api?authSource=admin', 'host': 'mongodb://admin:111111@10.2.2.10:27017/ops_api?authSource=admin',
} }
# redis 地址
REDIS_URL = "redis://:@10.2.2.10:6379/1"
LOGGING = { LOGGING = {
'version': 1, 'version': 1,
'disable_existing_loggers': True, 'disable_existing_loggers': True,

View File

@ -4,16 +4,57 @@ from flask import request
from flask_restful import abort from flask_restful import abort
from settings import common from settings import common
from database.redis import redis_client
def has_token_permission() -> bool:
"""token 权限判断"""
token = request.headers.get("X-Token")
token_set = common.X_TOKEN_SET
return token in token_set
def has_session_permission() -> bool:
"""session 权限判断
TODO 可能需要走接口判断暂时先读 redis
"""
username = request.headers.get("Identity")
session_id = request.headers.get("Session")
if not username or not session_id:
return False
key_name = f"{common.SESSION_PREFIX}:{username}:{session_id}"
print("===> key name", key_name)
return bool(redis_client.exists(key_name))
def token_header_required(func): def token_header_required(func):
"""获取请求头携带的 token简单的过滤请求""" """获取请求头携带的 token简单的过滤请求"""
@wraps(func) @wraps(func)
def wrap_func(*args, **kwargs): def wrap_func(*args, **kwargs):
token = request.headers.get("X-Token") if not has_token_permission():
token_set = common.X_TOKEN_SET
if token not in token_set:
abort(403, msg="permission denied", code=1403) abort(403, msg="permission denied", code=1403)
return func(*args, **kwargs) return func(*args, **kwargs)
return wrap_func return wrap_func
def session_login_required(func):
"""需要 session 登陆权限校验ops1 session
"""
@wraps(func)
def wrap_func(*args, **kwargs):
if not has_session_permission():
abort(403, msg="permission denied", code=1403)
return func(*args, **kwargs)
return wrap_func
def session_or_token_required(func):
"""需要 token 或 session 认证,必须满足其一"""
@wraps(func)
def wrap_func(*args, **kwargs):
# 两种权限满足其一,否则 403
if has_token_permission() or has_session_permission():
return func(*args, **kwargs)
abort(403, msg="permission denied", code=1403)
return wrap_func

View File

@ -5,12 +5,14 @@ from flask_restful import reqparse
from models.asset import fields as assetField from models.asset import fields as assetField
from models.asset import models as assetModel from models.asset import models as assetModel
from common.views import ListCreateViewSet, DetailViewSet from common.views import ListCreateViewSet, DetailViewSet
from common.permission import session_or_token_required
from controller.asset import parsers from controller.asset import parsers
class HostViews(parsers.HostParse, ListCreateViewSet): class HostViews(parsers.HostParse, ListCreateViewSet):
model = assetModel.Host model = assetModel.Host
fields = assetField.HostFields fields = assetField.HostFields
method_decorators = [session_or_token_required]
filter_fields = (("public_ip", "contains"), ("private_ip", "contains"), ("tags", "")) filter_fields = (("public_ip", "contains"), ("private_ip", "contains"), ("tags", ""))
def __init__(self): def __init__(self):
@ -23,6 +25,7 @@ class HostViews(parsers.HostParse, ListCreateViewSet):
class HostDetailViews(parsers.HostParse, DetailViewSet): class HostDetailViews(parsers.HostParse, DetailViewSet):
model = assetModel.Host model = assetModel.Host
fields = assetField.HostFields fields = assetField.HostFields
method_decorators = [session_or_token_required]
def __init__(self): def __init__(self):
self.init_parse() self.init_parse()
@ -33,6 +36,7 @@ class HostDetailViews(parsers.HostParse, DetailViewSet):
class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
model = assetModel.MySQLInstance model = assetModel.MySQLInstance
fields = assetField.MySQLInstanceFields fields = assetField.MySQLInstanceFields
method_decorators = [session_or_token_required]
filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", "")) filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", ""))
def __init__(self): def __init__(self):
@ -53,6 +57,7 @@ class MySQLInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet):
model = assetModel.MySQLInstance model = assetModel.MySQLInstance
fields = assetField.MySQLInstanceFields fields = assetField.MySQLInstanceFields
method_decorators = [session_or_token_required]
def __init__(self): def __init__(self):
"""对象修改的参数解析""" """对象修改的参数解析"""
@ -73,6 +78,7 @@ class MySQLInstanceDetail(parsers.DatabaseServerParse, DetailViewSet):
class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet): class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
model = assetModel.RedisInstance model = assetModel.RedisInstance
fields = assetField.RedisInstanceFields fields = assetField.RedisInstanceFields
method_decorators = [session_or_token_required]
filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", "")) filter_fields = (("name", "icontains"), ("host", "icontains"), ("manage", ""), ("tags", ""))
def __init__(self): def __init__(self):
@ -91,6 +97,7 @@ class RedisInstanceViews(parsers.DatabaseServerParse, ListCreateViewSet):
class RedisInstanceDetail(parsers.DatabaseServerParse, DetailViewSet): class RedisInstanceDetail(parsers.DatabaseServerParse, DetailViewSet):
model = assetModel.RedisInstance model = assetModel.RedisInstance
fields = assetField.RedisInstanceFields fields = assetField.RedisInstanceFields
method_decorators = [session_or_token_required]
def __init__(self): def __init__(self):
self.init_parse() self.init_parse()
@ -109,6 +116,7 @@ class NginxInstanceViews(parsers.MiddlewareParse, ListCreateViewSet):
model = assetModel.NginxInstance model = assetModel.NginxInstance
fields = assetField.NginxInstanceFields fields = assetField.NginxInstanceFields
filter_fields = (("host", "icontains"), ("manage", ""), ("url", "icontains"), ("tags", "")) filter_fields = (("host", "icontains"), ("manage", ""), ("url", "icontains"), ("tags", ""))
method_decorators = [session_or_token_required]
def __init__(self): def __init__(self):
self.init_parse() self.init_parse()
@ -123,6 +131,7 @@ class NginxInstanceViews(parsers.MiddlewareParse, ListCreateViewSet):
class NginxInstanceDetail(parsers.MiddlewareParse, DetailViewSet): class NginxInstanceDetail(parsers.MiddlewareParse, DetailViewSet):
model = assetModel.NginxInstance model = assetModel.NginxInstance
fields = assetField.NginxInstanceFields fields = assetField.NginxInstanceFields
method_decorators = [session_or_token_required]
def __init__(self): def __init__(self):
self.init_parse() self.init_parse()
@ -139,6 +148,7 @@ class CDNViews(ListCreateViewSet):
fields = assetField.CDNFields fields = assetField.CDNFields
uniq_fields = ("domain",) uniq_fields = ("domain",)
filter_fields = (("domain", "icontains"), ("tags", ""),) filter_fields = (("domain", "icontains"), ("tags", ""),)
method_decorators = [session_or_token_required]
def __init__(self): def __init__(self):
self.request_parse = reqparse.RequestParser() self.request_parse = reqparse.RequestParser()
@ -153,6 +163,7 @@ class CDNDetail(DetailViewSet):
model = assetModel.CDN model = assetModel.CDN
fields = assetField.CDNFields fields = assetField.CDNFields
uniq_fields = ("domain",) uniq_fields = ("domain",)
method_decorators = [session_or_token_required]
def __init__(self): def __init__(self):
self.request_parse = reqparse.RequestParser() self.request_parse = reqparse.RequestParser()

View File

@ -6,6 +6,7 @@ from models.project import fields
from models.asset import models as assetModel from models.asset import models as assetModel
from models.project.models import Project, Channel, Server from models.project.models import Project, Channel, Server
from common.views import ListMixin, CreateMixin, ListCreateViewSet, DetailViewSet from common.views import ListMixin, CreateMixin, ListCreateViewSet, DetailViewSet
from common.permission import session_or_token_required
from controller.project import parsers from controller.project import parsers
@ -13,18 +14,21 @@ class ProjectViews(ListMixin, CreateMixin):
model = Project model = Project
paging = False paging = False
fields = fields.ProjectFields fields = fields.ProjectFields
method_decorators = [session_or_token_required]
filter_fields = (("name", ""), ("fork", ""), ("ops_ip", ""), ("domain", "icontains"), ("tags", "")) filter_fields = (("name", ""), ("fork", ""), ("ops_ip", ""), ("domain", "icontains"), ("tags", ""))
class ProjectDetailViews(DetailViewSet): class ProjectDetailViews(DetailViewSet):
model = Project model = Project
fields = fields.ProjectFields fields = fields.ProjectFields
method_decorators = [session_or_token_required]
class ChannelViews(parsers.ChannelParse, ListCreateViewSet): class ChannelViews(parsers.ChannelParse, ListCreateViewSet):
model = Channel model = Channel
fields = fields.ChannelFields fields = fields.ChannelFields
uniq_fields = (("spid", "project_id"),) uniq_fields = (("spid", "project_id"),)
method_decorators = [session_or_token_required]
filter_fields = (("name", ""), ("spid", ""), ("tags", "")) filter_fields = (("name", ""), ("spid", ""), ("tags", ""))
def __init__(self): def __init__(self):
@ -38,6 +42,7 @@ class ChannelDetailViews(parsers.ChannelParse, DetailViewSet):
model = Channel model = Channel
fields = fields.ChannelFields fields = fields.ChannelFields
uniq_fields = (("spid", "project_id"),) uniq_fields = (("spid", "project_id"),)
method_decorators = [session_or_token_required]
def __init__(self): def __init__(self):
self.init_parse() self.init_parse()
@ -49,6 +54,7 @@ class ChannelDetailViews(parsers.ChannelParse, DetailViewSet):
class ServerViews(parsers.ServerParse, ListCreateViewSet): class ServerViews(parsers.ServerParse, ListCreateViewSet):
model = Server model = Server
fields = fields.ServerFields fields = fields.ServerFields
method_decorators = [session_or_token_required]
uniq_fields = (("num", "channel_id"),) uniq_fields = (("num", "channel_id"),)
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True)) relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
filter_fields = (("num", ""), ("spid", ""), ("tags", "")) filter_fields = (("num", ""), ("spid", ""), ("tags", ""))
@ -63,6 +69,7 @@ class ServerDetailView(parsers.ServerParse, DetailViewSet):
model = Server model = Server
fields = fields.ServerFields fields = fields.ServerFields
uniq_fields = (("num", "channel_id"),) uniq_fields = (("num", "channel_id"),)
method_decorators = [session_or_token_required]
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True)) relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
def __init__(self): def __init__(self):

7
src/database/redis.py Normal file
View File

@ -0,0 +1,7 @@
from flask_redis import FlaskRedis
redis_client = FlaskRedis()
def initialize_redis(app):
redis_client.init_app(app)