区服信息搜索
This commit is contained in:
parent
be964c54a4
commit
12cf916a80
|
@ -159,6 +159,9 @@ class ListMixin(ModelViewBase):
|
|||
filter_fields = []
|
||||
|
||||
def get_queryset(self):
|
||||
self.queryset = self.model.objects
|
||||
|
||||
def filter_queryset(self):
|
||||
"""简单的过滤,默认操作符为等于"""
|
||||
query_params = {}
|
||||
# 搜索字段必须在模型对象中
|
||||
|
@ -180,7 +183,7 @@ class ListMixin(ModelViewBase):
|
|||
|
||||
try:
|
||||
# 字段值异常会抛错
|
||||
queryset = self.model.objects(**query_params)
|
||||
queryset = self.queryset.filter(**query_params)
|
||||
except:
|
||||
# 返回空 queryset
|
||||
queryset = self.model.objects.none()
|
||||
|
@ -190,15 +193,18 @@ class ListMixin(ModelViewBase):
|
|||
def get(self):
|
||||
"""获取列表数据"""
|
||||
# 过滤后的数据,
|
||||
|
||||
self.get_queryset()
|
||||
|
||||
# TODO 匹配过滤参数
|
||||
queryset = self.get_queryset()
|
||||
self.queryset = self.filter_queryset()
|
||||
|
||||
if self.paging:
|
||||
return self.paginate_queryset(queryset)
|
||||
return self.paginate_queryset(self. queryset)
|
||||
|
||||
# TODO 这里还未了解到比较妥的办法,似乎 `marshal` 返回的信息必须是 `dict`,返回铺平对象的 `[{obj}, {obj}]` 没找到方法,待研究
|
||||
# 不分页,需要取出对象,铺平
|
||||
return [marshal(obj, self.fields) for obj in queryset]
|
||||
return [marshal(obj, self.fields) for obj in self.queryset]
|
||||
|
||||
|
||||
class CreateMixin(ModelViewBase):
|
||||
|
|
|
@ -76,7 +76,7 @@ class ServerSyncView(CreateMixin):
|
|||
|
||||
if count != len(data):
|
||||
return make_response(400, 1001, "quantity mismatch")
|
||||
project_id = str(self.project.id)
|
||||
project_id = self.project.id
|
||||
hosts = {}
|
||||
channels = {}
|
||||
errors = []
|
||||
|
@ -100,12 +100,12 @@ class ServerSyncView(CreateMixin):
|
|||
if not created and proj not in hostObj.tags:
|
||||
hostObj.tags.append(proj)
|
||||
hostObj.save()
|
||||
host = str(hostObj.id)
|
||||
host = hostObj.id
|
||||
hosts[ip] = host
|
||||
if not channel:
|
||||
channelObj, _ = Channel.get_or_create(
|
||||
project_id=project_id, spid=spid, defaults=dict(project_id=project_id, spid=spid))
|
||||
channel = str(channelObj.id)
|
||||
channel = channelObj.id
|
||||
channels[spid] = channel
|
||||
|
||||
# 更新、创建的参数
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""数据模型常规增删查改的接口"""
|
||||
|
||||
from flask import request
|
||||
from flask_restful import inputs
|
||||
|
||||
from models.project import fields
|
||||
|
@ -49,6 +50,18 @@ class ChannelViews(parsers.ChannelParse, ListCreateViewSet):
|
|||
self.request_parse.add_argument("project_id", type=str, location='json', required=True)
|
||||
super().__init__()
|
||||
|
||||
def get_queryset(self):
|
||||
self.queryset = self.model.objects
|
||||
project_full_name = request.args.get("project")
|
||||
|
||||
# 若包含项目名参数
|
||||
if project_full_name:
|
||||
try:
|
||||
name, fork = project_full_name.split("/")
|
||||
self.queryset = self.model.filter_by_project(name=name, fork=fork, queryset=self.queryset)
|
||||
except:
|
||||
self.queryset = self.model.objects.none()
|
||||
|
||||
|
||||
class ChannelDetailViews(parsers.ChannelParse, DetailViewSet):
|
||||
model = Channel
|
||||
|
@ -69,13 +82,37 @@ class ServerViews(parsers.ServerParse, ListCreateViewSet):
|
|||
method_decorators = [session_or_token_required]
|
||||
uniq_fields = (("num", "channel_id"),)
|
||||
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
|
||||
filter_fields = (("num", ""), ("spid", ""), ("tags", ""))
|
||||
filter_fields = (("num", ""), ("spid", ""), ("tags", ""), ("channel_id", ""), ("host_id", ""))
|
||||
|
||||
def __init__(self):
|
||||
self.init_parse()
|
||||
self.request_parse.add_argument("num", type=int, location='json', required=True)
|
||||
self.request_parse.add_argument("channel_id", type=str, location='json', required=True)
|
||||
|
||||
def get_queryset(self):
|
||||
"""这些都是有关联的查询,模型内部字段的匹配字段放在 `filter_fields` 中"""
|
||||
project = None
|
||||
self.queryset = self.model.objects
|
||||
project_full_name = request.args.get("project")
|
||||
|
||||
# 若包含项目名参数
|
||||
if project_full_name:
|
||||
try:
|
||||
name, fork = project_full_name.split("/")
|
||||
project = Project.objects(name=name, fork=fork).first()
|
||||
except:
|
||||
pass
|
||||
self.queryset = self.model.filter_by_project(project)
|
||||
|
||||
# 若包含 spid 参数
|
||||
spid = request.args.get("spid")
|
||||
if spid:
|
||||
self.queryset = self.model.filter_by_spid(spid, queryset=self.queryset)
|
||||
|
||||
ip = request.args.get("ip")
|
||||
if ip:
|
||||
self.queryset = self.model.filter_by_ip(ip, queryset=self.queryset)
|
||||
|
||||
|
||||
class ServerDetailView(parsers.ServerParse, DetailViewSet):
|
||||
model = Server
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import mongoengine as mongo
|
||||
from mongoengine import Q
|
||||
|
||||
from settings import common
|
||||
from common.document import DocumentBase
|
||||
|
@ -45,7 +46,8 @@ class Version(mongo.EmbeddedDocument):
|
|||
class Channel(DocumentBase):
|
||||
"""渠道"""
|
||||
# spid项目内唯一
|
||||
project_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
|
||||
# project_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
|
||||
project_id = mongo.ObjectIdField(required=True)
|
||||
name = mongo.StringField(max_length=32, default="")
|
||||
spid = mongo.StringField(max_length=3, min_length=3, required=True, unique_with="project_id", validation=isalnum)
|
||||
version = mongo.EmbeddedDocumentField(Version) # 缺失时获取对象的此字段为 None
|
||||
|
@ -79,6 +81,16 @@ class Channel(DocumentBase):
|
|||
if isinstance(val, Project):
|
||||
self.project_id = str(val.id)
|
||||
|
||||
@classmethod
|
||||
def filter_by_project(cls, name, fork, queryset=None):
|
||||
"""过滤项目的所有渠道"""
|
||||
if not queryset:
|
||||
queryset = cls.objects
|
||||
project = Project.objects(name=name, fork=fork).first()
|
||||
if project:
|
||||
return queryset.filter(project_id=project.id)
|
||||
return cls.objects().none()
|
||||
|
||||
|
||||
class Server(DocumentBase):
|
||||
"""服务"""
|
||||
|
@ -94,11 +106,13 @@ class Server(DocumentBase):
|
|||
|
||||
num = mongo.IntField(required=True, unique_with="channel_id")
|
||||
# 关联channel表,保存是一个channel._id的hex字符串
|
||||
channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
|
||||
# channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
|
||||
channel_id = mongo.ObjectIdField(required=True)
|
||||
status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=False, default="running")
|
||||
|
||||
# 机器字段,TODO 先允许为空
|
||||
host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string)
|
||||
# host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string)
|
||||
host_id = mongo.ObjectIdField(required=True)
|
||||
domain = mongo.StringField(max_length=128, required=False, default="")
|
||||
port = mongo.IntField()
|
||||
version = mongo.EmbeddedDocumentField(Version)
|
||||
|
@ -186,3 +200,41 @@ class Server(DocumentBase):
|
|||
if self.channel:
|
||||
return self.channel.is_cross
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def filter_by_spid(cls, spid: str, project: Project = None, queryset=None):
|
||||
"""过滤渠道下的区服"""
|
||||
if not queryset:
|
||||
queryset = cls.objects
|
||||
if project:
|
||||
channel = Channel.objects(project_id=project.id, spid=spid).first()
|
||||
else:
|
||||
channel = Channel.objects(spid=spid).first()
|
||||
if channel:
|
||||
return queryset.filter(channel_id=channel.id)
|
||||
return cls.objects.none()
|
||||
|
||||
@classmethod
|
||||
def filter_by_project(cls, project: Project, queryset=None):
|
||||
"""过滤某项目的区服"""
|
||||
if not queryset:
|
||||
queryset = cls.objects
|
||||
if project:
|
||||
channels = Channel.objects(project_id=project.id).values_list("id")
|
||||
if channels:
|
||||
return queryset.filter(channel_id__in=channels)
|
||||
return cls.objects.none()
|
||||
|
||||
@classmethod
|
||||
def filter_by_ip(cls, ip, only_public=False, queryset=None):
|
||||
"""过滤某ip下的区服"""
|
||||
if not queryset:
|
||||
queryset = cls.objects
|
||||
if only_public:
|
||||
hosts = Host.objects(public_ip=ip).values_list("id")
|
||||
else:
|
||||
hosts = Host.objects(Q(public_ip=ip) | Q(private_ip=ip)).values_list("id")
|
||||
|
||||
if hosts:
|
||||
return queryset.filter(host_id__in=hosts)
|
||||
return cls.objects.none()
|
||||
|
|
Loading…
Reference in New Issue