区服信息搜索

This commit is contained in:
chenzuoqing 2021-12-17 11:45:47 +08:00
parent be964c54a4
commit 12cf916a80
4 changed files with 106 additions and 11 deletions

View File

@ -159,6 +159,9 @@ class ListMixin(ModelViewBase):
filter_fields = []
def get_queryset(self):
self.queryset = self.model.objects
def filter_queryset(self):
"""简单的过滤,默认操作符为等于"""
query_params = {}
# 搜索字段必须在模型对象中
@ -180,7 +183,7 @@ class ListMixin(ModelViewBase):
try:
# 字段值异常会抛错
queryset = self.model.objects(**query_params)
queryset = self.queryset.filter(**query_params)
except:
# 返回空 queryset
queryset = self.model.objects.none()
@ -190,15 +193,18 @@ class ListMixin(ModelViewBase):
def get(self):
"""获取列表数据"""
# 过滤后的数据,
self.get_queryset()
# TODO 匹配过滤参数
queryset = self.get_queryset()
self.queryset = self.filter_queryset()
if self.paging:
return self.paginate_queryset(queryset)
return self.paginate_queryset(self. queryset)
# TODO 这里还未了解到比较妥的办法,似乎 `marshal` 返回的信息必须是 `dict`,返回铺平对象的 `[{obj}, {obj}]` 没找到方法,待研究
# 不分页,需要取出对象,铺平
return [marshal(obj, self.fields) for obj in queryset]
return [marshal(obj, self.fields) for obj in self.queryset]
class CreateMixin(ModelViewBase):

View File

@ -76,7 +76,7 @@ class ServerSyncView(CreateMixin):
if count != len(data):
return make_response(400, 1001, "quantity mismatch")
project_id = str(self.project.id)
project_id = self.project.id
hosts = {}
channels = {}
errors = []
@ -100,12 +100,12 @@ class ServerSyncView(CreateMixin):
if not created and proj not in hostObj.tags:
hostObj.tags.append(proj)
hostObj.save()
host = str(hostObj.id)
host = hostObj.id
hosts[ip] = host
if not channel:
channelObj, _ = Channel.get_or_create(
project_id=project_id, spid=spid, defaults=dict(project_id=project_id, spid=spid))
channel = str(channelObj.id)
channel = channelObj.id
channels[spid] = channel
# 更新、创建的参数

View File

@ -1,5 +1,6 @@
"""数据模型常规增删查改的接口"""
from flask import request
from flask_restful import inputs
from models.project import fields
@ -49,6 +50,18 @@ class ChannelViews(parsers.ChannelParse, ListCreateViewSet):
self.request_parse.add_argument("project_id", type=str, location='json', required=True)
super().__init__()
def get_queryset(self):
self.queryset = self.model.objects
project_full_name = request.args.get("project")
# 若包含项目名参数
if project_full_name:
try:
name, fork = project_full_name.split("/")
self.queryset = self.model.filter_by_project(name=name, fork=fork, queryset=self.queryset)
except:
self.queryset = self.model.objects.none()
class ChannelDetailViews(parsers.ChannelParse, DetailViewSet):
model = Channel
@ -69,13 +82,37 @@ class ServerViews(parsers.ServerParse, ListCreateViewSet):
method_decorators = [session_or_token_required]
uniq_fields = (("num", "channel_id"),)
relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True))
filter_fields = (("num", ""), ("spid", ""), ("tags", ""))
filter_fields = (("num", ""), ("spid", ""), ("tags", ""), ("channel_id", ""), ("host_id", ""))
def __init__(self):
self.init_parse()
self.request_parse.add_argument("num", type=int, location='json', required=True)
self.request_parse.add_argument("channel_id", type=str, location='json', required=True)
def get_queryset(self):
"""这些都是有关联的查询,模型内部字段的匹配字段放在 `filter_fields` 中"""
project = None
self.queryset = self.model.objects
project_full_name = request.args.get("project")
# 若包含项目名参数
if project_full_name:
try:
name, fork = project_full_name.split("/")
project = Project.objects(name=name, fork=fork).first()
except:
pass
self.queryset = self.model.filter_by_project(project)
# 若包含 spid 参数
spid = request.args.get("spid")
if spid:
self.queryset = self.model.filter_by_spid(spid, queryset=self.queryset)
ip = request.args.get("ip")
if ip:
self.queryset = self.model.filter_by_ip(ip, queryset=self.queryset)
class ServerDetailView(parsers.ServerParse, DetailViewSet):
model = Server

View File

@ -1,4 +1,5 @@
import mongoengine as mongo
from mongoengine import Q
from settings import common
from common.document import DocumentBase
@ -45,7 +46,8 @@ class Version(mongo.EmbeddedDocument):
class Channel(DocumentBase):
"""渠道"""
# spid项目内唯一
project_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
# project_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
project_id = mongo.ObjectIdField(required=True)
name = mongo.StringField(max_length=32, default="")
spid = mongo.StringField(max_length=3, min_length=3, required=True, unique_with="project_id", validation=isalnum)
version = mongo.EmbeddedDocumentField(Version) # 缺失时获取对象的此字段为 None
@ -79,6 +81,16 @@ class Channel(DocumentBase):
if isinstance(val, Project):
self.project_id = str(val.id)
@classmethod
def filter_by_project(cls, name, fork, queryset=None):
"""过滤项目的所有渠道"""
if not queryset:
queryset = cls.objects
project = Project.objects(name=name, fork=fork).first()
if project:
return queryset.filter(project_id=project.id)
return cls.objects().none()
class Server(DocumentBase):
"""服务"""
@ -94,11 +106,13 @@ class Server(DocumentBase):
num = mongo.IntField(required=True, unique_with="channel_id")
# 关联channel表保存是一个channel._id的hex字符串
channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
# channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string)
channel_id = mongo.ObjectIdField(required=True)
status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=False, default="running")
# 机器字段TODO 先允许为空
host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string)
# host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string)
host_id = mongo.ObjectIdField(required=True)
domain = mongo.StringField(max_length=128, required=False, default="")
port = mongo.IntField()
version = mongo.EmbeddedDocumentField(Version)
@ -186,3 +200,41 @@ class Server(DocumentBase):
if self.channel:
return self.channel.is_cross
return False
@classmethod
def filter_by_spid(cls, spid: str, project: Project = None, queryset=None):
"""过滤渠道下的区服"""
if not queryset:
queryset = cls.objects
if project:
channel = Channel.objects(project_id=project.id, spid=spid).first()
else:
channel = Channel.objects(spid=spid).first()
if channel:
return queryset.filter(channel_id=channel.id)
return cls.objects.none()
@classmethod
def filter_by_project(cls, project: Project, queryset=None):
"""过滤某项目的区服"""
if not queryset:
queryset = cls.objects
if project:
channels = Channel.objects(project_id=project.id).values_list("id")
if channels:
return queryset.filter(channel_id__in=channels)
return cls.objects.none()
@classmethod
def filter_by_ip(cls, ip, only_public=False, queryset=None):
"""过滤某ip下的区服"""
if not queryset:
queryset = cls.objects
if only_public:
hosts = Host.objects(public_ip=ip).values_list("id")
else:
hosts = Host.objects(Q(public_ip=ip) | Q(private_ip=ip)).values_list("id")
if hosts:
return queryset.filter(host_id__in=hosts)
return cls.objects.none()