diff --git a/src/controller/project/parsers.py b/src/controller/project/parsers.py index 50b040d..3f10b8f 100644 --- a/src/controller/project/parsers.py +++ b/src/controller/project/parsers.py @@ -35,29 +35,23 @@ class ChannelParse: def init_parse(self): self.request_parse = reqparse.RequestParser() - self.request_parse.add_argument("name", type=str, location='json', default="") + self.request_parse.add_argument("name", type=str, location='json') self.request_parse.add_argument("version", type=dict, location='json') - self.request_parse.add_argument("repository", type=str, location='json', default="") - self.request_parse.add_argument("branch", type=dict, location='json', default="") + self.request_parse.add_argument("repository", type=str, location='json') + self.request_parse.add_argument("branch", type=str, location='json') self.request_parse.add_argument("data", type=dict, location='json') self.request_parse.add_argument("tags", type=list, location='json') self.request_parse.add_argument("labels", type=dict, location='json') - def validate_fields(self, args: dict, create=True) -> dict: - if create: - self.validate_relation_pk(args, projectModel.Project, key="project_id") - spid = str(args.get("spid")) - if not spid.isalnum() or len(spid) != 3: - abort_response(400, 1001, msg="spid 不合法") - else: - if "project_id" in args: - self.validate_relation_pk(args, projectModel.Project, key="project_id") - if "spid" in args: - spid = str(args.get("spid")) - if not spid.isalnum() or len(spid) != 3: - abort_response(400, 1001, msg="spid 不合法") - + def validate_fields(self, args: dict, **kwargs) -> dict: + """处理内嵌的version字段""" + version = args.get("version", {}) + if version and isinstance(version, dict): + defaults = {} + for field in projectModel.Version._fields: + defaults[field] = version.get(field, "") + args["version"] = projectModel.Version(**defaults) return args @@ -83,17 +77,14 @@ class ServerParse: self.request_parse.add_argument("weight", type=int, location='json') self.request_parse.add_argument("slot", type=int, location='json') - def validate_fields(self, args: dict, create=True) -> dict: - # if not create: - # if "channel_id" in args: - # self.validate_relation_pk(args, projectModel.Channel, key="channel_id") - # if "host_id" in args: - # self.validate_relation_pk(args, assetModel.Host, key="channel_id") - # else: - # self.validate_relation_pk(args, projectModel.Channel, key="channel_id") - version = args.get("version") - if version: - args["version"] = projectModel.Version(**version) + def validate_fields(self, args: dict, **kwargs) -> dict: + version = args.get("version", {}) + if version and isinstance(version, dict): + defaults = {} + # 以模型的字段为基准,从version参数中填充相同key的值 + for field in projectModel.Version._fields: + defaults[field] = version.get(field, "") + args["version"] = projectModel.Version(**defaults) game_db_id = args.get("game_db_id") if game_db_id: try: diff --git a/src/controller/project/views.py b/src/controller/project/views.py index ceaef69..f79690e 100644 --- a/src/controller/project/views.py +++ b/src/controller/project/views.py @@ -41,6 +41,7 @@ class ChannelViews(parsers.ChannelParse, ListCreateViewSet): model = Channel fields = fields.ChannelFields uniq_fields = (("spid", "project_id"),) + relation_fields = (("project_id", Project, True),) method_decorators = [session_or_token_required] filter_fields = (("name", ""), ("spid", ""), ("tags", "")) @@ -67,6 +68,7 @@ class ChannelDetailViews(parsers.ChannelParse, DetailViewSet): model = Channel fields = fields.ChannelFields uniq_fields = (("spid", "project_id"),) + relation_fields = (("project_id", Project, True),) method_decorators = [session_or_token_required] def __init__(self):