diff --git a/src/common/views.py b/src/common/views.py index 1a20ea0..560de7b 100644 --- a/src/common/views.py +++ b/src/common/views.py @@ -159,6 +159,9 @@ class ListMixin(ModelViewBase): filter_fields = [] def get_queryset(self): + self.queryset = self.model.objects + + def filter_queryset(self): """简单的过滤,默认操作符为等于""" query_params = {} # 搜索字段必须在模型对象中 @@ -180,7 +183,7 @@ class ListMixin(ModelViewBase): try: # 字段值异常会抛错 - queryset = self.model.objects(**query_params) + queryset = self.queryset.filter(**query_params) except: # 返回空 queryset queryset = self.model.objects.none() @@ -190,15 +193,18 @@ class ListMixin(ModelViewBase): def get(self): """获取列表数据""" # 过滤后的数据, + + self.get_queryset() + # TODO 匹配过滤参数 - queryset = self.get_queryset() + self.queryset = self.filter_queryset() if self.paging: - return self.paginate_queryset(queryset) + return self.paginate_queryset(self. queryset) # TODO 这里还未了解到比较妥的办法,似乎 `marshal` 返回的信息必须是 `dict`,返回铺平对象的 `[{obj}, {obj}]` 没找到方法,待研究 # 不分页,需要取出对象,铺平 - return [marshal(obj, self.fields) for obj in queryset] + return [marshal(obj, self.fields) for obj in self.queryset] class CreateMixin(ModelViewBase): diff --git a/src/controller/project/operation.py b/src/controller/project/operation.py index 19933cc..223f0df 100644 --- a/src/controller/project/operation.py +++ b/src/controller/project/operation.py @@ -76,7 +76,7 @@ class ServerSyncView(CreateMixin): if count != len(data): return make_response(400, 1001, "quantity mismatch") - project_id = str(self.project.id) + project_id = self.project.id hosts = {} channels = {} errors = [] @@ -100,12 +100,12 @@ class ServerSyncView(CreateMixin): if not created and proj not in hostObj.tags: hostObj.tags.append(proj) hostObj.save() - host = str(hostObj.id) + host = hostObj.id hosts[ip] = host if not channel: channelObj, _ = Channel.get_or_create( project_id=project_id, spid=spid, defaults=dict(project_id=project_id, spid=spid)) - channel = str(channelObj.id) + channel = channelObj.id channels[spid] = channel # 更新、创建的参数 diff --git a/src/controller/project/views.py b/src/controller/project/views.py index 0ddb372..ceaef69 100644 --- a/src/controller/project/views.py +++ b/src/controller/project/views.py @@ -1,5 +1,6 @@ """数据模型常规增删查改的接口""" +from flask import request from flask_restful import inputs from models.project import fields @@ -49,6 +50,18 @@ class ChannelViews(parsers.ChannelParse, ListCreateViewSet): self.request_parse.add_argument("project_id", type=str, location='json', required=True) super().__init__() + def get_queryset(self): + self.queryset = self.model.objects + project_full_name = request.args.get("project") + + # 若包含项目名参数 + if project_full_name: + try: + name, fork = project_full_name.split("/") + self.queryset = self.model.filter_by_project(name=name, fork=fork, queryset=self.queryset) + except: + self.queryset = self.model.objects.none() + class ChannelDetailViews(parsers.ChannelParse, DetailViewSet): model = Channel @@ -69,13 +82,37 @@ class ServerViews(parsers.ServerParse, ListCreateViewSet): method_decorators = [session_or_token_required] uniq_fields = (("num", "channel_id"),) relation_fields = (("host_id", assetModel.Host, ""), ("channel_id", Channel, True)) - filter_fields = (("num", ""), ("spid", ""), ("tags", "")) + filter_fields = (("num", ""), ("spid", ""), ("tags", ""), ("channel_id", ""), ("host_id", "")) def __init__(self): self.init_parse() self.request_parse.add_argument("num", type=int, location='json', required=True) self.request_parse.add_argument("channel_id", type=str, location='json', required=True) + def get_queryset(self): + """这些都是有关联的查询,模型内部字段的匹配字段放在 `filter_fields` 中""" + project = None + self.queryset = self.model.objects + project_full_name = request.args.get("project") + + # 若包含项目名参数 + if project_full_name: + try: + name, fork = project_full_name.split("/") + project = Project.objects(name=name, fork=fork).first() + except: + pass + self.queryset = self.model.filter_by_project(project) + + # 若包含 spid 参数 + spid = request.args.get("spid") + if spid: + self.queryset = self.model.filter_by_spid(spid, queryset=self.queryset) + + ip = request.args.get("ip") + if ip: + self.queryset = self.model.filter_by_ip(ip, queryset=self.queryset) + class ServerDetailView(parsers.ServerParse, DetailViewSet): model = Server diff --git a/src/models/project/models.py b/src/models/project/models.py index edb1659..8a1cb40 100644 --- a/src/models/project/models.py +++ b/src/models/project/models.py @@ -1,4 +1,5 @@ import mongoengine as mongo +from mongoengine import Q from settings import common from common.document import DocumentBase @@ -45,7 +46,8 @@ class Version(mongo.EmbeddedDocument): class Channel(DocumentBase): """渠道""" # spid项目内唯一 - project_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string) + # project_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string) + project_id = mongo.ObjectIdField(required=True) name = mongo.StringField(max_length=32, default="") spid = mongo.StringField(max_length=3, min_length=3, required=True, unique_with="project_id", validation=isalnum) version = mongo.EmbeddedDocumentField(Version) # 缺失时获取对象的此字段为 None @@ -79,6 +81,16 @@ class Channel(DocumentBase): if isinstance(val, Project): self.project_id = str(val.id) + @classmethod + def filter_by_project(cls, name, fork, queryset=None): + """过滤项目的所有渠道""" + if not queryset: + queryset = cls.objects + project = Project.objects(name=name, fork=fork).first() + if project: + return queryset.filter(project_id=project.id) + return cls.objects().none() + class Server(DocumentBase): """服务""" @@ -94,11 +106,13 @@ class Server(DocumentBase): num = mongo.IntField(required=True, unique_with="channel_id") # 关联channel表,保存是一个channel._id的hex字符串 - channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string) + # channel_id = mongo.StringField(max_length=128, required=True, validation=is_hex_string) + channel_id = mongo.ObjectIdField(required=True) status = mongo.StringField(max_length=12, choices=STATUS.keys(), required=False, default="running") # 机器字段,TODO 先允许为空 - host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string) + # host_id = mongo.StringField(max_length=128, null=True, validation=is_hex_string) + host_id = mongo.ObjectIdField(required=True) domain = mongo.StringField(max_length=128, required=False, default="") port = mongo.IntField() version = mongo.EmbeddedDocumentField(Version) @@ -186,3 +200,41 @@ class Server(DocumentBase): if self.channel: return self.channel.is_cross return False + + @classmethod + def filter_by_spid(cls, spid: str, project: Project = None, queryset=None): + """过滤渠道下的区服""" + if not queryset: + queryset = cls.objects + if project: + channel = Channel.objects(project_id=project.id, spid=spid).first() + else: + channel = Channel.objects(spid=spid).first() + if channel: + return queryset.filter(channel_id=channel.id) + return cls.objects.none() + + @classmethod + def filter_by_project(cls, project: Project, queryset=None): + """过滤某项目的区服""" + if not queryset: + queryset = cls.objects + if project: + channels = Channel.objects(project_id=project.id).values_list("id") + if channels: + return queryset.filter(channel_id__in=channels) + return cls.objects.none() + + @classmethod + def filter_by_ip(cls, ip, only_public=False, queryset=None): + """过滤某ip下的区服""" + if not queryset: + queryset = cls.objects + if only_public: + hosts = Host.objects(public_ip=ip).values_list("id") + else: + hosts = Host.objects(Q(public_ip=ip) | Q(private_ip=ip)).values_list("id") + + if hosts: + return queryset.filter(host_id__in=hosts) + return cls.objects.none()