区服信息搜索
This commit is contained in:
parent
be964c54a4
commit
12cf916a80
|
@ -159,6 +159,9 @@ class ListMixin(ModelViewBase):
|
||||||
filter_fields = []
|
filter_fields = []
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
|
self.queryset = self.model.objects
|
||||||
|
|
||||||
|
def filter_queryset(self):
|
||||||
"""简单的过滤,默认操作符为等于"""
|
"""简单的过滤,默认操作符为等于"""
|
||||||
query_params = {}
|
query_params = {}
|
||||||
# 搜索字段必须在模型对象中
|
# 搜索字段必须在模型对象中
|
||||||
|
@ -180,7 +183,7 @@ class ListMixin(ModelViewBase):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 字段值异常会抛错
|
# 字段值异常会抛错
|
||||||
queryset = self.model.objects(**query_params)
|
queryset = self.queryset.filter(**query_params)
|
||||||
except:
|
except:
|
||||||
# 返回空 queryset
|
# 返回空 queryset
|
||||||
queryset = self.model.objects.none()
|
queryset = self.model.objects.none()
|
||||||
|
@ -190,15 +193,18 @@ class ListMixin(ModelViewBase):
|
||||||
def get(self):
|
def get(self):
|
||||||
"""获取列表数据"""
|
"""获取列表数据"""
|
||||||
# 过滤后的数据,
|
# 过滤后的数据,
|
||||||
|
|
||||||
|
self.get_queryset()
|
||||||
|
|
||||||
# TODO 匹配过滤参数
|
# TODO 匹配过滤参数
|
||||||
queryset = self.get_queryset()
|
self.queryset = self.filter_queryset()
|
||||||
|
|
||||||
if self.paging:
|
if self.paging:
|
||||||
return self.paginate_queryset(queryset)
|
return self.paginate_queryset(self. queryset)
|
||||||
|
|
||||||
# TODO 这里还未了解到比较妥的办法,似乎 `marshal` 返回的信息必须是 `dict`,返回铺平对象的 `[{obj}, {obj}]` 没找到方法,待研究
|
# TODO 这里还未了解到比较妥的办法,似乎 `marshal` 返回的信息必须是 `dict`,返回铺平对象的 `[{obj}, {obj}]` 没找到方法,待研究
|
||||||
# 不分页,需要取出对象,铺平
|
# 不分页,需要取出对象,铺平
|
||||||
return [marshal(obj, self.fields) for obj in queryset]
|
return [marshal(obj, self.fields) for obj in self.queryset]
|
||||||
|
|
||||||
|
|
||||||
class CreateMixin(ModelViewBase):
|
class CreateMixin(ModelViewBase):
|
||||||
|
|
|
@ -76,7 +76,7 @@ class ServerSyncView(CreateMixin):
|
||||||
|
|
||||||
if count != len(data):
|
if count != len(data):
|
||||||
return make_response(400, 1001, "quantity mismatch")
|
return make_response(400, 1001, "quantity mismatch")
|
||||||
project_id = str(self.project.id)
|
project_id = self.project.id
|
||||||
hosts = {}
|
hosts = {}
|
||||||
channels = {}
|
channels = {}
|
||||||
errors = []
|
errors = []
|
||||||
|
@ -100,12 +100,12 @@ class ServerSyncView(CreateMixin):
|
||||||
if not created and proj not in hostObj.tags:
|
if not created and proj not in hostObj.tags:
|
||||||
hostObj.tags.append(proj)
|
hostObj.tags.append(proj)
|
||||||
hostObj.save()
|
hostObj.save()
|
||||||
host = str(hostObj.id)
|
host = hostObj.id
|
||||||
hosts[ip] = host
|
hosts[ip] = host
|
||||||
if not channel:
|
if not channel:
|
||||||
channelObj, _ = Channel.get_or_create(
|
channelObj, _ = Channel.get_or_create(
|
||||||
project_id=project_id, spid=spid, defaults=dict(project_id=project_id, spid=spid))
|
project_id=project_id, spid=spid, defaults=dict(project_id=project_id, spid=spid))
|
||||||
channel = str(channelObj.id)
|
channel = channelObj.id
|
||||||
channels[spid] = channel
|
channels[spid] = channel
|
||||||
|
|
||||||
# 更新、创建的参数
|
# 更新、创建的参数
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""数据模型常规增删查改的接口"""
|
"""数据模型常规增删查改的接口"""
|
||||||
|
|
||||||
|
from flask import request
|
||||||
from flask_restful import inputs
|
from flask_restful import inputs
|
||||||
|
|
||||||
from models.project import fields
|
from models.project import fields
|
||||||
|
@ -49,6 +50,18 @@ class ChannelViews(parsers.ChannelParse, ListCreateViewSet):
|
||||||
self.request_parse.add_argument("project_id", type=str, location='json', required=True)
|
self.request_parse.add_argument("project_id", type=str, location='json', required=True)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
self.queryset = self.model.objects
|
||||||
|
project_full_name = request.args.get("project")
|
||||||
|
|
||||||
|
# 若包含项目名参数
|
||||||
|
if project_full_name:
|
||||||
|
try:
|
||||||
|
name, fork = project_full_name.split("/")
|
||||||
|
self.queryset = self.model.filter_by_project(name=name, fork=fork, queryset=self.queryset)
|
||||||
|
except:
|
||||||
|
self.queryset = self.model.objects.none()
|
||||||
|
|
||||||
|
|
||||||
class ChannelDetailViews(parsers.ChannelParse, DetailViewSet):
|
class ChannelDetailViews(parsers.ChannelParse, DetailViewSet):
|
||||||
model = Channel
|
model = Channel
|
||||||
|
@ -69,13 +82,37 @@ class ServerViews(parsers.ServerParse, ListCreateViewSet):
|
||||||
method_decorators = [session_or_token_required]
|
method_decorators = [session_or_token_required]
|
||||||
uniq_fields = (("num", "channel_id"),)
|
uniq_fields = (("num", "channel_id"),)
|
||||||
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
|
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
|
||||||
filter_fields = (("num", ""), ("spid", ""), ("tags", ""))
|
filter_fields = (("num", ""), ("spid", ""), ("tags", ""), ("channel_id", ""), ("host_id", ""))
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.init_parse()
|
self.init_parse()
|
||||||
self.request_parse.add_argument("num", type=int, location='json', required=True)
|
self.request_parse.add_argument("num", type=int, location='json', required=True)
|
||||||
self.request_parse.add_argument("channel_id", type=str, location='json', required=True)
|
self.request_parse.add_argument("channel_id", type=str, location='json', required=True)
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
"""这些都是有关联的查询,模型内部字段的匹配字段放在 `filter_fields` 中"""
|
||||||
|
project = None
|
||||||
|
self.queryset = self.model.objects
|
||||||
|
project_full_name = request.args.get("project")
|
||||||
|
|
||||||
|
# 若包含项目名参数
|
||||||
|
if project_full_name:
|
||||||
|
try:
|
||||||
|
name, fork = project_full_name.split("/")
|
||||||
|
project = Project.objects(name=name, fork=fork).first()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
self.queryset = self.model.filter_by_project(project)
|
||||||
|
|
||||||
|
# 若包含 spid 参数
|
||||||
|
spid = request.args.get("spid")
|
||||||
|
if spid:
|
||||||
|
self.queryset = self.model.filter_by_spid(spid, queryset=self.queryset)
|
||||||
|
|
||||||
|
ip = request.args.get("ip")
|
||||||
|
if ip:
|
||||||
|
self.queryset = self.model.filter_by_ip(ip, queryset=self.queryset)
|
||||||
|
|
||||||
|
|
||||||
class ServerDetailView(parsers.ServerParse, DetailViewSet):
|
class ServerDetailView(parsers.ServerParse, DetailViewSet):
|
||||||
model = Server
|
model = Server
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import mongoengine as mongo
|
import mongoengine as mongo
|
||||||
|
from mongoengine import Q
|
||||||
|
|
||||||
from settings import common
|
from settings import common
|
||||||
from common.document import DocumentBase
|
from common.document import DocumentBase
|
||||||
|
@ -45,7 +46,8 @@ class Version(mongo.EmbeddedDocument):
|
||||||
class Channel(DocumentBase):
|
class Channel(DocumentBase):
|
||||||
"""渠道"""
|
"""渠道"""
|
||||||
# spid项目内唯一
|
# spid项目内唯一
|
||||||
project_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
|
# project_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
|
||||||
|
project_id = mongo.ObjectIdField(required=True)
|
||||||
name = mongo.StringField(max_length=32, default="")
|
name = mongo.StringField(max_length=32, default="")
|
||||||
spid = mongo.StringField(max_length=3, min_length=3, required=True, unique_with="project_id", validation=isalnum)
|
spid = mongo.StringField(max_length=3, min_length=3, required=True, unique_with="project_id", validation=isalnum)
|
||||||
version = mongo.EmbeddedDocumentField(Version) # 缺失时获取对象的此字段为 None
|
version = mongo.EmbeddedDocumentField(Version) # 缺失时获取对象的此字段为 None
|
||||||
|
@ -79,6 +81,16 @@ class Channel(DocumentBase):
|
||||||
if isinstance(val, Project):
|
if isinstance(val, Project):
|
||||||
self.project_id = str(val.id)
|
self.project_id = str(val.id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def filter_by_project(cls, name, fork, queryset=None):
|
||||||
|
"""过滤项目的所有渠道"""
|
||||||
|
if not queryset:
|
||||||
|
queryset = cls.objects
|
||||||
|
project = Project.objects(name=name, fork=fork).first()
|
||||||
|
if project:
|
||||||
|
return queryset.filter(project_id=project.id)
|
||||||
|
return cls.objects().none()
|
||||||
|
|
||||||
|
|
||||||
class Server(DocumentBase):
|
class Server(DocumentBase):
|
||||||
"""服务"""
|
"""服务"""
|
||||||
|
@ -94,11 +106,13 @@ class Server(DocumentBase):
|
||||||
|
|
||||||
num = mongo.IntField(required=True, unique_with="channel_id")
|
num = mongo.IntField(required=True, unique_with="channel_id")
|
||||||
# 关联channel表,保存是一个channel._id的hex字符串
|
# 关联channel表,保存是一个channel._id的hex字符串
|
||||||
channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
|
# channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
|
||||||
|
channel_id = mongo.ObjectIdField(required=True)
|
||||||
status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=False, default="running")
|
status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=False, default="running")
|
||||||
|
|
||||||
# 机器字段,TODO 先允许为空
|
# 机器字段,TODO 先允许为空
|
||||||
host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string)
|
# host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string)
|
||||||
|
host_id = mongo.ObjectIdField(required=True)
|
||||||
domain = mongo.StringField(max_length=128, required=False, default="")
|
domain = mongo.StringField(max_length=128, required=False, default="")
|
||||||
port = mongo.IntField()
|
port = mongo.IntField()
|
||||||
version = mongo.EmbeddedDocumentField(Version)
|
version = mongo.EmbeddedDocumentField(Version)
|
||||||
|
@ -186,3 +200,41 @@ class Server(DocumentBase):
|
||||||
if self.channel:
|
if self.channel:
|
||||||
return self.channel.is_cross
|
return self.channel.is_cross
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def filter_by_spid(cls, spid: str, project: Project = None, queryset=None):
|
||||||
|
"""过滤渠道下的区服"""
|
||||||
|
if not queryset:
|
||||||
|
queryset = cls.objects
|
||||||
|
if project:
|
||||||
|
channel = Channel.objects(project_id=project.id, spid=spid).first()
|
||||||
|
else:
|
||||||
|
channel = Channel.objects(spid=spid).first()
|
||||||
|
if channel:
|
||||||
|
return queryset.filter(channel_id=channel.id)
|
||||||
|
return cls.objects.none()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def filter_by_project(cls, project: Project, queryset=None):
|
||||||
|
"""过滤某项目的区服"""
|
||||||
|
if not queryset:
|
||||||
|
queryset = cls.objects
|
||||||
|
if project:
|
||||||
|
channels = Channel.objects(project_id=project.id).values_list("id")
|
||||||
|
if channels:
|
||||||
|
return queryset.filter(channel_id__in=channels)
|
||||||
|
return cls.objects.none()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def filter_by_ip(cls, ip, only_public=False, queryset=None):
|
||||||
|
"""过滤某ip下的区服"""
|
||||||
|
if not queryset:
|
||||||
|
queryset = cls.objects
|
||||||
|
if only_public:
|
||||||
|
hosts = Host.objects(public_ip=ip).values_list("id")
|
||||||
|
else:
|
||||||
|
hosts = Host.objects(Q(public_ip=ip) | Q(private_ip=ip)).values_list("id")
|
||||||
|
|
||||||
|
if hosts:
|
||||||
|
return queryset.filter(host_id__in=hosts)
|
||||||
|
return cls.objects.none()
|
||||||
|
|
Loading…
Reference in New Issue